mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-23 11:12:01 +02:00
Merge branch 'main' into feature/relay-integration
This commit is contained in:
commit
9700b105b3
8
.editorconfig
Normal file
8
.editorconfig
Normal file
@ -0,0 +1,8 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
end_of_line = lf
|
||||
insert_final_newline = true
|
||||
|
||||
[*.go]
|
||||
indent_style = tab
|
35
.github/workflows/golang-test-freebsd.yml
vendored
35
.github/workflows/golang-test-freebsd.yml
vendored
@ -13,7 +13,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Test in FreeBSD
|
||||
@ -21,19 +21,26 @@ jobs:
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "14.1"
|
||||
prepare: |
|
||||
pkg install -y curl
|
||||
pkg install -y git
|
||||
pkg install -y go
|
||||
|
||||
# -x - to print all executed commands
|
||||
# -e - to faile on first error
|
||||
run: |
|
||||
set -x
|
||||
curl -o go.tar.gz https://go.dev/dl/go1.21.11.freebsd-amd64.tar.gz -L
|
||||
tar zxf go.tar.gz
|
||||
mv go /usr/local/go
|
||||
ln -s /usr/local/go/bin/go /usr/local/bin/go
|
||||
go mod tidy
|
||||
go test -timeout 5m -p 1 ./iface/...
|
||||
go test -timeout 5m -p 1 ./client/...
|
||||
cd client
|
||||
go build .
|
||||
cd ..
|
||||
set -e -x
|
||||
time go build -o netbird client/main.go
|
||||
# check all component except management, since we do not support management server on freebsd
|
||||
time go test -timeout 1m -failfast ./base62/...
|
||||
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
|
||||
time go test -timeout 8m -failfast -p 1 ./client/...
|
||||
time go test -timeout 1m -failfast ./dns/...
|
||||
time go test -timeout 1m -failfast ./encryption/...
|
||||
time go test -timeout 1m -failfast ./formatter/...
|
||||
time go test -timeout 1m -failfast ./iface/...
|
||||
time go test -timeout 1m -failfast ./route/...
|
||||
time go test -timeout 1m -failfast ./sharedsock/...
|
||||
time go test -timeout 1m -failfast ./signal/...
|
||||
time go test -timeout 1m -failfast ./util/...
|
||||
time go test -timeout 1m -failfast ./version/...
|
||||
|
11
.github/workflows/release.yml
vendored
11
.github/workflows/release.yml
vendored
@ -79,15 +79,8 @@ jobs:
|
||||
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso 386
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_386.syso
|
||||
- name: Generate windows syso arm
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm.syso
|
||||
- name: Generate windows syso arm64
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm64.syso
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_amd64.syso
|
||||
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
@ -170,7 +163,7 @@ jobs:
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/ui/resources_windows_amd64.syso
|
||||
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
|
@ -151,10 +151,10 @@ jobs:
|
||||
- name: run docker compose up
|
||||
working-directory: infrastructure_files/artifacts
|
||||
run: |
|
||||
docker-compose up -d
|
||||
docker compose up -d
|
||||
sleep 5
|
||||
docker-compose ps
|
||||
docker-compose logs --tail=20
|
||||
docker compose ps
|
||||
docker compose logs --tail=20
|
||||
|
||||
- name: test running containers
|
||||
run: |
|
||||
@ -207,7 +207,7 @@ jobs:
|
||||
|
||||
- name: Postgres run cleanup
|
||||
run: |
|
||||
docker-compose down --volumes --rmi all
|
||||
docker compose down --volumes --rmi all
|
||||
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
|
||||
|
||||
- name: run script with Zitadel CockroachDB
|
||||
|
@ -10,10 +10,12 @@
|
||||
<img width="234" src="docs/media/logo-full.png"/>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
|
||||
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
|
||||
</a>
|
||||
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<a href="https://www.codacy.com/gh/netbirdio/netbird/dashboard?utm_source=github.com&utm_medium=referral&utm_content=netbirdio/netbird&utm_campaign=Badge_Grade"><img src="https://app.codacy.com/project/badge/Grade/e3013d046aec44cdb7462c8673b00976"/></a>
|
||||
<br>
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
|
@ -178,6 +178,21 @@ func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
|
||||
})
|
||||
}
|
||||
|
||||
// AnonymizeRoute anonymizes a route string by replacing IP addresses with anonymized versions and
|
||||
// domain names with random strings.
|
||||
func (a *Anonymizer) AnonymizeRoute(route string) string {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
domains := strings.Split(route, ", ")
|
||||
for i, domain := range domains {
|
||||
domains[i] = a.AnonymizeDomain(domain)
|
||||
}
|
||||
return strings.Join(domains, ", ")
|
||||
}
|
||||
|
||||
func isWellKnown(addr netip.Addr) bool {
|
||||
wellKnown := []string{
|
||||
"8.8.8.8", "8.8.4.4", // Google DNS IPv4
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
@ -13,6 +14,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
)
|
||||
|
||||
const errCloseConnection = "Failed to close connection: %v"
|
||||
|
||||
var debugCmd = &cobra.Command{
|
||||
Use: "debug",
|
||||
Short: "Debugging commands",
|
||||
@ -63,12 +66,17 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd),
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
@ -84,7 +92,11 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
level := server.ParseLogLevel(args[0])
|
||||
@ -113,7 +125,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
@ -122,17 +138,20 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
restoreUp := stat.Status == string(internal.StatusConnected) || stat.Status == string(internal.StatusConnecting)
|
||||
stateWasDown := stat.Status != string(internal.StatusConnected) && stat.Status != string(internal.StatusConnecting)
|
||||
|
||||
initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
if stateWasDown {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird up")
|
||||
time.Sleep(time.Second * 10)
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
|
||||
if !initialLevelTrace {
|
||||
@ -145,6 +164,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Log level set to trace.")
|
||||
}
|
||||
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
@ -162,21 +186,25 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
cmd.Println("\nDuration completed")
|
||||
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
|
||||
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
if restoreUp {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird up")
|
||||
}
|
||||
|
||||
if !initialLevelTrace {
|
||||
@ -186,16 +214,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||
}
|
||||
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Println(resp.GetPath())
|
||||
|
||||
return nil
|
||||
|
@ -37,6 +37,7 @@ const (
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
systemInfoFlag = "system-info"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -69,6 +70,7 @@ var (
|
||||
autoConnectDisabled bool
|
||||
extraIFaceBlackList []string
|
||||
anonymizeFlag bool
|
||||
debugSystemInfoFlag bool
|
||||
dnsRouteInterval time.Duration
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
@ -91,12 +93,15 @@ func init() {
|
||||
oldDefaultConfigPathDir = "/etc/wiretrustee/"
|
||||
oldDefaultLogFileDir = "/var/log/wiretrustee/"
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
defaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
|
||||
defaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
|
||||
|
||||
oldDefaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
|
||||
oldDefaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
|
||||
case "freebsd":
|
||||
defaultConfigPathDir = "/var/db/netbird/"
|
||||
}
|
||||
|
||||
defaultConfigPath = defaultConfigPathDir + "config.json"
|
||||
@ -165,6 +170,8 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle")
|
||||
}
|
||||
|
||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||
|
@ -31,6 +31,8 @@ var installCmd = &cobra.Command{
|
||||
configPath,
|
||||
"--log-level",
|
||||
logLevel,
|
||||
"--daemon-addr",
|
||||
daemonAddr,
|
||||
}
|
||||
|
||||
if managementURL != "" {
|
||||
|
@ -810,7 +810,7 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
||||
}
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = anonymizeRoute(a, route)
|
||||
peer.Routes[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
}
|
||||
|
||||
@ -846,21 +846,8 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
|
||||
}
|
||||
|
||||
for i, route := range overview.Routes {
|
||||
overview.Routes[i] = anonymizeRoute(a, route)
|
||||
overview.Routes[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
}
|
||||
|
||||
func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
domains := strings.Split(route, ", ")
|
||||
for i, domain := range domains {
|
||||
domains[i] = a.AnonymizeDomain(domain)
|
||||
}
|
||||
return strings.Join(domains, ", ")
|
||||
}
|
||||
|
@ -69,6 +69,11 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
|
||||
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
|
||||
if runtime.GOOS == "freebsd" {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
|
||||
if err != nil {
|
||||
// fallback to device code flow
|
||||
|
@ -94,7 +94,7 @@ func NewDefaultServer(
|
||||
|
||||
var dnsService service
|
||||
if wgInterface.IsUserspaceBind() {
|
||||
dnsService = newServiceViaMemory(wgInterface)
|
||||
dnsService = NewServiceViaMemory(wgInterface)
|
||||
} else {
|
||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||
}
|
||||
@ -112,7 +112,7 @@ func NewDefaultServerPermanentUpstream(
|
||||
statusRecorder *peer.Status,
|
||||
) *DefaultServer {
|
||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds.hostsDNSHolder.set(hostsDnsList)
|
||||
ds.permanent = true
|
||||
ds.addHostRootZone()
|
||||
@ -130,7 +130,7 @@ func NewDefaultServerIos(
|
||||
iosDnsManager IosDnsManager,
|
||||
statusRecorder *peer.Status,
|
||||
) *DefaultServer {
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
|
||||
ds.iosDnsManager = iosDnsManager
|
||||
return ds
|
||||
}
|
||||
|
@ -534,7 +534,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
hostManager := &mockHostConfigurator{}
|
||||
server := DefaultServer{
|
||||
service: newServiceViaMemory(&mocWGIface{}),
|
||||
service: NewServiceViaMemory(&mocWGIface{}),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
|
@ -128,6 +128,9 @@ func (s *serviceViaListener) RuntimeIP() string {
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
|
@ -12,7 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type serviceViaMemory struct {
|
||||
type ServiceViaMemory struct {
|
||||
wgInterface WGIface
|
||||
dnsMux *dns.ServeMux
|
||||
runtimeIP string
|
||||
@ -22,8 +22,8 @@ type serviceViaMemory struct {
|
||||
listenerFlagLock sync.Mutex
|
||||
}
|
||||
|
||||
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
|
||||
s := &serviceViaMemory{
|
||||
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||
s := &ServiceViaMemory{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: dns.NewServeMux(),
|
||||
|
||||
@ -33,7 +33,7 @@ func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) Listen() error {
|
||||
func (s *ServiceViaMemory) Listen() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
@ -52,7 +52,7 @@ func (s *serviceViaMemory) Listen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) Stop() {
|
||||
func (s *ServiceViaMemory) Stop() {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
@ -67,23 +67,23 @@ func (s *serviceViaMemory) Stop() {
|
||||
s.listenerIsRunning = false
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) DeregisterMux(pattern string) {
|
||||
func (s *ServiceViaMemory) DeregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) RuntimePort() int {
|
||||
func (s *ServiceViaMemory) RuntimePort() int {
|
||||
return s.runtimePort
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) RuntimeIP() string {
|
||||
func (s *ServiceViaMemory) RuntimeIP() string {
|
||||
return s.runtimeIP
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
|
||||
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter == nil {
|
||||
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
||||
|
@ -4,6 +4,7 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
"time"
|
||||
@ -19,7 +20,7 @@ type upstreamResolverIOS struct {
|
||||
*upstreamResolverBase
|
||||
lIP net.IP
|
||||
lNet *net.IPNet
|
||||
iIndex int
|
||||
interfaceName string
|
||||
}
|
||||
|
||||
func newUpstreamResolver(
|
||||
@ -32,17 +33,11 @@ func newUpstreamResolver(
|
||||
) (*upstreamResolverIOS, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
|
||||
|
||||
index, err := getInterfaceIndex(interfaceName)
|
||||
if err != nil {
|
||||
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ios := &upstreamResolverIOS{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
lIP: ip,
|
||||
lNet: net,
|
||||
iIndex: index,
|
||||
interfaceName: interfaceName,
|
||||
}
|
||||
ios.upstreamClient = ios
|
||||
|
||||
@ -53,7 +48,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
client := &dns.Client{}
|
||||
upstreamHost, _, err := net.SplitHostPort(upstream)
|
||||
if err != nil {
|
||||
log.Errorf("error while parsing upstream host: %s", err)
|
||||
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
|
||||
}
|
||||
|
||||
timeout := upstreamTimeout
|
||||
@ -65,26 +60,35 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
upstreamIP := net.ParseIP(upstreamHost)
|
||||
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
|
||||
log.Debugf("using private client to query upstream: %s", upstream)
|
||||
client = u.getClientPrivate(timeout)
|
||||
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("error while creating private client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
||||
return client.Exchange(r, upstream)
|
||||
}
|
||||
|
||||
// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||
// This method is needed for iOS
|
||||
func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.Client {
|
||||
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||
index, err := getInterfaceIndex(interfaceName)
|
||||
if err != nil {
|
||||
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
LocalAddr: &net.UDPAddr{
|
||||
IP: u.lIP,
|
||||
IP: ip,
|
||||
Port: 0, // Let the OS pick a free port
|
||||
},
|
||||
Timeout: dialTimeout,
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
var operr error
|
||||
fn := func(s uintptr) {
|
||||
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex)
|
||||
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index)
|
||||
}
|
||||
|
||||
if err := c.Control(fn); err != nil {
|
||||
@ -101,7 +105,7 @@ func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.C
|
||||
client := &dns.Client{
|
||||
Dialer: dialer,
|
||||
}
|
||||
return client
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func getInterfaceIndex(interfaceName string) (int, error) {
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
@ -64,7 +65,7 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
|
||||
routePeersNotifiers: make(map[string]chan struct{}),
|
||||
routeUpdate: make(chan routesUpdate),
|
||||
peerStateUpdate: make(chan struct{}),
|
||||
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder),
|
||||
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface),
|
||||
}
|
||||
return client
|
||||
}
|
||||
@ -377,9 +378,10 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
}
|
||||
}
|
||||
|
||||
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler {
|
||||
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface *iface.WGIface) RouteHandler {
|
||||
if rt.IsDynamic() {
|
||||
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder)
|
||||
dns := nbdns.NewServiceViaMemory(wgInterface)
|
||||
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
|
||||
}
|
||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@ -47,6 +48,8 @@ type Route struct {
|
||||
currentPeerKey string
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
resolverAddr string
|
||||
}
|
||||
|
||||
func NewRoute(
|
||||
@ -55,6 +58,8 @@ func NewRoute(
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
interval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
wgInterface *iface.WGIface,
|
||||
resolverAddr string,
|
||||
) *Route {
|
||||
return &Route{
|
||||
route: rt,
|
||||
@ -63,6 +68,8 @@ func NewRoute(
|
||||
interval: interval,
|
||||
dynamicDomains: domainMap{},
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
resolverAddr: resolverAddr,
|
||||
}
|
||||
}
|
||||
|
||||
@ -189,9 +196,14 @@ func (r *Route) startResolver(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (r *Route) update(ctx context.Context) error {
|
||||
if resolved, err := r.resolveDomains(); err != nil {
|
||||
resolved, err := r.resolveDomains()
|
||||
if err != nil {
|
||||
if len(resolved) == 0 {
|
||||
return fmt.Errorf("resolve domains: %w", err)
|
||||
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
|
||||
}
|
||||
log.Warnf("Failed to resolve domains: %v", err)
|
||||
}
|
||||
if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
|
||||
return fmt.Errorf("update dynamic routes: %w", err)
|
||||
}
|
||||
|
||||
@ -223,11 +235,17 @@ func (r *Route) resolve(results chan resolveResult) {
|
||||
wg.Add(1)
|
||||
go func(domain domain.Domain) {
|
||||
defer wg.Done()
|
||||
ips, err := net.LookupIP(string(domain))
|
||||
|
||||
ips, err := r.getIPsFromResolver(domain)
|
||||
if err != nil {
|
||||
log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err)
|
||||
ips, err = net.LookupIP(string(domain))
|
||||
if err != nil {
|
||||
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
|
13
client/internal/routemanager/dynamic/route_generic.go
Normal file
13
client/internal/routemanager/dynamic/route_generic.go
Normal file
@ -0,0 +1,13 @@
|
||||
//go:build !ios
|
||||
|
||||
package dynamic
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
|
||||
return net.LookupIP(string(domain))
|
||||
}
|
55
client/internal/routemanager/dynamic/route_ios.go
Normal file
55
client/internal/routemanager/dynamic/route_ios.go
Normal file
@ -0,0 +1,55 @@
|
||||
//go:build ios
|
||||
|
||||
package dynamic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
const dialTimeout = 10 * time.Second
|
||||
|
||||
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
|
||||
privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while creating private client: %s", err)
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(dns.Fqdn(string(domain)), dns.TypeA)
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
response, _, err := privateClient.Exchange(msg, r.resolverAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err)
|
||||
}
|
||||
|
||||
if response.Rcode != dns.RcodeSuccess {
|
||||
return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode])
|
||||
}
|
||||
|
||||
ips := make([]net.IP, 0)
|
||||
|
||||
for _, answ := range response.Answer {
|
||||
if aRecord, ok := answ.(*dns.A); ok {
|
||||
ips = append(ips, aRecord.A)
|
||||
}
|
||||
if aaaaRecord, ok := answ.(*dns.AAAA); ok {
|
||||
ips = append(ips, aaaaRecord.AAAA)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString())
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
@ -22,7 +22,7 @@ type Route struct {
|
||||
Interface *net.Interface
|
||||
}
|
||||
|
||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
func GetRoutesFromTable() ([]netip.Prefix, error) {
|
||||
tab, err := retryFetchRIB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch RIB: %v", err)
|
||||
|
@ -427,7 +427,7 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
||||
}
|
||||
|
||||
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := getRoutesFromTable()
|
||||
routes, err := GetRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
@ -440,7 +440,7 @@ func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||
}
|
||||
|
||||
func isSubRange(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := getRoutesFromTable()
|
||||
routes, err := GetRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
|
@ -206,7 +206,7 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
func GetRoutesFromTable() ([]netip.Prefix, error) {
|
||||
v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get v4 routes: %w", err)
|
||||
@ -504,7 +504,7 @@ func getAddressFamily(prefix netip.Prefix) int {
|
||||
|
||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||
if isLegacy() {
|
||||
return getRoutesFromTable()
|
||||
return GetRoutesFromTable()
|
||||
}
|
||||
return nil, ErrRoutingIsSeparate
|
||||
}
|
||||
|
@ -24,5 +24,5 @@ func EnableIPForwarding() error {
|
||||
}
|
||||
|
||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||
return getRoutesFromTable()
|
||||
return GetRoutesFromTable()
|
||||
}
|
||||
|
@ -94,7 +94,7 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
func GetRoutesFromTable() ([]netip.Prefix, error) {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
|
||||
|
@ -263,6 +263,7 @@ message Route {
|
||||
message DebugBundleRequest {
|
||||
bool anonymize = 1;
|
||||
string status = 2;
|
||||
bool systemInfo = 3;
|
||||
}
|
||||
|
||||
message DebugBundleResponse {
|
||||
|
@ -1,3 +1,5 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
@ -6,16 +8,70 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const readmeContent = `Netbird debug bundle
|
||||
This debug bundle contains the following files:
|
||||
|
||||
status.txt: Anonymized status information of the NetBird client.
|
||||
client.log: Most recent, anonymized log file of the NetBird client.
|
||||
routes.txt: Anonymized system routes, if --system-info flag was provided.
|
||||
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
|
||||
|
||||
Anonymization Process
|
||||
The files in this bundle have been anonymized to protect sensitive information. Here's how the anonymization was applied:
|
||||
|
||||
IP Addresses
|
||||
|
||||
IPv4 addresses are replaced with addresses starting from 192.51.100.0
|
||||
IPv6 addresses are replaced with addresses starting from 100::
|
||||
|
||||
IP addresses from non public ranges and well known addresses are not anonymized (e.g. 8.8.8.8, 100.64.0.0/10, addresses starting with 192.168., 172.16., 10., etc.).
|
||||
Reoccuring IP addresses are replaced with the same anonymized address.
|
||||
|
||||
Note: The anonymized IP addresses in the status file do not match those in the log and routes files. However, the anonymized IP addresses are consistent within the status file and across the routes and log files.
|
||||
|
||||
Domains
|
||||
All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle.
|
||||
Reoccuring domain names are replaced with the same anonymized domain.
|
||||
|
||||
Routes
|
||||
For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct.
|
||||
Network Interfaces
|
||||
The interfaces.txt file contains information about network interfaces, including:
|
||||
- Interface name
|
||||
- Interface index
|
||||
- MTU (Maximum Transmission Unit)
|
||||
- Flags
|
||||
- IP addresses associated with each interface
|
||||
|
||||
The IP addresses in the interfaces file are anonymized using the same process as described above. Interface names, indexes, MTUs, and flags are not anonymized.
|
||||
|
||||
Configuration
|
||||
The config.txt file contains anonymized configuration information of the NetBird client. Sensitive information such as private keys and SSH keys are excluded. The following fields are anonymized:
|
||||
- ManagementURL
|
||||
- AdminURL
|
||||
- NATExternalIPs
|
||||
- CustomDNSAddress
|
||||
|
||||
Other non-sensitive configuration options are included without anonymization.
|
||||
`
|
||||
|
||||
// DebugBundle creates a debug bundle and returns the location.
|
||||
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
|
||||
s.mutex.Lock()
|
||||
@ -30,93 +86,211 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
return nil, fmt.Errorf("create zip file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := bundlePath.Close(); err != nil {
|
||||
log.Errorf("failed to close zip file: %v", err)
|
||||
if closeErr := bundlePath.Close(); closeErr != nil && err == nil {
|
||||
err = fmt.Errorf("close zip file: %w", closeErr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err2 := os.Remove(bundlePath.Name()); err2 != nil {
|
||||
log.Errorf("Failed to remove zip file: %v", err2)
|
||||
if removeErr := os.Remove(bundlePath.Name()); removeErr != nil {
|
||||
log.Errorf("Failed to remove zip file: %v", removeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
archive := zip.NewWriter(bundlePath)
|
||||
defer func() {
|
||||
if err := archive.Close(); err != nil {
|
||||
log.Errorf("failed to close archive writer: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if status := req.GetStatus(); status != "" {
|
||||
filename := "status.txt"
|
||||
if req.GetAnonymize() {
|
||||
filename = "status.anon.txt"
|
||||
}
|
||||
statusReader := strings.NewReader(status)
|
||||
if err := addFileToZip(archive, statusReader, filename); err != nil {
|
||||
return nil, fmt.Errorf("add status file to zip: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logFile, err := os.Open(s.logFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open log file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := logFile.Close(); err != nil {
|
||||
log.Errorf("failed to close original log file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
filename := "client.log.txt"
|
||||
var logReader io.Reader
|
||||
errChan := make(chan error, 1)
|
||||
if req.GetAnonymize() {
|
||||
filename = "client.anon.log.txt"
|
||||
var writer io.WriteCloser
|
||||
logReader, writer = io.Pipe()
|
||||
|
||||
go s.anonymize(logFile, writer, errChan)
|
||||
} else {
|
||||
logReader = logFile
|
||||
}
|
||||
if err := addFileToZip(archive, logReader, filename); err != nil {
|
||||
return nil, fmt.Errorf("add log file to zip: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
if err := s.createArchive(bundlePath, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
return &proto.DebugBundleResponse{Path: bundlePath.Name()}, nil
|
||||
}
|
||||
|
||||
func (s *Server) anonymize(reader io.Reader, writer io.WriteCloser, errChan chan<- error) {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleRequest) error {
|
||||
archive := zip.NewWriter(bundlePath)
|
||||
if err := s.addReadme(req, archive); err != nil {
|
||||
return fmt.Errorf("add readme: %w", err)
|
||||
}
|
||||
|
||||
if err := s.addStatus(req, archive); err != nil {
|
||||
return fmt.Errorf("add status: %w", err)
|
||||
}
|
||||
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
status := s.statusRecorder.GetFullStatus()
|
||||
seedFromStatus(anonymizer, &status)
|
||||
|
||||
if err := s.addConfig(req, anonymizer, archive); err != nil {
|
||||
return fmt.Errorf("add config: %w", err)
|
||||
}
|
||||
|
||||
if req.GetSystemInfo() {
|
||||
if err := s.addRoutes(req, anonymizer, archive); err != nil {
|
||||
return fmt.Errorf("add routes: %w", err)
|
||||
}
|
||||
|
||||
if err := s.addInterfaces(req, anonymizer, archive); err != nil {
|
||||
return fmt.Errorf("add interfaces: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.addLogfile(req, anonymizer, archive); err != nil {
|
||||
return fmt.Errorf("add log file: %w", err)
|
||||
}
|
||||
|
||||
if err := archive.Close(); err != nil {
|
||||
return fmt.Errorf("close archive writer: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addReadme(req *proto.DebugBundleRequest, archive *zip.Writer) error {
|
||||
if req.GetAnonymize() {
|
||||
readmeReader := strings.NewReader(readmeContent)
|
||||
if err := addFileToZip(archive, readmeReader, "README.txt"); err != nil {
|
||||
return fmt.Errorf("add README file to zip: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addStatus(req *proto.DebugBundleRequest, archive *zip.Writer) error {
|
||||
if status := req.GetStatus(); status != "" {
|
||||
statusReader := strings.NewReader(status)
|
||||
if err := addFileToZip(archive, statusReader, "status.txt"); err != nil {
|
||||
return fmt.Errorf("add status file to zip: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addConfig(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
var configContent strings.Builder
|
||||
s.addCommonConfigFields(&configContent)
|
||||
|
||||
if req.GetAnonymize() {
|
||||
if s.config.ManagementURL != nil {
|
||||
configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", anonymizer.AnonymizeURI(s.config.ManagementURL.String())))
|
||||
}
|
||||
if s.config.AdminURL != nil {
|
||||
configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", anonymizer.AnonymizeURI(s.config.AdminURL.String())))
|
||||
}
|
||||
configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", anonymizeNATExternalIPs(s.config.NATExternalIPs, anonymizer)))
|
||||
if s.config.CustomDNSAddress != "" {
|
||||
configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", anonymizer.AnonymizeString(s.config.CustomDNSAddress)))
|
||||
}
|
||||
} else {
|
||||
if s.config.ManagementURL != nil {
|
||||
configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", s.config.ManagementURL.String()))
|
||||
}
|
||||
if s.config.AdminURL != nil {
|
||||
configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", s.config.AdminURL.String()))
|
||||
}
|
||||
configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", s.config.NATExternalIPs))
|
||||
if s.config.CustomDNSAddress != "" {
|
||||
configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", s.config.CustomDNSAddress))
|
||||
}
|
||||
}
|
||||
|
||||
// Add config content to zip file
|
||||
configReader := strings.NewReader(configContent.String())
|
||||
if err := addFileToZip(archive, configReader, "config.txt"); err != nil {
|
||||
return fmt.Errorf("add config file to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addCommonConfigFields(configContent *strings.Builder) {
|
||||
configContent.WriteString("NetBird Client Configuration:\n\n")
|
||||
|
||||
// Add non-sensitive fields
|
||||
configContent.WriteString(fmt.Sprintf("WgIface: %s\n", s.config.WgIface))
|
||||
configContent.WriteString(fmt.Sprintf("WgPort: %d\n", s.config.WgPort))
|
||||
if s.config.NetworkMonitor != nil {
|
||||
configContent.WriteString(fmt.Sprintf("NetworkMonitor: %v\n", *s.config.NetworkMonitor))
|
||||
}
|
||||
configContent.WriteString(fmt.Sprintf("IFaceBlackList: %v\n", s.config.IFaceBlackList))
|
||||
configContent.WriteString(fmt.Sprintf("DisableIPv6Discovery: %v\n", s.config.DisableIPv6Discovery))
|
||||
configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", s.config.RosenpassEnabled))
|
||||
configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", s.config.RosenpassPermissive))
|
||||
if s.config.ServerSSHAllowed != nil {
|
||||
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *s.config.ServerSSHAllowed))
|
||||
}
|
||||
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", s.config.DisableAutoConnect))
|
||||
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", s.config.DNSRouteInterval))
|
||||
}
|
||||
|
||||
func (s *Server) addRoutes(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
if routes, err := systemops.GetRoutesFromTable(); err != nil {
|
||||
log.Errorf("Failed to get routes: %v", err)
|
||||
} else {
|
||||
// TODO: get routes including nexthop
|
||||
routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer)
|
||||
routesReader := strings.NewReader(routesContent)
|
||||
if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil {
|
||||
return fmt.Errorf("add routes file to zip: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addInterfaces(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get interfaces: %w", err)
|
||||
}
|
||||
|
||||
interfacesContent := formatInterfaces(interfaces, req.GetAnonymize(), anonymizer)
|
||||
interfacesReader := strings.NewReader(interfacesContent)
|
||||
if err := addFileToZip(archive, interfacesReader, "interfaces.txt"); err != nil {
|
||||
return fmt.Errorf("add interfaces file to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) {
|
||||
logFile, err := os.Open(s.logFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open log file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := writer.Close(); err != nil {
|
||||
log.Errorf("Failed to close writer: %v", err)
|
||||
if err := logFile.Close(); err != nil {
|
||||
log.Errorf("Failed to close original log file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var logReader io.Reader
|
||||
if req.GetAnonymize() {
|
||||
var writer *io.PipeWriter
|
||||
logReader, writer = io.Pipe()
|
||||
|
||||
go s.anonymize(logFile, writer, anonymizer)
|
||||
} else {
|
||||
logReader = logFile
|
||||
}
|
||||
if err := addFileToZip(archive, logReader, "client.log"); err != nil {
|
||||
return fmt.Errorf("add log file to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) anonymize(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) {
|
||||
defer func() {
|
||||
// always nil
|
||||
_ = writer.Close()
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(reader)
|
||||
for scanner.Scan() {
|
||||
line := anonymizer.AnonymizeString(scanner.Text())
|
||||
if _, err := writer.Write([]byte(line + "\n")); err != nil {
|
||||
errChan <- fmt.Errorf("write line to writer: %w", err)
|
||||
writer.CloseWithError(fmt.Errorf("anonymize write: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
errChan <- fmt.Errorf("read line from scanner: %w", err)
|
||||
writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -143,6 +317,20 @@ func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error
|
||||
header := &zip.FileHeader{
|
||||
Name: filename,
|
||||
Method: zip.Deflate,
|
||||
Modified: time.Now(),
|
||||
|
||||
CreatorVersion: 20, // Version 2.0
|
||||
ReaderVersion: 20, // Version 2.0
|
||||
Flags: 0x800, // UTF-8 filename
|
||||
}
|
||||
|
||||
// If the reader is a file, we can get more accurate information
|
||||
if f, ok := reader.(*os.File); ok {
|
||||
if stat, err := f.Stat(); err != nil {
|
||||
log.Tracef("Failed to get file stat for %s: %v", filename, err)
|
||||
} else {
|
||||
header.Modified = stat.ModTime()
|
||||
}
|
||||
}
|
||||
|
||||
writer, err := archive.CreateHeader(header)
|
||||
@ -165,6 +353,13 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
|
||||
|
||||
for _, peer := range status.Peers {
|
||||
a.AnonymizeDomain(peer.FQDN)
|
||||
for route := range peer.GetRoutes() {
|
||||
a.AnonymizeRoute(route)
|
||||
}
|
||||
}
|
||||
|
||||
for route := range status.LocalPeerState.Routes {
|
||||
a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
for _, nsGroup := range status.NSGroupStates {
|
||||
@ -179,3 +374,113 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func formatRoutes(routes []netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) string {
|
||||
var ipv4Routes, ipv6Routes []netip.Prefix
|
||||
|
||||
// Separate IPv4 and IPv6 routes
|
||||
for _, route := range routes {
|
||||
if route.Addr().Is4() {
|
||||
ipv4Routes = append(ipv4Routes, route)
|
||||
} else {
|
||||
ipv6Routes = append(ipv6Routes, route)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort IPv4 and IPv6 routes separately
|
||||
sort.Slice(ipv4Routes, func(i, j int) bool {
|
||||
return ipv4Routes[i].Bits() > ipv4Routes[j].Bits()
|
||||
})
|
||||
sort.Slice(ipv6Routes, func(i, j int) bool {
|
||||
return ipv6Routes[i].Bits() > ipv6Routes[j].Bits()
|
||||
})
|
||||
|
||||
var builder strings.Builder
|
||||
|
||||
// Format IPv4 routes
|
||||
builder.WriteString("IPv4 Routes:\n")
|
||||
for _, route := range ipv4Routes {
|
||||
formatRoute(&builder, route, anonymize, anonymizer)
|
||||
}
|
||||
|
||||
// Format IPv6 routes
|
||||
builder.WriteString("\nIPv6 Routes:\n")
|
||||
for _, route := range ipv6Routes {
|
||||
formatRoute(&builder, route, anonymize, anonymizer)
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func formatRoute(builder *strings.Builder, route netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) {
|
||||
if anonymize {
|
||||
anonymizedIP := anonymizer.AnonymizeIP(route.Addr())
|
||||
builder.WriteString(fmt.Sprintf("%s/%d\n", anonymizedIP, route.Bits()))
|
||||
} else {
|
||||
builder.WriteString(fmt.Sprintf("%s\n", route))
|
||||
}
|
||||
}
|
||||
|
||||
func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *anonymize.Anonymizer) string {
|
||||
sort.Slice(interfaces, func(i, j int) bool {
|
||||
return interfaces[i].Name < interfaces[j].Name
|
||||
})
|
||||
|
||||
var builder strings.Builder
|
||||
builder.WriteString("Network Interfaces:\n")
|
||||
|
||||
for _, iface := range interfaces {
|
||||
builder.WriteString(fmt.Sprintf("\nInterface: %s\n", iface.Name))
|
||||
builder.WriteString(fmt.Sprintf(" Index: %d\n", iface.Index))
|
||||
builder.WriteString(fmt.Sprintf(" MTU: %d\n", iface.MTU))
|
||||
builder.WriteString(fmt.Sprintf(" Flags: %v\n", iface.Flags))
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
builder.WriteString(fmt.Sprintf(" Addresses: Error retrieving addresses: %v\n", err))
|
||||
} else {
|
||||
builder.WriteString(" Addresses:\n")
|
||||
for _, addr := range addrs {
|
||||
prefix, err := netip.ParsePrefix(addr.String())
|
||||
if err != nil {
|
||||
builder.WriteString(fmt.Sprintf(" Error parsing address: %v\n", err))
|
||||
continue
|
||||
}
|
||||
ip := prefix.Addr()
|
||||
if anonymize {
|
||||
ip = anonymizer.AnonymizeIP(ip)
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf(" %s/%d\n", ip, prefix.Bits()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string {
|
||||
anonymizedIPs := make([]string, len(ips))
|
||||
for i, ip := range ips {
|
||||
parts := strings.SplitN(ip, "/", 2)
|
||||
|
||||
ip1, err := netip.ParseAddr(parts[0])
|
||||
if err != nil {
|
||||
anonymizedIPs[i] = ip
|
||||
continue
|
||||
}
|
||||
ip1anon := anonymizer.AnonymizeIP(ip1)
|
||||
|
||||
if len(parts) == 2 {
|
||||
ip2, err := netip.ParseAddr(parts[1])
|
||||
if err != nil {
|
||||
anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, parts[1])
|
||||
} else {
|
||||
ip2anon := anonymizer.AnonymizeIP(ip2)
|
||||
anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, ip2anon)
|
||||
}
|
||||
} else {
|
||||
anonymizedIPs[i] = ip1anon.String()
|
||||
}
|
||||
}
|
||||
return anonymizedIPs
|
||||
}
|
||||
|
@ -1,11 +1,12 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func Test_sysInfo(t *testing.T) {
|
||||
func Test_sysInfoMac(t *testing.T) {
|
||||
t.Skip("skipping darwin test")
|
||||
serialNum, prodName, manufacturer := sysInfo()
|
||||
if serialNum == "" {
|
||||
|
@ -21,6 +21,26 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
type SysInfoGetter interface {
|
||||
GetSysInfo() SysInfo
|
||||
}
|
||||
|
||||
type SysInfoWrapper struct {
|
||||
si sysinfo.SysInfo
|
||||
}
|
||||
|
||||
func (s SysInfoWrapper) GetSysInfo() SysInfo {
|
||||
s.si.GetSysInfo()
|
||||
return SysInfo{
|
||||
ChassisSerial: s.si.Chassis.Serial,
|
||||
ProductSerial: s.si.Product.Serial,
|
||||
BoardSerial: s.si.Board.Serial,
|
||||
ProductName: s.si.Product.Name,
|
||||
BoardName: s.si.Board.Name,
|
||||
ProductVendor: s.si.Product.Vendor,
|
||||
}
|
||||
}
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
info := _getInfo()
|
||||
@ -45,7 +65,8 @@ func GetInfo(ctx context.Context) *Info {
|
||||
log.Warnf("failed to discover network addresses: %s", err)
|
||||
}
|
||||
|
||||
serialNum, prodName, manufacturer := sysInfo()
|
||||
si := SysInfoWrapper{}
|
||||
serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo())
|
||||
|
||||
env := Environment{
|
||||
Cloud: detect_cloud.Detect(ctx),
|
||||
@ -87,20 +108,36 @@ func _getInfo() string {
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func sysInfo() (serialNumber string, productName string, manufacturer string) {
|
||||
var si sysinfo.SysInfo
|
||||
si.GetSysInfo()
|
||||
func sysInfo(si SysInfo) (string, string, string) {
|
||||
isascii := regexp.MustCompile("^[[:ascii:]]+$")
|
||||
serial := si.Chassis.Serial
|
||||
if (serial == "Default string" || serial == "") && si.Product.Serial != "" {
|
||||
serial = si.Product.Serial
|
||||
|
||||
serials := []string{si.ChassisSerial, si.ProductSerial}
|
||||
serial := ""
|
||||
|
||||
for _, s := range serials {
|
||||
if isascii.MatchString(s) {
|
||||
serial = s
|
||||
if s != "Default string" {
|
||||
break
|
||||
}
|
||||
if (!isascii.MatchString(serial)) && si.Board.Serial != "" {
|
||||
serial = si.Board.Serial
|
||||
}
|
||||
name := si.Product.Name
|
||||
if (!isascii.MatchString(name)) && si.Board.Name != "" {
|
||||
name = si.Board.Name
|
||||
}
|
||||
return serial, name, si.Product.Vendor
|
||||
|
||||
if serial == "" && isascii.MatchString(si.BoardSerial) {
|
||||
serial = si.BoardSerial
|
||||
}
|
||||
|
||||
var name string
|
||||
for _, n := range []string{si.ProductName, si.BoardName} {
|
||||
if isascii.MatchString(n) {
|
||||
name = n
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var manufacturer string
|
||||
if isascii.MatchString(si.ProductVendor) {
|
||||
manufacturer = si.ProductVendor
|
||||
}
|
||||
return serial, name, manufacturer
|
||||
}
|
||||
|
12
client/system/sysinfo_linux.go
Normal file
12
client/system/sysinfo_linux.go
Normal file
@ -0,0 +1,12 @@
|
||||
package system
|
||||
|
||||
// SysInfo used to moc out the sysinfo getter
|
||||
type SysInfo struct {
|
||||
ChassisSerial string
|
||||
ProductSerial string
|
||||
BoardSerial string
|
||||
|
||||
ProductName string
|
||||
BoardName string
|
||||
ProductVendor string
|
||||
}
|
198
client/system/sysinfo_linux_test.go
Normal file
198
client/system/sysinfo_linux_test.go
Normal file
@ -0,0 +1,198 @@
|
||||
package system
|
||||
|
||||
import "testing"
|
||||
|
||||
func Test_sysInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sysInfo SysInfo
|
||||
wantSerialNum string
|
||||
wantProdName string
|
||||
wantManufacturer string
|
||||
}{
|
||||
{
|
||||
name: "Test Case 1",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "Default string",
|
||||
ProductSerial: "Default string",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "B650M-HDV/M.2",
|
||||
BoardName: "B650M-HDV/M.2",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Default string",
|
||||
wantProdName: "B650M-HDV/M.2",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Empty Chassis Serial",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "",
|
||||
ProductSerial: "Default string",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "B650M-HDV/M.2",
|
||||
BoardName: "B650M-HDV/M.2",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Default string",
|
||||
wantProdName: "B650M-HDV/M.2",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Empty Chassis Serial",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "",
|
||||
ProductSerial: "Default string",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "B650M-HDV/M.2",
|
||||
BoardName: "B650M-HDV/M.2",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Default string",
|
||||
wantProdName: "B650M-HDV/M.2",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Fallback to Product Serial",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "Default string",
|
||||
ProductSerial: "Product serial",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "B650M-HDV/M.2",
|
||||
BoardName: "B650M-HDV/M.2",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Product serial",
|
||||
wantProdName: "B650M-HDV/M.2",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Fallback to Product Serial with default string",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "Default string",
|
||||
ProductSerial: "Default string",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "B650M-HDV/M.2",
|
||||
BoardName: "B650M-HDV/M.2",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Default string",
|
||||
wantProdName: "B650M-HDV/M.2",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Non UTF-8 in Chassis Serial",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "\x80",
|
||||
ProductSerial: "Product serial",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "B650M-HDV/M.2",
|
||||
BoardName: "B650M-HDV/M.2",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Product serial",
|
||||
wantProdName: "B650M-HDV/M.2",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Non UTF-8 in Chassis Serial and Product Serial",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "\x80",
|
||||
ProductSerial: "\x80",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "B650M-HDV/M.2",
|
||||
BoardName: "B650M-HDV/M.2",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "M80-G8013200245",
|
||||
wantProdName: "B650M-HDV/M.2",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Non UTF-8 in Chassis Serial and Product Serial and BoardSerial",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "\x80",
|
||||
ProductSerial: "\x80",
|
||||
BoardSerial: "\x80",
|
||||
ProductName: "B650M-HDV/M.2",
|
||||
BoardName: "B650M-HDV/M.2",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "",
|
||||
wantProdName: "B650M-HDV/M.2",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
|
||||
{
|
||||
name: "Empty Product Name",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "Default string",
|
||||
ProductSerial: "Default string",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "",
|
||||
BoardName: "boardname",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Default string",
|
||||
wantProdName: "boardname",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Invalid Product Name",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "Default string",
|
||||
ProductSerial: "Default string",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "\x80",
|
||||
BoardName: "boardname",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Default string",
|
||||
wantProdName: "boardname",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Invalid BoardName Name",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "Default string",
|
||||
ProductSerial: "Default string",
|
||||
BoardSerial: "M80-G8013200245",
|
||||
ProductName: "\x80",
|
||||
BoardName: "\x80",
|
||||
ProductVendor: "ASRock",
|
||||
},
|
||||
wantSerialNum: "Default string",
|
||||
wantProdName: "",
|
||||
wantManufacturer: "ASRock",
|
||||
},
|
||||
{
|
||||
name: "Invalid chars",
|
||||
sysInfo: SysInfo{
|
||||
ChassisSerial: "\x80",
|
||||
ProductSerial: "\x80",
|
||||
BoardSerial: "\x80",
|
||||
ProductName: "\x80",
|
||||
BoardName: "\x80",
|
||||
ProductVendor: "\x80",
|
||||
},
|
||||
wantSerialNum: "",
|
||||
wantProdName: "",
|
||||
wantManufacturer: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotSerialNum, gotProdName, gotManufacturer := sysInfo(tt.sysInfo)
|
||||
if gotSerialNum != tt.wantSerialNum {
|
||||
t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum)
|
||||
}
|
||||
if gotProdName != tt.wantProdName {
|
||||
t.Errorf("sysInfo() gotProdName = %v, want %v", gotProdName, tt.wantProdName)
|
||||
}
|
||||
if gotManufacturer != tt.wantManufacturer {
|
||||
t.Errorf("sysInfo() gotManufacturer = %v, want %v", gotManufacturer, tt.wantManufacturer)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
//go:build darwin
|
||||
|
||||
package main
|
||||
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
func EncryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, message pb.Message) ([]byte, error) {
|
||||
byteResp, err := pb.Marshal(message)
|
||||
if err != nil {
|
||||
log.Errorf("failed marshalling message %v", err)
|
||||
log.Errorf("failed marshalling message %v, %+v", err, message.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -14,14 +14,29 @@ type TextFormatter struct {
|
||||
levelDesc []string
|
||||
}
|
||||
|
||||
// SyslogFormatter formats logs into text
|
||||
type SyslogFormatter struct {
|
||||
levelDesc []string
|
||||
}
|
||||
|
||||
var validLevelDesc = []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"}
|
||||
|
||||
|
||||
// NewTextFormatter create new MyTextFormatter instance
|
||||
func NewTextFormatter() *TextFormatter {
|
||||
return &TextFormatter{
|
||||
levelDesc: []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"},
|
||||
levelDesc: validLevelDesc,
|
||||
timestampFormat: time.RFC3339, // or RFC3339
|
||||
}
|
||||
}
|
||||
|
||||
// NewSyslogFormatter create new MySyslogFormatter instance
|
||||
func NewSyslogFormatter() *SyslogFormatter {
|
||||
return &SyslogFormatter{
|
||||
levelDesc: validLevelDesc,
|
||||
}
|
||||
}
|
||||
|
||||
// Format renders a single log entry
|
||||
func (f *TextFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
||||
var fields string
|
||||
@ -49,3 +64,20 @@ func (f *TextFormatter) parseLevel(level logrus.Level) string {
|
||||
|
||||
return f.levelDesc[level]
|
||||
}
|
||||
|
||||
// Format renders a single log entry
|
||||
func (f *SyslogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
||||
var fields string
|
||||
keys := make([]string, 0, len(entry.Data))
|
||||
for k, v := range entry.Data {
|
||||
if k == "source" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, fmt.Sprintf("%s: %v", k, v))
|
||||
}
|
||||
|
||||
if len(keys) > 0 {
|
||||
fields = fmt.Sprintf("[%s] ", strings.Join(keys, ", "))
|
||||
}
|
||||
return []byte(fmt.Sprintf("%s%s\n", fields, entry.Message)), nil
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLogMessageFormat(t *testing.T) {
|
||||
func TestLogTextFormat(t *testing.T) {
|
||||
|
||||
someEntry := &logrus.Entry{
|
||||
Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"},
|
||||
@ -24,3 +24,20 @@ func TestLogMessageFormat(t *testing.T) {
|
||||
expectedString := "^2021-02-21T01:10:30Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$"
|
||||
assert.Regexp(t, expectedString, parsedString)
|
||||
}
|
||||
|
||||
func TestLogSyslogFormat(t *testing.T) {
|
||||
|
||||
someEntry := &logrus.Entry{
|
||||
Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"},
|
||||
Time: time.Date(2021, time.Month(2), 21, 1, 10, 30, 0, time.UTC),
|
||||
Level: 3,
|
||||
Message: "Some Message",
|
||||
}
|
||||
|
||||
formatter := NewSyslogFormatter()
|
||||
result, _ := formatter.Format(someEntry)
|
||||
|
||||
parsedString := string(result)
|
||||
expectedString := "^\\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] Some Message\\s+$"
|
||||
assert.Regexp(t, expectedString, parsedString)
|
||||
}
|
||||
|
@ -10,6 +10,12 @@ func SetTextFormatter(logger *logrus.Logger) {
|
||||
logger.ReportCaller = true
|
||||
logger.AddHook(NewContextHook())
|
||||
}
|
||||
// SetSyslogFormatter set the text formatter for given logger.
|
||||
func SetSyslogFormatter(logger *logrus.Logger) {
|
||||
logger.Formatter = NewSyslogFormatter()
|
||||
logger.ReportCaller = true
|
||||
logger.AddHook(NewContextHook())
|
||||
}
|
||||
|
||||
// SetJSONFormatter set the JSON formatter for given logger.
|
||||
func SetJSONFormatter(logger *logrus.Logger) {
|
||||
|
2
go.mod
2
go.mod
@ -117,7 +117,7 @@ require (
|
||||
github.com/dgraph-io/ristretto v0.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/docker v26.1.3+incompatible // indirect
|
||||
github.com/docker/docker v26.1.4+incompatible // indirect
|
||||
github.com/docker/go-connections v0.5.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
|
4
go.sum
4
go.sum
@ -81,8 +81,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v26.1.3+incompatible h1:lLCzRbrVZrljpVNobJu1J2FHk8V0s4BawoZippkc+xo=
|
||||
github.com/docker/docker v26.1.3+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/docker v26.1.4+incompatible h1:vuTpXDuoga+Z38m1OZHzl7NKisKWaWlhjQk7IDPSLsU=
|
||||
github.com/docker/docker v26.1.4+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
||||
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
|
@ -135,8 +135,8 @@ type AccountManager interface {
|
||||
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
|
||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
GetValidatedPeers(account *Account) (map[string]struct{}, error)
|
||||
SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
|
||||
CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||
@ -770,10 +770,6 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
|
||||
// SetJWTGroups updates the user's auto groups by synchronizing JWT groups.
|
||||
// Returns true if there are changes in the JWT group membership.
|
||||
func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool {
|
||||
if len(groupsNames) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
user, ok := a.Users[userID]
|
||||
if !ok {
|
||||
return false
|
||||
@ -978,7 +974,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -1029,7 +1025,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
|
||||
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
||||
return func() (time.Duration, bool) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -1128,7 +1124,7 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error {
|
||||
|
||||
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
|
||||
func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
@ -1588,7 +1584,7 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
defer unlock()
|
||||
|
||||
account, err = am.Store.GetAccountByUser(ctx, user.Id)
|
||||
@ -1671,7 +1667,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, newAcc.Id)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
|
||||
alreadyUnlocked := false
|
||||
defer func() {
|
||||
if !alreadyUnlocked {
|
||||
@ -1827,7 +1823,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
||||
|
||||
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||
if err == nil {
|
||||
unlockAccount := am.Store.AcquireAccountWriteLock(ctx, account.Id)
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
defer unlockAccount()
|
||||
account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||
if err != nil {
|
||||
@ -1847,7 +1843,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
||||
return account, nil
|
||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
if domainAccount != nil {
|
||||
unlockAccount := am.Store.AcquireAccountWriteLock(ctx, domainAccount.Id)
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
|
||||
defer unlockAccount()
|
||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
||||
if err != nil {
|
||||
@ -1861,17 +1857,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
|
||||
if err != nil {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
||||
return nil, nil, nil, status.Errorf(status.Unauthenticated, "peer not registered")
|
||||
}
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountReadLock(ctx, accountID)
|
||||
defer unlock()
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer accountUnlock()
|
||||
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||
defer peerUnlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
@ -1891,26 +1881,20 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey
|
||||
return peer, netMap, postureChecks, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peer.Key)
|
||||
if err != nil {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
||||
return status.Errorf(status.Unauthenticated, "peer not registered")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
|
||||
accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer accountUnlock()
|
||||
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||
defer peerUnlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peer.Key, false, nil, account)
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peer.Key, err)
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -1923,7 +1907,7 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountReadLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
|
@ -2219,6 +2219,13 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added")
|
||||
assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups")
|
||||
})
|
||||
|
||||
t.Run("remove all JWT groups", func(t *testing.T) {
|
||||
updated := account.SetJWTGroups("user1", []string{})
|
||||
assert.True(t, updated, "account should be updated")
|
||||
assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain")
|
||||
assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccount_UserGroupsAddToPeers(t *testing.T) {
|
||||
|
@ -57,6 +57,10 @@ type Config struct {
|
||||
func (c Config) GetAuthAudiences() []string {
|
||||
audiences := []string{c.HttpConfig.AuthAudience}
|
||||
|
||||
if c.HttpConfig.ExtraAuthAudience != "" {
|
||||
audiences = append(audiences, c.HttpConfig.ExtraAuthAudience)
|
||||
}
|
||||
|
||||
if c.DeviceAuthorizationFlow != nil && c.DeviceAuthorizationFlow.ProviderConfig.Audience != "" {
|
||||
audiences = append(audiences, c.DeviceAuthorizationFlow.ProviderConfig.Audience)
|
||||
}
|
||||
@ -95,6 +99,8 @@ type HttpServerConfig struct {
|
||||
OIDCConfigEndpoint string
|
||||
// IdpSignKeyRefreshEnabled identifies the signing key is currently being rotated or not
|
||||
IdpSignKeyRefreshEnabled bool
|
||||
// Extra audience
|
||||
ExtraAuthAudience string
|
||||
}
|
||||
|
||||
// Host represents a Wiretrustee host (e.g. STUN, TURN, Signal)
|
||||
|
@ -36,7 +36,7 @@ func (d DNSSettings) Copy() DNSSettings {
|
||||
|
||||
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
||||
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -58,7 +58,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
||||
|
||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
|
@ -13,7 +13,7 @@ import (
|
||||
|
||||
// GetEvents returns a list of activity events of an account
|
||||
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
|
@ -39,8 +39,8 @@ type FileStore struct {
|
||||
mux sync.Mutex `json:"-"`
|
||||
storeFile string `json:"-"`
|
||||
|
||||
// sync.Mutex indexed by accountID
|
||||
accountLocks sync.Map `json:"-"`
|
||||
// sync.Mutex indexed by resource ID
|
||||
resourceLocks sync.Map `json:"-"`
|
||||
globalAccountLock sync.Mutex `json:"-"`
|
||||
|
||||
metrics telemetry.AppMetrics `json:"-"`
|
||||
@ -281,26 +281,26 @@ func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
|
||||
return unlock
|
||||
}
|
||||
|
||||
// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock
|
||||
func (s *FileStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) {
|
||||
log.WithContext(ctx).Debugf("acquiring lock for account %s", accountID)
|
||||
// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
|
||||
func (s *FileStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Debugf("acquiring lock for ID %s", uniqueID)
|
||||
start := time.Now()
|
||||
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
|
||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.Mutex{})
|
||||
mtx := value.(*sync.Mutex)
|
||||
mtx.Lock()
|
||||
|
||||
unlock = func() {
|
||||
mtx.Unlock()
|
||||
log.WithContext(ctx).Debugf("released lock for account %s in %v", accountID, time.Since(start))
|
||||
log.WithContext(ctx).Debugf("released lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
}
|
||||
|
||||
return unlock
|
||||
}
|
||||
|
||||
// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock
|
||||
// AcquireReadLockByUID acquires an ID lock for reading a resource and returns a function that releases the lock
|
||||
// This method is still returns a write lock as file store can't handle read locks
|
||||
func (s *FileStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) {
|
||||
return s.AcquireAccountWriteLock(ctx, accountID)
|
||||
func (s *FileStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
return s.AcquireWriteLockByUID(ctx, uniqueID)
|
||||
}
|
||||
|
||||
func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error {
|
||||
@ -666,6 +666,26 @@ func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error {
|
||||
return s.persist(ctx, s.storeFile)
|
||||
}
|
||||
|
||||
// SavePeer saves the peer in the account
|
||||
func (s *FileStore) SavePeer(_ context.Context, accountID string, peer *nbpeer.Peer) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newPeer := peer.Copy()
|
||||
|
||||
account.Peers[peer.ID] = newPeer
|
||||
|
||||
s.PeerKeyID2AccountID[peer.Key] = accountID
|
||||
s.PeerID2AccountID[peer.ID] = accountID
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
|
||||
// PeerStatus will be saved eventually when some other changes occur.
|
||||
func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"path"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
@ -30,6 +31,8 @@ func loadGeolocationDatabases(dataDir string) error {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("geo location file %s not found , file will be downloaded", file)
|
||||
|
||||
switch file {
|
||||
case MMDBFileName:
|
||||
extractFunc := func(src string, dst string) error {
|
||||
|
@ -23,7 +23,7 @@ func (e *GroupLinkError) Error() string {
|
||||
|
||||
// GetGroup object of the peers
|
||||
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -50,7 +50,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
|
||||
|
||||
// GetAllGroups returns all groups in an account
|
||||
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -77,7 +77,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID str
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -110,7 +110,7 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
|
||||
|
||||
// SaveGroup object of the peers
|
||||
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup})
|
||||
}
|
||||
@ -163,7 +163,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveGroups(account.Id, account.Groups); err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -245,7 +245,7 @@ func difference(a, b []string) []string {
|
||||
|
||||
// DeleteGroup object of the peers
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountId)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
@ -359,7 +359,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
||||
|
||||
// ListGroups objects of the peers
|
||||
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -377,7 +377,7 @@ func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID strin
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -413,7 +413,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
|
@ -155,7 +155,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
}
|
||||
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
|
||||
if err != nil {
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
@ -178,11 +178,11 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
||||
}
|
||||
|
||||
return s.handleUpdates(ctx, peerKey, peer, updates, srv)
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
for {
|
||||
select {
|
||||
// condition when there are some updates
|
||||
@ -193,12 +193,12 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
|
||||
|
||||
if !open {
|
||||
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return nil
|
||||
}
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
|
||||
if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil {
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -206,7 +206,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
|
||||
case <-srv.Context().Done():
|
||||
// happens when connection drops, e.g. client disconnects
|
||||
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return srv.Context().Err()
|
||||
}
|
||||
}
|
||||
@ -214,10 +214,10 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, pee
|
||||
|
||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||
// then sends the encrypted message to the connected peer via the sync server.
|
||||
func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return status.Errorf(codes.Internal, "failed processing update message")
|
||||
}
|
||||
err = srv.SendMsg(&proto.EncryptedMessage{
|
||||
@ -225,17 +225,17 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *
|
||||
Body: encryptedResp,
|
||||
})
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return status.Errorf(codes.Internal, "failed sending update message")
|
||||
}
|
||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) {
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
s.turnRelayTokenManager.CancelRefresh(peer.ID)
|
||||
_ = s.accountManager.CancelPeerRoutines(ctx, peer)
|
||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
|
||||
|
@ -526,6 +526,43 @@ components:
|
||||
- revoked
|
||||
- auto_groups
|
||||
- usage_limit
|
||||
CreateSetupKeyRequest:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: Setup Key name
|
||||
type: string
|
||||
example: Default key
|
||||
type:
|
||||
description: Setup key type, one-off for single time usage and reusable
|
||||
type: string
|
||||
example: reusable
|
||||
expires_in:
|
||||
description: Expiration time in seconds
|
||||
type: integer
|
||||
minimum: 86400
|
||||
maximum: 31536000
|
||||
example: 86400
|
||||
auto_groups:
|
||||
description: List of group IDs to auto-assign to peers registered with this key
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
example: "ch8i4ug6lnn4g9hqv7m0"
|
||||
usage_limit:
|
||||
description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
type: integer
|
||||
example: 0
|
||||
ephemeral:
|
||||
description: Indicate that the peer will be ephemeral or not
|
||||
type: boolean
|
||||
example: true
|
||||
required:
|
||||
- name
|
||||
- type
|
||||
- expires_in
|
||||
- auto_groups
|
||||
- usage_limit
|
||||
PersonalAccessToken:
|
||||
type: object
|
||||
properties:
|
||||
@ -1806,7 +1843,7 @@ paths:
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
$ref: '#/components/schemas/SetupKeyRequest'
|
||||
$ref: '#/components/schemas/CreateSetupKeyRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: A Setup Keys Object
|
||||
|
@ -254,6 +254,27 @@ type Country struct {
|
||||
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
|
||||
type CountryCode = string
|
||||
|
||||
// CreateSetupKeyRequest defines model for CreateSetupKeyRequest.
|
||||
type CreateSetupKeyRequest struct {
|
||||
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Ephemeral Indicate that the peer will be ephemeral or not
|
||||
Ephemeral *bool `json:"ephemeral,omitempty"`
|
||||
|
||||
// ExpiresIn Expiration time in seconds
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
|
||||
// Name Setup Key name
|
||||
Name string `json:"name"`
|
||||
|
||||
// Type Setup key type, one-off for single time usage and reusable
|
||||
Type string `json:"type"`
|
||||
|
||||
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
UsageLimit int `json:"usage_limit"`
|
||||
}
|
||||
|
||||
// DNSSettings defines model for DNSSettings.
|
||||
type DNSSettings struct {
|
||||
// DisabledManagementGroups Groups whose DNS management is disabled
|
||||
@ -1241,7 +1262,7 @@ type PostApiRoutesJSONRequestBody = RouteRequest
|
||||
type PutApiRoutesRouteIdJSONRequestBody = RouteRequest
|
||||
|
||||
// PostApiSetupKeysJSONRequestBody defines body for PostApiSetupKeys for application/json ContentType.
|
||||
type PostApiSetupKeysJSONRequestBody = SetupKeyRequest
|
||||
type PostApiSetupKeysJSONRequestBody = CreateSetupKeyRequest
|
||||
|
||||
// PutApiSetupKeysKeyIdJSONRequestBody defines body for PutApiSetupKeysKeyId for application/json ContentType.
|
||||
type PutApiSetupKeysKeyIdJSONRequestBody = SetupKeyRequest
|
||||
|
@ -32,7 +32,7 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
|
||||
return errors.New("invalid groups")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
a, err := am.Store.GetAccountByUser(ctx, userID)
|
||||
|
@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@ -16,6 +17,7 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@ -83,7 +85,7 @@ func Test_SyncProtocol(t *testing.T) {
|
||||
defer func() {
|
||||
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
||||
}()
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, &Config{
|
||||
mgmtServer, _, mgmtAddr, err := startManagement(t, &Config{
|
||||
Stuns: []*Host{{
|
||||
Proto: "udp",
|
||||
URI: "stun:stun.wiretrustee.com:3468",
|
||||
@ -399,25 +401,28 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) {
|
||||
func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
|
||||
t.Helper()
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
accountManager, err := BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
|
||||
ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck
|
||||
|
||||
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MocIntegratedValidator{})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
|
||||
rc := &RelayConfig{
|
||||
@ -428,7 +433,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
|
||||
ephemeralMgr := NewEphemeralManager(store, accountManager)
|
||||
mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
|
||||
@ -438,7 +443,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis.Addr().String(), nil
|
||||
return s, accountManager, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn, error) {
|
||||
@ -458,3 +463,165 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie
|
||||
|
||||
return mgmtProto.NewManagementServiceClient(conn), conn, nil
|
||||
}
|
||||
func Test_SyncStatusRace(t *testing.T) {
|
||||
if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" {
|
||||
t.Skip("Skipping on CI and Postgres store")
|
||||
}
|
||||
for i := 0; i < 500; i++ {
|
||||
t.Run(fmt.Sprintf("TestRun-%d", i), func(t *testing.T) {
|
||||
testSyncStatusRace(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
func testSyncStatusRace(t *testing.T) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
||||
}()
|
||||
|
||||
mgmtServer, am, mgmtAddr, err := startManagement(t, &Config{
|
||||
Stuns: []*Host{{
|
||||
Proto: "udp",
|
||||
URI: "stun:stun.wiretrustee.com:3468",
|
||||
}},
|
||||
TURNConfig: &TURNConfig{
|
||||
TimeBasedCredentials: false,
|
||||
CredentialsTTL: util.Duration{},
|
||||
Secret: "whatever",
|
||||
Turns: []*Host{{
|
||||
Proto: "udp",
|
||||
URI: "turn:stun.wiretrustee.com:3468",
|
||||
}},
|
||||
},
|
||||
Signal: &Host{
|
||||
Proto: "http",
|
||||
URI: "signal.wiretrustee.com:10000",
|
||||
},
|
||||
Datadir: dir,
|
||||
HttpConfig: nil,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
client, clientConn, err := createRawClient(mgmtAddr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
defer clientConn.Close()
|
||||
|
||||
// there are two peers already in the store, add two more
|
||||
peers, err := registerPeers(2, client)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
serverKey, err := getServerKey(client)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
concurrentPeerKey2 := peers[1]
|
||||
t.Log("Public key of concurrent peer: ", concurrentPeerKey2.PublicKey().String())
|
||||
|
||||
syncReq2 := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
|
||||
message2, err := encryption.EncryptMessage(*serverKey, *concurrentPeerKey2, syncReq2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx2, cancelFunc2 := context.WithCancel(context.Background())
|
||||
|
||||
//client.
|
||||
sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: concurrentPeerKey2.PublicKey().String(),
|
||||
Body: message2,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
resp2 := &mgmtProto.EncryptedMessage{}
|
||||
err = sync2.RecvMsg(resp2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
peerWithInvalidStatus := peers[0]
|
||||
t.Log("Public key of peer with invalid status: ", peerWithInvalidStatus.PublicKey().String())
|
||||
|
||||
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
|
||||
message, err := encryption.EncryptMessage(*serverKey, *peerWithInvalidStatus, syncReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
|
||||
//client.
|
||||
sync, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: peerWithInvalidStatus.PublicKey().String(),
|
||||
Body: message,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
// take the first registered peer as a base for the test. Total four.
|
||||
|
||||
resp := &mgmtProto.EncryptedMessage{}
|
||||
err = sync.RecvMsg(resp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
cancelFunc2()
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
cancelFunc()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
ctx, cancelFunc = context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
sync, err = client.Sync(ctx, &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: peerWithInvalidStatus.PublicKey().String(),
|
||||
Body: message,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
resp = &mgmtProto.EncryptedMessage{}
|
||||
err = sync.RecvMsg(resp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
if !peer.Status.Connected {
|
||||
t.Fatal("Peer should be connected")
|
||||
}
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ type MockAccountManager struct {
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error)
|
||||
@ -105,14 +105,14 @@ type MockAccountManager struct {
|
||||
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
if am.SyncAndMarkPeerFunc != nil {
|
||||
return am.SyncAndMarkPeerFunc(ctx, peerPubKey, meta, realIP)
|
||||
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CancelPeerRoutines(_ context.Context, peer *nbpeer.Peer) error {
|
||||
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
|
||||
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
||||
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -48,7 +48,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
|
||||
// CreateNameServerGroup creates and saves a new nameserver group
|
||||
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -95,7 +95,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
// SaveNameServerGroup saves nameserver group
|
||||
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if nsGroupToSave == nil {
|
||||
@ -130,7 +130,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
||||
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -160,7 +160,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
// ListNameServerGroups returns a list of nameserver groups from account
|
||||
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
|
@ -7,10 +7,11 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@ -149,7 +150,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
|
||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
|
||||
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -271,7 +272,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
|
||||
|
||||
// DeletePeer removes peer from the account by its IP
|
||||
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -355,7 +356,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
@ -379,7 +380,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
|
||||
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
||||
// Such case is possible when AddPeer function takes long time to finish after AcquireAccountWriteLock (e.g., database is slow)
|
||||
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
|
||||
// and the peer disconnects with a timeout and tries to register again.
|
||||
// We just check if this machine has been registered before and reject the second registration.
|
||||
// The connecting peer should be able to recover with a retry.
|
||||
@ -452,6 +453,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
Location: peer.Location,
|
||||
}
|
||||
|
||||
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
|
||||
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
|
||||
} else {
|
||||
newPeer.Location.CountryCode = location.Country.ISOCode
|
||||
newPeer.Location.CityName = location.City.Names.En
|
||||
newPeer.Location.GeoNameID = location.City.GeonameID
|
||||
}
|
||||
}
|
||||
|
||||
// add peer to 'All' group
|
||||
group, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
@ -534,12 +546,12 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
||||
}
|
||||
|
||||
if peerLoginExpired(ctx, peer, account.Settings) {
|
||||
return nil, nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
||||
return nil, nil, nil, status.NewPeerLoginExpiredError()
|
||||
}
|
||||
|
||||
peer, updated := updatePeerMeta(peer, sync.Meta, account)
|
||||
if updated {
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
err = am.Store.SavePeer(ctx, account.Id, peer)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@ -588,18 +600,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
Key: login.WireGuardPubKey,
|
||||
Meta: login.Meta,
|
||||
SSHKey: login.SSHKey,
|
||||
}
|
||||
if am.geo != nil && login.ConnectionIP != nil {
|
||||
location, err := am.geo.Lookup(login.ConnectionIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err)
|
||||
} else {
|
||||
newPeer.Location.ConnectionIP = login.ConnectionIP
|
||||
newPeer.Location.CountryCode = location.Country.ISOCode
|
||||
newPeer.Location.CityName = location.City.Names.En
|
||||
newPeer.Location.GeoNameID = location.City.GeonameID
|
||||
|
||||
}
|
||||
Location: nbpeer.Location{ConnectionIP: login.ConnectionIP},
|
||||
}
|
||||
|
||||
return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer)
|
||||
@ -609,44 +610,17 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer")
|
||||
}
|
||||
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
|
||||
// when the client sends a login request with a JWT which is used to get the user ID,
|
||||
// it means that the client has already checked if it needs login and had been through the SSO flow
|
||||
// so, we can skip this check and directly proceed with the login
|
||||
if login.UserID == "" {
|
||||
err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.NewPeerNotRegisteredError()
|
||||
}
|
||||
|
||||
accSettings, err := am.Store.GetAccountSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err)
|
||||
}
|
||||
|
||||
var isWriteLock bool
|
||||
|
||||
// duplicated logic from after the lock to have an early exit
|
||||
expired := peerLoginExpired(ctx, peer, accSettings)
|
||||
switch {
|
||||
case expired:
|
||||
if err := checkAuth(ctx, login.UserID, peer); err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
isWriteLock = true
|
||||
log.WithContext(ctx).Debugf("peer login expired, acquiring write lock")
|
||||
|
||||
case peer.UpdateMetaIfNew(login.Meta):
|
||||
isWriteLock = true
|
||||
log.WithContext(ctx).Debugf("peer changed meta, acquiring write lock")
|
||||
|
||||
default:
|
||||
isWriteLock = false
|
||||
log.WithContext(ctx).Debugf("peer meta is the same, acquiring read lock")
|
||||
}
|
||||
|
||||
var unlock func()
|
||||
|
||||
if isWriteLock {
|
||||
unlock = am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
} else {
|
||||
unlock = am.Store.AcquireAccountReadLock(ctx, accountID)
|
||||
}
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
@ -659,7 +633,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
peer, err = account.FindPeerByPubKey(login.WireGuardPubKey)
|
||||
peer, err := account.FindPeerByPubKey(login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.NewPeerNotRegisteredError()
|
||||
}
|
||||
@ -670,53 +644,39 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
}
|
||||
|
||||
// this flag prevents unnecessary calls to the persistent store.
|
||||
shouldStoreAccount := false
|
||||
shouldStorePeer := false
|
||||
updateRemotePeers := false
|
||||
if peerLoginExpired(ctx, peer, account.Settings) {
|
||||
err = checkAuth(ctx, login.UserID, peer)
|
||||
err = am.handleExpiredPeer(ctx, login, account, peer)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
// If peer was expired before and if it reached this point, it is re-authenticated.
|
||||
// UserID is present, meaning that JWT validation passed successfully in the API layer.
|
||||
updatePeerLastLogin(peer, account)
|
||||
updateRemotePeers = true
|
||||
shouldStoreAccount = true
|
||||
|
||||
// sync user last login with peer last login
|
||||
user, err := account.FindUser(login.UserID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user")
|
||||
}
|
||||
user.updateLastLogin(peer.LastLogin)
|
||||
|
||||
am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
|
||||
shouldStorePeer = true
|
||||
}
|
||||
|
||||
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
peer, updated := updatePeerMeta(peer, login.Meta, account)
|
||||
if updated {
|
||||
shouldStoreAccount = true
|
||||
shouldStorePeer = true
|
||||
}
|
||||
|
||||
peer, err = am.checkAndUpdatePeerSSHKey(ctx, peer, account, login.SSHKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
if peer.SSHKey != login.SSHKey {
|
||||
peer.SSHKey = login.SSHKey
|
||||
shouldStorePeer = true
|
||||
}
|
||||
|
||||
if shouldStoreAccount {
|
||||
if !isWriteLock {
|
||||
log.WithContext(ctx).Errorf("account %s should be stored but is not write locked", accountID)
|
||||
return nil, nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked")
|
||||
}
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if shouldStorePeer {
|
||||
err = am.Store.SavePeer(ctx, accountID, peer)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
@ -724,13 +684,46 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer)
|
||||
}
|
||||
|
||||
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO
|
||||
// and if the peer login is expired.
|
||||
// The NetBird client doesn't have a way to check if the peer needs login besides sending a login request
|
||||
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
|
||||
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
|
||||
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if the peer was not added with SSO login we can exit early because peers activated with setup-key
|
||||
// doesn't expire, and we avoid extra databases calls.
|
||||
if !peer.AddedWithSSOLogin() {
|
||||
return nil
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if peerLoginExpired(ctx, peer, settings) {
|
||||
return status.NewPeerLoginExpiredError()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
var postureChecks []*posture.Checks
|
||||
|
||||
if isRequiresApproval {
|
||||
emptyMap := &NetworkMap{
|
||||
Network: account.Network.Copy(),
|
||||
}
|
||||
return peer, emptyMap, postureChecks, nil
|
||||
return peer, emptyMap, nil, nil
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
@ -742,6 +735,30 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error {
|
||||
err := checkAuth(ctx, login.UserID, peer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// If peer was expired before and if it reached this point, it is re-authenticated.
|
||||
// UserID is present, meaning that JWT validation passed successfully in the API layer.
|
||||
updatePeerLastLogin(peer, account)
|
||||
|
||||
// sync user last login with peer last login
|
||||
user, err := account.FindUser(login.UserID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "couldn't find user")
|
||||
}
|
||||
|
||||
err = am.Store.SaveUserLastLogin(account.Id, user.Id, peer.LastLogin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
|
||||
if peer.AddedWithSSOLogin() {
|
||||
user, err := account.FindUser(peer.UserID)
|
||||
@ -758,11 +775,11 @@ func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
|
||||
func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error {
|
||||
if loginUserID == "" {
|
||||
// absence of a user ID indicates that JWT wasn't provided.
|
||||
return status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
||||
return status.NewPeerLoginExpiredError()
|
||||
}
|
||||
if peer.UserID != loginUserID {
|
||||
log.WithContext(ctx).Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID)
|
||||
return status.Errorf(status.Unauthenticated, "can't login")
|
||||
return status.Errorf(status.Unauthenticated, "can't login with this credentials")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -782,31 +799,6 @@ func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) {
|
||||
account.UpdatePeer(peer)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(ctx context.Context, peer *nbpeer.Peer, account *Account, newSSHKey string) (*nbpeer.Peer, error) {
|
||||
if len(newSSHKey) == 0 {
|
||||
log.WithContext(ctx).Debugf("no new SSH key provided for peer %s, skipping update", peer.ID)
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
if peer.SSHKey == newSSHKey {
|
||||
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peer.ID)
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
peer.SSHKey = newSSHKey
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
err := am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// trigger network map update
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
// UpdatePeerSSHKey updates peer's public SSH key
|
||||
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
|
||||
if sshKey == "" {
|
||||
@ -819,7 +811,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID st
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
defer unlock()
|
||||
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
@ -854,7 +846,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID st
|
||||
|
||||
// GetPeer for a given accountID, peerID and userID error if not found.
|
||||
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
|
@ -315,7 +315,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
|
||||
|
||||
// GetPolicy from the store
|
||||
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -343,7 +343,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
|
||||
|
||||
// SavePolicy in the store
|
||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -371,7 +371,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
|
||||
// DeletePolicy from the store
|
||||
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -398,7 +398,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
|
||||
// ListPolicies from the store
|
||||
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
|
@ -15,7 +15,7 @@ const (
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -42,7 +42,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -89,7 +89,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -121,7 +121,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
|
@ -17,7 +17,7 @@ import (
|
||||
|
||||
// GetRoute gets a route object from account and route IDs
|
||||
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -126,7 +126,7 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string {
|
||||
|
||||
// CreateRoute creates and saves a new route
|
||||
func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -214,7 +214,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
|
||||
// SaveRoute saves route
|
||||
func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if routeToSave == nil {
|
||||
@ -283,7 +283,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
|
||||
// DeleteRoute deletes route with routeID
|
||||
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -311,7 +311,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
|
||||
// ListRoutes returns a list of routes from account
|
||||
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
|
@ -210,7 +210,7 @@ func Hash(s string) uint32 {
|
||||
// and adds it to the specified account. A list of autoGroups IDs can be empty.
|
||||
func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
keyDuration := DefaultSetupKeyDuration
|
||||
@ -256,7 +256,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
// (e.g. the key itself, creation date, ID, etc).
|
||||
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
|
||||
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if keyToSave == nil {
|
||||
@ -328,7 +328,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
|
||||
// ListSetupKeys returns a list of all setup keys of the account
|
||||
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
@ -360,7 +360,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
|
||||
|
||||
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
||||
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
|
@ -33,12 +33,14 @@ import (
|
||||
const (
|
||||
storeSqliteFileName = "store.db"
|
||||
idQueryCondition = "id = ?"
|
||||
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||
peerNotFoundFMT = "peer %s not found"
|
||||
)
|
||||
|
||||
// SqlStore represents an account storage backed by a Sql DB persisted to disk
|
||||
type SqlStore struct {
|
||||
db *gorm.DB
|
||||
accountLocks sync.Map
|
||||
resourceLocks sync.Map
|
||||
globalAccountLock sync.Mutex
|
||||
metrics telemetry.AppMetrics
|
||||
installationPK int
|
||||
@ -96,33 +98,35 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
|
||||
return unlock
|
||||
}
|
||||
|
||||
func (s *SqlStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring write lock for account %s", accountID)
|
||||
// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
|
||||
func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID)
|
||||
|
||||
start := time.Now()
|
||||
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
|
||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.Lock()
|
||||
|
||||
unlock = func() {
|
||||
mtx.Unlock()
|
||||
log.WithContext(ctx).Tracef("released write lock for account %s in %v", accountID, time.Since(start))
|
||||
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
}
|
||||
|
||||
return unlock
|
||||
}
|
||||
|
||||
func (s *SqlStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring read lock for account %s", accountID)
|
||||
// AcquireReadLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
|
||||
func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID)
|
||||
|
||||
start := time.Now()
|
||||
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
|
||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.RLock()
|
||||
|
||||
unlock = func() {
|
||||
mtx.RUnlock()
|
||||
log.WithContext(ctx).Tracef("released read lock for account %s in %v", accountID, time.Since(start))
|
||||
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
}
|
||||
|
||||
return unlock
|
||||
@ -271,6 +275,38 @@ func (s *SqlStore) GetInstallationID() string {
|
||||
return installation.InstallationIDValue
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
|
||||
// To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
|
||||
peerCopy := peer.Copy()
|
||||
peerCopy.AccountID = accountID
|
||||
|
||||
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// check if peer exists before saving
|
||||
var peerID string
|
||||
result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if peerID == "" {
|
||||
return status.Errorf(status.NotFound, peerNotFoundFMT, peer.ID)
|
||||
}
|
||||
|
||||
result = tx.Model(&nbpeer.Peer{}).Where(accountAndIDQueryCondition, accountID, peer.ID).Save(peerCopy)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||
var peerCopy nbpeer.Peer
|
||||
peerCopy.Status = &peerStatus
|
||||
@ -281,14 +317,14 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
|
||||
}
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
Select(fieldsToUpdate).
|
||||
Where("account_id = ? AND id = ?", accountID, peerID).
|
||||
Where(accountAndIDQueryCondition, accountID, peerID).
|
||||
Updates(&peerCopy)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "peer %s not found", peerID)
|
||||
return status.Errorf(status.NotFound, peerNotFoundFMT, peerID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -302,7 +338,7 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
|
||||
peerCopy.Location = peerWithLocation.Location
|
||||
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
Where("account_id = ? and id = ?", accountID, peerWithLocation.ID).
|
||||
Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID).
|
||||
Updates(peerCopy)
|
||||
|
||||
if result.Error != nil {
|
||||
@ -310,7 +346,7 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID)
|
||||
return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -644,7 +680,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*S
|
||||
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
||||
var user User
|
||||
|
||||
result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID)
|
||||
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "user %s not found", userID)
|
||||
|
@ -362,6 +362,54 @@ func TestSqlite_GetAccount(t *testing.T) {
|
||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||
}
|
||||
|
||||
func TestSqlite_SavePeer(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store := newSqliteStoreFromFile(t, "testdata/store.json")
|
||||
|
||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
require.NoError(t, err)
|
||||
|
||||
// save status of non-existing peer
|
||||
peer := &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
ID: "testpeer",
|
||||
SetupKey: "peerkeysetupkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
ctx := context.Background()
|
||||
err = store.SavePeer(ctx, account.Id, peer)
|
||||
assert.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||
|
||||
// save new status of existing peer
|
||||
account.Peers[peer.ID] = peer
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedPeer := peer.Copy()
|
||||
updatedPeer.Status.Connected = false
|
||||
updatedPeer.Meta.Hostname = "updatedpeer"
|
||||
|
||||
err = store.SavePeer(ctx, account.Id, updatedPeer)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err = store.GetAccount(context.Background(), account.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual := account.Peers[peer.ID]
|
||||
assert.Equal(t, updatedPeer.Status, actual.Status)
|
||||
assert.Equal(t, updatedPeer.Meta, actual.Meta)
|
||||
}
|
||||
|
||||
func TestSqlite_SavePeerStatus(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
@ -402,7 +450,19 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
|
||||
|
||||
actual := account.Peers["testpeer"].Status
|
||||
assert.Equal(t, newStatus, *actual)
|
||||
|
||||
newStatus.Connected = true
|
||||
|
||||
err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err = store.GetAccount(context.Background(), account.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual = account.Peers["testpeer"].Status
|
||||
assert.Equal(t, newStatus, *actual)
|
||||
}
|
||||
|
||||
func TestSqlite_SavePeerLocation(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
|
@ -95,3 +95,8 @@ func NewUserNotFoundError(userKey string) error {
|
||||
func NewPeerNotRegisteredError() error {
|
||||
return Errorf(Unauthenticated, "peer is not registered")
|
||||
}
|
||||
|
||||
// NewPeerLoginExpiredError creates a new Error with PermissionDenied type for an expired peer
|
||||
func NewPeerLoginExpiredError() error {
|
||||
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
|
||||
}
|
||||
|
@ -12,10 +12,11 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
@ -48,12 +49,13 @@ type Store interface {
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
GetInstallationID() string
|
||||
SaveInstallationID(ctx context.Context, ID string) error
|
||||
// AcquireAccountWriteLock should attempt to acquire account lock for write purposes and return a function that releases the lock
|
||||
AcquireAccountWriteLock(ctx context.Context, accountID string) func()
|
||||
// AcquireAccountReadLock should attempt to acquire account lock for read purposes and return a function that releases the lock
|
||||
AcquireAccountReadLock(ctx context.Context, accountID string) func()
|
||||
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
|
||||
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
|
||||
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
|
||||
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
|
||||
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
||||
AcquireGlobalLock(ctx context.Context) func()
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error
|
||||
|
@ -211,7 +211,7 @@ func NewOwnerUser(id string) *User {
|
||||
|
||||
// createServiceUser creates a new service user under the given account.
|
||||
func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -267,7 +267,7 @@ func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, user
|
||||
|
||||
// inviteNewUser Invites a USer to a given account and creates reference in datastore
|
||||
func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if am.idpManager == nil {
|
||||
@ -368,7 +368,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
|
||||
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
defer unlock()
|
||||
|
||||
account, err = am.Store.GetAccount(ctx, account.Id)
|
||||
@ -401,7 +401,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
|
||||
// ListUsers returns lists of all users under the account.
|
||||
// It doesn't populate user information such as email or name.
|
||||
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -428,7 +428,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
|
||||
if initiatorUserID == targetUserID {
|
||||
return status.Errorf(status.InvalidArgument, "self deletion is not allowed")
|
||||
}
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -538,7 +538,7 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU
|
||||
|
||||
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
|
||||
func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if am.idpManager == nil {
|
||||
@ -578,7 +578,7 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
|
||||
|
||||
// CreatePAT creates a new PAT for the given user
|
||||
func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if tokenName == "" {
|
||||
@ -628,7 +628,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
|
||||
|
||||
// DeletePAT deletes a specific PAT from a user
|
||||
func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -678,7 +678,7 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
|
||||
|
||||
// GetPAT returns a specific PAT from a user
|
||||
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -710,7 +710,7 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
|
||||
|
||||
// GetAllPATs returns all PATs for a user
|
||||
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -752,7 +752,7 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists)
|
||||
@ -859,7 +859,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveUsers(account.Id, account.Users); err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -35,8 +35,11 @@ func InitLog(logLevel string, logPath string) error {
|
||||
AddSyslogHook()
|
||||
}
|
||||
|
||||
//nolint:gocritic
|
||||
if os.Getenv("NB_LOG_FORMAT") == "json" {
|
||||
formatter.SetJSONFormatter(log.StandardLogger())
|
||||
} else if logPath == "syslog" {
|
||||
formatter.SetSyslogFormatter(log.StandardLogger())
|
||||
} else {
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user