Merge remote-tracking branch 'origin/main' into local/engine-restart

This commit is contained in:
Pascal Fischer 2023-10-06 16:33:01 +02:00
commit 91b45eab98
72 changed files with 2465 additions and 2508 deletions

View File

@ -112,6 +112,27 @@ jobs:
grep -A 6 PKCEAuthorizationFlow management.json | grep -A 5 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT grep -A 6 PKCEAuthorizationFlow management.json | grep -A 5 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
grep -A 7 PKCEAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES" grep -A 7 PKCEAuthorizationFlow management.json | grep -A 6 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
- name: Install modules
run: go mod tidy
- name: Build management binary
working-directory: management
run: CGO_ENABLED=1 go build -o netbird-mgmt main.go
- name: Build management docker image
working-directory: management
run: |
docker build -t netbirdio/management:latest .
- name: Build signal binary
working-directory: signal
run: CGO_ENABLED=0 go build -o netbird-signal main.go
- name: Build signal docker image
working-directory: signal
run: |
docker build -t netbirdio/signal:latest .
- name: run docker compose up - name: run docker compose up
working-directory: infrastructure_files working-directory: infrastructure_files
run: | run: |

View File

@ -193,7 +193,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
} }
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) { func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config) oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -3,6 +3,7 @@ package cmd
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"strings" "strings"
"time" "time"
@ -83,6 +84,7 @@ var loginCmd = &cobra.Command{
SetupKey: setupKey, SetupKey: setupKey,
PreSharedKey: preSharedKey, PreSharedKey: preSharedKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
} }
var loginErr error var loginErr error
@ -163,7 +165,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
} }
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) { func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config) oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -202,3 +204,8 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
"https://docs.netbird.io/how-to/register-machines-using-setup-keys") "https://docs.netbird.io/how-to/register-machines-using-setup-keys")
} }
} }
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
func isLinuxRunningDesktop() bool {
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
}

View File

@ -148,6 +148,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
NatExternalIPs: natExternalIPs, NatExternalIPs: natExternalIPs,
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0, CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
CustomDNSAddress: customDNSAddressConverted, CustomDNSAddress: customDNSAddressConverted,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
} }
var loginErr error var loginErr error

View File

@ -93,7 +93,7 @@ func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
// AddFiltering rule to the firewall // AddFiltering rule to the firewall
// //
// If comment is empty rule ID is used as comment // Comment will be ignored because some system this feature is not supported
func (m *Manager) AddFiltering( func (m *Manager) AddFiltering(
ip net.IP, ip net.IP,
protocol fw.Protocol, protocol fw.Protocol,
@ -123,9 +123,6 @@ func (m *Manager) AddFiltering(
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal) ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal)
ruleID := uuid.New().String() ruleID := uuid.New().String()
if comment == "" {
comment = ruleID
}
if ipsetName != "" { if ipsetName != "" {
rs, rsExists := m.rulesets[ipsetName] rs, rsExists := m.rulesets[ipsetName]
@ -157,8 +154,7 @@ func (m *Manager) AddFiltering(
// this is new ipset so we need to create firewall rule for it // this is new ipset so we need to create firewall rule for it
} }
specs := m.filterRuleSpecs("filter", ip, string(protocol), sPortVal, dPortVal, specs := m.filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
direction, action, comment, ipsetName)
if direction == fw.RuleDirectionOUT { if direction == fw.RuleDirectionOUT {
ok, err := client.Exists("filter", ChainOutputFilterName, specs...) ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
@ -283,7 +279,7 @@ func (m *Manager) AllowNetbird() error {
fw.RuleDirectionIN, fw.RuleDirectionIN,
fw.ActionAccept, fw.ActionAccept,
"", "",
"allow netbird interface traffic", "",
) )
if err != nil { if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err) return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
@ -296,7 +292,7 @@ func (m *Manager) AllowNetbird() error {
fw.RuleDirectionOUT, fw.RuleDirectionOUT,
fw.ActionAccept, fw.ActionAccept,
"", "",
"allow netbird interface traffic", "",
) )
return err return err
} }
@ -362,9 +358,7 @@ func (m *Manager) reset(client *iptables.IPTables, table string) error {
// filterRuleSpecs returns the specs of a filtering rule // filterRuleSpecs returns the specs of a filtering rule
func (m *Manager) filterRuleSpecs( func (m *Manager) filterRuleSpecs(
table string, ip net.IP, protocol string, sPort, dPort string, ip net.IP, protocol string, sPort, dPort string, direction fw.RuleDirection, action fw.Action, ipsetName string,
direction fw.RuleDirection, action fw.Action, comment string,
ipsetName string,
) (specs []string) { ) (specs []string) {
matchByIP := true matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0 // don't use IP matching if IP is ip 0.0.0.0
@ -398,8 +392,7 @@ func (m *Manager) filterRuleSpecs(
if dPort != "" { if dPort != "" {
specs = append(specs, "--dport", dPort) specs = append(specs, "--dport", dPort)
} }
specs = append(specs, "-j", m.actionToStr(action)) return append(specs, "-j", m.actionToStr(action))
return append(specs, "-m", "comment", "--comment", comment)
} }
// rawClient returns corresponding iptables client for the given ip // rawClient returns corresponding iptables client for the given ip

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"runtime" "runtime"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
@ -63,14 +64,16 @@ func (t TokenInfo) GetTokenToUse() string {
// and if that also fails, the authentication process is deemed unsuccessful // and if that also fails, the authentication process is deemed unsuccessful
// //
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) {
if runtime.GOOS == "linux" && !isLinuxRunningDesktop() { if runtime.GOOS == "linux" && !isLinuxDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config) return authenticateWithDeviceCodeFlow(ctx, config)
} }
pkceFlow, err := authenticateWithPKCEFlow(ctx, config) pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
if err != nil { if err != nil {
// fallback to device code flow // fallback to device code flow
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
log.Debug("falling back to device code flow")
return authenticateWithDeviceCodeFlow(ctx, config) return authenticateWithDeviceCodeFlow(ctx, config)
} }
return pkceFlow, nil return pkceFlow, nil

View File

@ -197,7 +197,13 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo,
tokenInfo.IDToken = idToken tokenInfo.IDToken = idToken
} }
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), p.providerConfig.Audience); err != nil { // if a provider doesn't support an audience, use the Client ID for token verification
audience := p.providerConfig.Audience
if audience == "" {
audience = p.providerConfig.ClientID
}
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err) return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
} }

View File

@ -7,8 +7,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"os"
"reflect"
"strings" "strings"
) )
@ -45,15 +43,14 @@ func isValidAccessToken(token string, audience string) error {
} }
// Audience claim of JWT can be a string or an array of strings // Audience claim of JWT can be a string or an array of strings
typ := reflect.TypeOf(claims.Audience) switch aud := claims.Audience.(type) {
switch typ.Kind() { case string:
case reflect.String: if aud == audience {
if claims.Audience == audience {
return nil return nil
} }
case reflect.Slice: case []interface{}:
for _, aud := range claims.Audience.([]interface{}) { for _, audItem := range aud {
if audience == aud { if audStr, ok := audItem.(string); ok && audStr == audience {
return nil return nil
} }
} }
@ -61,8 +58,3 @@ func isValidAccessToken(token string, audience string) error {
return fmt.Errorf("invalid JWT token audience field") return fmt.Errorf("invalid JWT token audience field")
} }
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
func isLinuxRunningDesktop() bool {
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
}

View File

@ -106,9 +106,6 @@ func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error { func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator" errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMSGFormat, "Audience")
}
if config.ClientID == "" { if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID") return fmt.Errorf(errorMSGFormat, "Client ID")
} }

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.26.0 // protoc-gen-go v1.26.0
// protoc v3.21.9 // protoc v4.23.4
// source: daemon.proto // source: daemon.proto
package proto package proto
@ -42,6 +42,7 @@ type LoginRequest struct {
// omits initialized empty slices due to omitempty tags // omits initialized empty slices due to omitempty tags
CleanNATExternalIPs bool `protobuf:"varint,6,opt,name=cleanNATExternalIPs,proto3" json:"cleanNATExternalIPs,omitempty"` CleanNATExternalIPs bool `protobuf:"varint,6,opt,name=cleanNATExternalIPs,proto3" json:"cleanNATExternalIPs,omitempty"`
CustomDNSAddress []byte `protobuf:"bytes,7,opt,name=customDNSAddress,proto3" json:"customDNSAddress,omitempty"` CustomDNSAddress []byte `protobuf:"bytes,7,opt,name=customDNSAddress,proto3" json:"customDNSAddress,omitempty"`
IsLinuxDesktopClient bool `protobuf:"varint,8,opt,name=isLinuxDesktopClient,proto3" json:"isLinuxDesktopClient,omitempty"`
} }
func (x *LoginRequest) Reset() { func (x *LoginRequest) Reset() {
@ -125,6 +126,13 @@ func (x *LoginRequest) GetCustomDNSAddress() []byte {
return nil return nil
} }
func (x *LoginRequest) GetIsLinuxDesktopClient() bool {
if x != nil {
return x.IsLinuxDesktopClient
}
return false
}
type LoginResponse struct { type LoginResponse struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
@ -1043,7 +1051,7 @@ var file_daemon_proto_rawDesc = []byte{
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74,
0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65,
0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74,
0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x96, 0x02, 0x0a, 0x0c, 0x4c, 0x6f, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xca, 0x02, 0x0a, 0x0c, 0x4c, 0x6f,
0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65,
0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65,
0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x12, 0x22, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x12, 0x22, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61,
@ -1061,128 +1069,131 @@ var file_daemon_proto_rawDesc = []byte{
0x6e, 0x61, 0x6c, 0x49, 0x50, 0x73, 0x12, 0x2a, 0x0a, 0x10, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x6e, 0x61, 0x6c, 0x49, 0x50, 0x73, 0x12, 0x2a, 0x0a, 0x10, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d,
0x44, 0x4e, 0x53, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x44, 0x4e, 0x53, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c,
0x52, 0x10, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x44, 0x4e, 0x53, 0x41, 0x64, 0x64, 0x72, 0x65, 0x52, 0x10, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x44, 0x4e, 0x53, 0x41, 0x64, 0x64, 0x72, 0x65,
0x73, 0x73, 0x22, 0xb5, 0x01, 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x73, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x69, 0x73, 0x4c, 0x69, 0x6e, 0x75, 0x78, 0x44, 0x65, 0x73,
0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x6b, 0x74, 0x6f, 0x70, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08,
0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x52, 0x14, 0x69, 0x73, 0x4c, 0x69, 0x6e, 0x75, 0x78, 0x44, 0x65, 0x73, 0x6b, 0x74, 0x6f, 0x70,
0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x22, 0xb5, 0x01, 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e,
0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x6e, 0x65, 0x65, 0x64,
0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x28, 0x0a, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52,
0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1a,
0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09,
0x12, 0x38, 0x0a, 0x17, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x28, 0x0a, 0x0f, 0x76, 0x65,
0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x18, 0x03, 0x20,
0x09, 0x52, 0x17, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f,
0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x22, 0x31, 0x0a, 0x13, 0x57, 0x61, 0x6e, 0x55, 0x52, 0x49, 0x12, 0x38, 0x0a, 0x17, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61,
0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18,
0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x17, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74,
0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x22, 0x16, 0x0a, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x22, 0x31,
0x14, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x0a, 0x13, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x0b, 0x0a, 0x09, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64,
0x73, 0x74, 0x22, 0x0c, 0x0a, 0x0a, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64,
0x22, 0x3d, 0x0a, 0x0d, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x65, 0x22, 0x16, 0x0a, 0x14, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69,
0x74, 0x12, 0x2c, 0x0a, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x0b, 0x0a, 0x09, 0x55, 0x70, 0x52,
0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x67, 0x65, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0c, 0x0a, 0x0a, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70,
0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x3d, 0x0a, 0x0d, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65,
0x82, 0x01, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c,
0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08,
0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x32, 0x0a, 0x0a, 0x66, 0x75, 0x52, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61,
0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x74, 0x75, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73,
0x75, 0x73, 0x52, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x24, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x32,
0x0a, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x0a, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01,
0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x56, 0x65, 0x72, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x75, 0x6c, 0x6c,
0x73, 0x69, 0x6f, 0x6e, 0x22, 0x0d, 0x0a, 0x0b, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74,
0x65, 0x73, 0x74, 0x22, 0x0e, 0x0a, 0x0c, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x75, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x56, 0x65, 0x72, 0x73,
0x6e, 0x73, 0x65, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xb3, 0x01, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x43, 0x6e, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x0d, 0x0a, 0x0b, 0x44, 0x6f, 0x77, 0x6e,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0e, 0x0a, 0x0c, 0x44, 0x6f, 0x77, 0x6e, 0x52,
0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x55, 0x72, 0x6c, 0x18, 0x01, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x43, 0x6f,
0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xb3, 0x01, 0x0a, 0x11,
0x55, 0x72, 0x6c, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x46, 0x69, 0x6c, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x46, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x55,
0x69, 0x6c, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6c, 0x6f, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x18, 0x03, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65,
0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c, 0x6f, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x6d, 0x65, 0x6e, 0x74, 0x55, 0x72, 0x6c, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x66, 0x69,
0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x18, 0x04, 0x20, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6e,
0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x66, 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6c, 0x6f, 0x67, 0x46, 0x69,
0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52, 0x4c, 0x18, 0x05, 0x20, 0x6c, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c, 0x6f, 0x67, 0x46, 0x69, 0x6c,
0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52, 0x4c, 0x22, 0xcf, 0x02, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65,
0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72,
0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52,
0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52,
0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x4c, 0x22, 0xcf, 0x02, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12,
0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12,
0x74, 0x75, 0x73, 0x12, 0x46, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52,
0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53,
0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6e,
0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x46, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53,
0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x72, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28,
0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x72, 0x65, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x6c, 0x61, 0x79, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x10, 0x63,
0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x12, 0x34, 0x0a, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12,
0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x18, 0x0a, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08,
0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x6c, 0x6f, 0x52, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x69, 0x72,
0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x65, 0x63, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63,
0x79, 0x70, 0x65, 0x12, 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x74, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e,
0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09,
0x01, 0x28, 0x09, 0x52, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x52, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64,
0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74,
0x71, 0x64, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70,
0x76, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49,
0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12,
0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66,
0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x71, 0x64, 0x6e, 0x22, 0x76, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72,
0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28,
0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18,
0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a,
0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x3d, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65,
0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e,
0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18,
0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x3d, 0x0a, 0x0b, 0x53,
0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x22, 0x41, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52,
0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09,
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52,
0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x22, 0x41, 0x0a, 0x0f, 0x4d, 0x61,
0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x22, 0xef, 0x01, 0x0a, 0x0a, 0x46, 0x75, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a,
0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12,
0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01,
0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x22, 0xef, 0x01,
0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f,
0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d,
0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f,
0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12,
0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02,
0x74, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69,
0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61,
0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50,
0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16,
0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65,
0x74, 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x32, 0xf7, 0x02, 0x0a, 0x0d, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65,
0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18,
0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50,
0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x32,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0xf7, 0x02, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63,
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65,
0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52,
0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69,
0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52,
0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70,
0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12,
0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52,
0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00,
0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e,
0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66,
0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
} }
var ( var (

View File

@ -51,6 +51,7 @@ message LoginRequest {
bytes customDNSAddress = 7; bytes customDNSAddress = 7;
bool isLinuxDesktopClient = 8;
} }
message LoginResponse { message LoginResponse {

View File

@ -208,7 +208,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
state.Set(internal.StatusConnecting) state.Set(internal.StatusConnecting)
if msg.SetupKey == "" { if msg.SetupKey == "" {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config) oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsLinuxDesktopClient)
if err != nil { if err != nil {
state.Set(internal.StatusLoginFailed) state.Set(internal.StatusLoginFailed)
return nil, err return nil, err

1
go.mod
View File

@ -29,6 +29,7 @@ require (
require ( require (
fyne.io/fyne/v2 v2.1.4 fyne.io/fyne/v2 v2.1.4
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
github.com/c-robinson/iplib v1.0.3 github.com/c-robinson/iplib v1.0.3
github.com/cilium/ebpf v0.10.0 github.com/cilium/ebpf v0.10.0
github.com/coreos/go-iptables v0.7.0 github.com/coreos/go-iptables v0.7.0

2
go.sum
View File

@ -61,6 +61,8 @@ github.com/PuerkitoBio/purell v1.0.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbt
github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0=
github.com/PuerkitoBio/urlesc v0.0.0-20160726150825-5bd2802263f2/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/PuerkitoBio/urlesc v0.0.0-20160726150825-5bd2802263f2/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE=
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE=
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible h1:hqcTK6ZISdip65SR792lwYJTa/axESA0889D3UlZbLo=
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible/go.mod h1:6B1nuc1MUs6c62ODZDl7hVE5Pv7O2XGSkgg2olnq34I=
github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 h1:pami0oPhVosjOu/qRHepRmdjD6hGILF7DBr+qQZeP10= github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 h1:pami0oPhVosjOu/qRHepRmdjD6hGILF7DBr+qQZeP10=
github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2/go.mod h1:jNIx5ykW1MroBuaTja9+VpglmaJOUzezumfhLlER3oY= github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2/go.mod h1:jNIx5ykW1MroBuaTja9+VpglmaJOUzezumfhLlER3oY=
github.com/akavel/rsrc v0.8.0/go.mod h1:uLoCtb9J+EyAqh+26kdrTgmzRBFPGOolLWKpdxkKq+c= github.com/akavel/rsrc v0.8.0/go.mod h1:uLoCtb9J+EyAqh+26kdrTgmzRBFPGOolLWKpdxkKq+c=

View File

@ -46,6 +46,14 @@ NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken}
# PKCE authorization flow # PKCE authorization flow
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"} NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"}
NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false} NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false}
NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
# Dashboard
# The default setting is to transmit the audience to the IDP during authorization. However,
# if your IDP does not have this capability, you can turn this off by setting it to false.
NETBIRD_DASH_AUTH_USE_AUDIENCE=${NETBIRD_DASH_AUTH_USE_AUDIENCE:-true}
NETBIRD_DASH_AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
# exports # exports
export NETBIRD_DOMAIN export NETBIRD_DOMAIN
@ -87,3 +95,6 @@ export NETBIRD_AUTH_DEVICE_AUTH_SCOPE
export NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN export NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
export NETBIRD_AUTH_PKCE_USE_ID_TOKEN export NETBIRD_AUTH_PKCE_USE_ID_TOKEN
export NETBIRD_AUTH_PKCE_AUDIENCE
export NETBIRD_DASH_AUTH_USE_AUDIENCE
export NETBIRD_DASH_AUTH_AUDIENCE

View File

@ -164,6 +164,12 @@ done
export NETBIRD_AUTH_PKCE_REDIRECT_URLS=${REDIRECT_URLS%,} export NETBIRD_AUTH_PKCE_REDIRECT_URLS=${REDIRECT_URLS%,}
# Remove audience for providers that do not support it
if [ "$NETBIRD_DASH_AUTH_USE_AUDIENCE" = "false" ]; then
export NETBIRD_DASH_AUTH_AUDIENCE=none
export NETBIRD_AUTH_PKCE_AUDIENCE=
fi
env | grep NETBIRD env | grep NETBIRD
envsubst <docker-compose.yml.tmpl >docker-compose.yml envsubst <docker-compose.yml.tmpl >docker-compose.yml

View File

@ -12,7 +12,7 @@ services:
- NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT - NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
- NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT - NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
# OIDC # OIDC
- AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE - AUTH_AUDIENCE=$NETBIRD_DASH_AUTH_AUDIENCE
- AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID - AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID
- AUTH_CLIENT_SECRET=$NETBIRD_AUTH_CLIENT_SECRET - AUTH_CLIENT_SECRET=$NETBIRD_AUTH_CLIENT_SECRET
- AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY - AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY

View File

@ -12,7 +12,7 @@ services:
- NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT - NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
- NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT - NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
# OIDC # OIDC
- AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE - AUTH_AUDIENCE=$NETBIRD_DASH_AUTH_AUDIENCE
- AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID - AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID
- AUTH_CLIENT_SECRET=$NETBIRD_AUTH_CLIENT_SECRET - AUTH_CLIENT_SECRET=$NETBIRD_AUTH_CLIENT_SECRET
- AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY - AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY
@ -20,6 +20,7 @@ services:
- AUTH_SUPPORTED_SCOPES=$NETBIRD_AUTH_SUPPORTED_SCOPES - AUTH_SUPPORTED_SCOPES=$NETBIRD_AUTH_SUPPORTED_SCOPES
- AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI - AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI
- AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI - AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI
- NETBIRD_TOKEN_SOURCE=$NETBIRD_TOKEN_SOURCE
# SSL # SSL
- NGINX_SSL_PORT=443 - NGINX_SSL_PORT=443
# Letsencrypt # Letsencrypt

View File

@ -62,7 +62,7 @@
}, },
"PKCEAuthorizationFlow": { "PKCEAuthorizationFlow": {
"ProviderConfig": { "ProviderConfig": {
"Audience": "$NETBIRD_AUTH_AUDIENCE", "Audience": "$NETBIRD_AUTH_PKCE_AUDIENCE",
"ClientID": "$NETBIRD_AUTH_CLIENT_ID", "ClientID": "$NETBIRD_AUTH_CLIENT_ID",
"ClientSecret": "$NETBIRD_AUTH_CLIENT_SECRET", "ClientSecret": "$NETBIRD_AUTH_CLIENT_SECRET",
"AuthorizationEndpoint": "$NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT", "AuthorizationEndpoint": "$NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT",

View File

@ -8,6 +8,9 @@ NETBIRD_DOMAIN=""
# e.g., https://example.eu.auth0.com/.well-known/openid-configuration # e.g., https://example.eu.auth0.com/.well-known/openid-configuration
# ------------------------------------------- # -------------------------------------------
NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="" NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=""
# The default setting is to transmit the audience to the IDP during authorization. However,
# if your IDP does not have this capability, you can turn this off by setting it to false.
#NETBIRD_DASH_AUTH_USE_AUDIENCE=false
NETBIRD_AUTH_AUDIENCE="" NETBIRD_AUTH_AUDIENCE=""
# e.g. netbird-client # e.g. netbird-client
NETBIRD_AUTH_CLIENT_ID="" NETBIRD_AUTH_CLIENT_ID=""

View File

@ -148,8 +148,8 @@ var (
return fmt.Errorf("failed to initialize database: %s", err) return fmt.Errorf("failed to initialize database: %s", err)
} }
if key != "" { if config.DataStoreEncryptionKey != key {
log.Debugf("update config with activity store key") log.Infof("update config with activity store key")
config.DataStoreEncryptionKey = key config.DataStoreEncryptionKey = key
err := updateMgmtConfig(mgmtConfig, config) err := updateMgmtConfig(mgmtConfig, config)
if err != nil { if err != nil {
@ -466,7 +466,7 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
} }
func updateMgmtConfig(path string, config *server.Config) error { func updateMgmtConfig(path string, config *server.Config) error {
return util.WriteJson(path, config) return util.DirectWriteJson(path, config)
} }
// OIDCConfigResponse used for parsing OIDC config response // OIDCConfigResponse used for parsing OIDC config response

View File

@ -62,12 +62,9 @@ type AccountManager interface {
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
MarkPATUsed(tokenID string) error MarkPATUsed(tokenID string) error
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
AccountExists(accountId string) (*bool, error)
GetPeerByKey(peerKey string) (*Peer, error)
GetPeers(accountID, userID string) ([]*Peer, error) GetPeers(accountID, userID string) ([]*Peer, error)
MarkPeerConnected(peerKey string, connected bool) error MarkPeerConnected(peerKey string, connected bool) error
DeletePeer(accountID, peerID, userID string) (*Peer, error) DeletePeer(accountID, peerID, userID string) error
GetPeerByIP(accountId string, peerIP string) (*Peer, error)
UpdatePeer(accountID, userID string, peer *Peer) (*Peer, error) UpdatePeer(accountID, userID string, peer *Peer) (*Peer, error)
GetNetworkMap(peerID string) (*NetworkMap, error) GetNetworkMap(peerID string) (*NetworkMap, error)
GetPeerNetwork(peerID string) (*Network, error) GetPeerNetwork(peerID string) (*Network, error)
@ -83,14 +80,13 @@ type AccountManager interface {
DeleteGroup(accountId, userId, groupID string) error DeleteGroup(accountId, userId, groupID string) error
ListGroups(accountId string) ([]*Group, error) ListGroups(accountId string) ([]*Group, error)
GroupAddPeer(accountId, groupID, peerID string) error GroupAddPeer(accountId, groupID, peerID string) error
GroupDeletePeer(accountId, groupID, peerKey string) error GroupDeletePeer(accountId, groupID, peerID string) error
GroupListPeers(accountId, groupID string) ([]*Peer, error)
GetPolicy(accountID, policyID, userID string) (*Policy, error) GetPolicy(accountID, policyID, userID string) (*Policy, error)
SavePolicy(accountID, userID string, policy *Policy) error SavePolicy(accountID, userID string, policy *Policy) error
DeletePolicy(accountID, policyID, userID string) error DeletePolicy(accountID, policyID, userID string) error
ListPolicies(accountID, userID string) ([]*Policy, error) ListPolicies(accountID, userID string) ([]*Policy, error)
GetRoute(accountID, routeID, userID string) (*route.Route, error) GetRoute(accountID, routeID, userID string) (*route.Route, error)
CreateRoute(accountID string, prefix, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
SaveRoute(accountID, userID string, route *route.Route) error SaveRoute(accountID, userID string, route *route.Route) error
DeleteRoute(accountID, routeID, userID string) error DeleteRoute(accountID, routeID, userID string) error
ListRoutes(accountID, userID string) ([]*route.Route, error) ListRoutes(accountID, userID string) ([]*route.Route, error)
@ -253,22 +249,39 @@ func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap looku
func (a *Account) getEnabledAndDisabledRoutesByPeer(peerID string) ([]*route.Route, []*route.Route) { func (a *Account) getEnabledAndDisabledRoutesByPeer(peerID string) ([]*route.Route, []*route.Route) {
var enabledRoutes []*route.Route var enabledRoutes []*route.Route
var disabledRoutes []*route.Route var disabledRoutes []*route.Route
for _, r := range a.Routes {
if r.Peer == peerID { takeRoute := func(r *route.Route, id string) {
// We need to set Peer.Key instead of Peer.ID because this object will be sent to agents as part of a network map.
// Ideally we should have a separate field for that, but fine for now.
peer := a.GetPeer(peerID) peer := a.GetPeer(peerID)
if peer == nil { if peer == nil {
log.Errorf("route %s has peer %s that doesn't exist under account %s", r.ID, peerID, a.Id) log.Errorf("route %s has peer %s that doesn't exist under account %s", r.ID, peerID, a.Id)
continue return
} }
raut := r.Copy()
raut.Peer = peer.Key
if r.Enabled { if r.Enabled {
enabledRoutes = append(enabledRoutes, raut) enabledRoutes = append(enabledRoutes, r)
return
}
disabledRoutes = append(disabledRoutes, r)
}
for _, r := range a.Routes {
if len(r.PeerGroups) != 0 {
for _, groupID := range r.PeerGroups {
group := a.GetGroup(groupID)
if group == nil {
log.Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id)
continue continue
} }
disabledRoutes = append(disabledRoutes, raut) for _, id := range group.Peers {
if id == peerID {
takeRoute(r, id)
break
}
}
}
}
if r.Peer == peerID {
takeRoute(r, peerID)
} }
} }
return enabledRoutes, disabledRoutes return enabledRoutes, disabledRoutes
@ -286,17 +299,6 @@ func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route {
return routes return routes
} }
// GetPeerByIP returns peer by it's IP if exists under account or nil otherwise
func (a *Account) GetPeerByIP(peerIP string) *Peer {
for _, peer := range a.Peers {
if peerIP == peer.IP.String() {
return peer
}
}
return nil
}
// GetGroup returns a group by ID if exists, nil otherwise // GetGroup returns a group by ID if exists, nil otherwise
func (a *Account) GetGroup(groupID string) *Group { func (a *Account) GetGroup(groupID string) *Group {
return a.Groups[groupID] return a.Groups[groupID]
@ -316,8 +318,51 @@ func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap {
} }
peersToConnect = append(peersToConnect, p) peersToConnect = append(peersToConnect, p)
} }
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
routesUpdate := a.getRoutesToSync(peerID, peersToConnect) routes := a.getRoutesToSync(peerID, peersToConnect)
takePeer := func(id string) (*Peer, bool) {
peer := a.GetPeer(id)
if peer == nil || peer.Meta.GoOS != "linux" {
return nil, false
}
return peer, true
}
// We need to set Peer.Key instead of Peer.ID because this object will be sent to agents as part of a network map.
// Ideally we should have a separate field for that, but fine for now.
var routesUpdate []*route.Route
seenPeers := make(map[string]bool)
for _, r := range routes {
if r.Peer != "" {
peer, valid := takePeer(r.Peer)
if !valid {
continue
}
rCopy := r.Copy()
rCopy.Peer = peer.Key // client expects the key
routesUpdate = append(routesUpdate, rCopy)
continue
}
for _, groupID := range r.PeerGroups {
if group := a.GetGroup(groupID); group != nil {
for _, peerId := range group.Peers {
peer, valid := takePeer(peerId)
if !valid {
continue
}
if _, ok := seenPeers[peer.ID]; !ok {
rCopy := r.Copy()
rCopy.ID = r.ID + ":" + peer.ID // we have to provide unit route id when distribute network map
rCopy.Peer = peer.Key // client expects the key
routesUpdate = append(routesUpdate, rCopy)
}
seenPeers[peer.ID] = true
}
}
}
}
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
dnsUpdate := nbdns.Config{ dnsUpdate := nbdns.Config{
@ -577,8 +622,8 @@ func (a *Account) Copy() *Account {
} }
routes := map[string]*route.Route{} routes := map[string]*route.Route{}
for id, route := range a.Routes { for id, r := range a.Routes {
routes[id] = route.Copy() routes[id] = r.Copy()
} }
nsGroups := map[string]*nbdns.NameServerGroup{} nsGroups := map[string]*nbdns.NameServerGroup{}
@ -928,6 +973,27 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
return err return err
} }
// If the Identity Provider does not support writing AppMetadata,
// in cases like this, we expect it to return all users in an "unset" field.
// We iterate over the users in the "unset" field, look up their AccountID in our store, and
// update their AppMetadata with the AccountID.
if unsetData, ok := userData[idp.UnsetAccountID]; ok {
for _, user := range unsetData {
accountID, err := am.Store.GetAccountByUser(user.ID)
if err == nil {
data := userData[accountID.Id]
if data == nil {
data = make([]*idp.UserData, 0, 1)
}
user.AppMetadata.WTAccountID = accountID.Id
userData[accountID.Id] = append(data, user)
}
}
}
delete(userData, idp.UnsetAccountID)
for accountID, users := range userData { for accountID, users := range userData {
err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration())) err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
if err != nil { if err != nil {
@ -994,7 +1060,36 @@ func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account
func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interface{}) ([]*idp.UserData, error) { func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interface{}) ([]*idp.UserData, error) {
log.Debugf("account %s not found in cache, reloading", accountID) log.Debugf("account %s not found in cache, reloading", accountID)
return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID)) accountIDString := fmt.Sprintf("%v", accountID)
account, err := am.Store.GetAccount(accountIDString)
if err != nil {
return nil, err
}
userData, err := am.idpManager.GetAccount(accountIDString)
if err != nil {
return nil, err
}
dataMap := make(map[string]*idp.UserData, len(userData))
for _, datum := range userData {
dataMap[datum.ID] = datum
}
matchedUserData := make([]*idp.UserData, 0)
for _, user := range account.Users {
if user.IsServiceUser {
continue
}
datum, ok := dataMap[user.Id]
if !ok {
log.Warnf("user %s not found in IDP", user.Id)
continue
}
matchedUserData = append(matchedUserData, datum)
}
return matchedUserData, nil
} }
func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountID string) (*idp.UserData, error) { func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountID string) (*idp.UserData, error) {
@ -1243,7 +1338,6 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
// MarkPATUsed marks a personal access token as used // MarkPATUsed marks a personal access token as used
func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error {
unlock := am.Store.AcquireGlobalLock()
user, err := am.Store.GetUserByTokenID(tokenID) user, err := am.Store.GetUserByTokenID(tokenID)
if err != nil { if err != nil {
@ -1255,8 +1349,7 @@ func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error {
return err return err
} }
unlock() unlock := am.Store.AcquireAccountLock(account.Id)
unlock = am.Store.AcquireAccountLock(account.Id)
defer unlock() defer unlock()
account, err = am.Store.GetAccountByUser(user.Id) account, err = am.Store.GetAccountByUser(user.Id)
@ -1383,9 +1476,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
if err := am.Store.SaveAccount(account); err != nil { if err := am.Store.SaveAccount(account); err != nil {
log.Errorf("failed to save account: %v", err) log.Errorf("failed to save account: %v", err)
} else { } else {
if err := am.updateAccountPeers(account); err != nil { am.updateAccountPeers(account)
log.Errorf("failed updating account peers while updating user %s", account.Id)
}
for _, g := range addNewGroups { for _, g := range addNewGroups {
if group := account.GetGroup(g); group != nil { if group := account.GetGroup(g); group != nil {
am.storeEvent(user.Id, user.Id, account.Id, activity.GroupAddedToUser, am.storeEvent(user.Id, user.Id, account.Id, activity.GroupAddedToUser,
@ -1496,26 +1587,6 @@ func isDomainValid(domain string) bool {
return re.Match([]byte(domain)) return re.Match([]byte(domain))
} }
// AccountExists checks whether account exists (returns true) or not (returns false)
func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
var res bool
_, err := am.Store.GetAccount(accountID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
res = false
return &res, nil
} else {
return nil, err
}
}
res = true
return &res, nil
}
// GetDNSDomain returns the configured dnsDomain // GetDNSDomain returns the configured dnsDomain
func (am *DefaultAccountManager) GetDNSDomain() string { func (am *DefaultAccountManager) GetDNSDomain() string {
return am.dnsDomain return am.dnsDomain

View File

@ -706,30 +706,6 @@ func createAccount(am *DefaultAccountManager, accountID, userID, domain string)
return account, nil return account, nil
} }
func TestAccountManager_AccountExists(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
expectedId := "test_account"
userId := "account_creator"
_, err = createAccount(manager, expectedId, userId, "")
if err != nil {
t.Fatal(err)
}
exists, err := manager.AccountExists(expectedId)
if err != nil {
t.Fatal(err)
}
if !*exists {
t.Errorf("expected account to exist after creation, got false")
}
}
func TestAccountManager_GetAccount(t *testing.T) { func TestAccountManager_GetAccount(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
if err != nil { if err != nil {
@ -1062,7 +1038,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
} }
}() }()
if _, err := manager.DeletePeer(account.Id, peer3.ID, userID); err != nil { if err := manager.DeletePeer(account.Id, peer3.ID, userID); err != nil {
t.Errorf("delete peer: %v", err) t.Errorf("delete peer: %v", err)
return return
} }
@ -1129,7 +1105,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
return return
} }
_, err = manager.DeletePeer(account.Id, peerKey, userID) err = manager.DeletePeer(account.Id, peerKey, userID)
if err != nil { if err != nil {
return return
} }
@ -1386,6 +1362,7 @@ func TestAccount_Copy(t *testing.T) {
Routes: map[string]*route.Route{ Routes: map[string]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
PeerGroups: []string{},
Groups: []string{"group1"}, Groups: []string{"group1"},
}, },
}, },

View File

@ -18,7 +18,9 @@ type Event struct {
ID uint64 ID uint64
// InitiatorID is the ID of an object that initiated the event (e.g., a user) // InitiatorID is the ID of an object that initiated the event (e.g., a user)
InitiatorID string InitiatorID string
// InitiatorEmail is the email address of an object that initiated the event. This will be set on deleted users only // InitiatorName is the name of an object that initiated the event.
InitiatorName string
// InitiatorEmail is the email address of an object that initiated the event.
InitiatorEmail string InitiatorEmail string
// TargetID is the ID of an object that was effected by the event (e.g., a peer) // TargetID is the ID of an object that was effected by the event (e.g., a peer)
TargetID string TargetID string
@ -42,6 +44,7 @@ func (e *Event) Copy() *Event {
Activity: e.Activity, Activity: e.Activity,
ID: e.ID, ID: e.ID,
InitiatorID: e.InitiatorID, InitiatorID: e.InitiatorID,
InitiatorName: e.InitiatorName,
InitiatorEmail: e.InitiatorEmail, InitiatorEmail: e.InitiatorEmail,
TargetID: e.TargetID, TargetID: e.TargetID,
AccountID: e.AccountID, AccountID: e.AccountID,

View File

@ -11,7 +11,7 @@ import (
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05} var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
type EmailEncrypt struct { type FieldEncrypt struct {
block cipher.Block block cipher.Block
} }
@ -25,7 +25,7 @@ func GenerateKey() (string, error) {
return readableKey, nil return readableKey, nil
} }
func NewEmailEncrypt(key string) (*EmailEncrypt, error) { func NewFieldEncrypt(key string) (*FieldEncrypt, error) {
binKey, err := base64.StdEncoding.DecodeString(key) binKey, err := base64.StdEncoding.DecodeString(key)
if err != nil { if err != nil {
return nil, err return nil, err
@ -35,14 +35,14 @@ func NewEmailEncrypt(key string) (*EmailEncrypt, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
ec := &EmailEncrypt{ ec := &FieldEncrypt{
block: block, block: block,
} }
return ec, nil return ec, nil
} }
func (ec *EmailEncrypt) Encrypt(payload string) string { func (ec *FieldEncrypt) Encrypt(payload string) string {
plainText := pkcs5Padding([]byte(payload)) plainText := pkcs5Padding([]byte(payload))
cipherText := make([]byte, len(plainText)) cipherText := make([]byte, len(plainText))
cbc := cipher.NewCBCEncrypter(ec.block, iv) cbc := cipher.NewCBCEncrypter(ec.block, iv)
@ -50,7 +50,7 @@ func (ec *EmailEncrypt) Encrypt(payload string) string {
return base64.StdEncoding.EncodeToString(cipherText) return base64.StdEncoding.EncodeToString(cipherText)
} }
func (ec *EmailEncrypt) Decrypt(data string) (string, error) { func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data) cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil { if err != nil {
return "", err return "", err

View File

@ -10,7 +10,7 @@ func TestGenerateKey(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to generate key: %s", err) t.Fatalf("failed to generate key: %s", err)
} }
ee, err := NewEmailEncrypt(key) ee, err := NewFieldEncrypt(key)
if err != nil { if err != nil {
t.Fatalf("failed to init email encryption: %s", err) t.Fatalf("failed to init email encryption: %s", err)
} }
@ -36,7 +36,7 @@ func TestCorruptKey(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to generate key: %s", err) t.Fatalf("failed to generate key: %s", err)
} }
ee, err := NewEmailEncrypt(key) ee, err := NewFieldEncrypt(key)
if err != nil { if err != nil {
t.Fatalf("failed to init email encryption: %s", err) t.Fatalf("failed to init email encryption: %s", err)
} }
@ -51,13 +51,13 @@ func TestCorruptKey(t *testing.T) {
t.Fatalf("failed to generate key: %s", err) t.Fatalf("failed to generate key: %s", err)
} }
ee, err = NewEmailEncrypt(newKey) ee, err = NewFieldEncrypt(newKey)
if err != nil { if err != nil {
t.Fatalf("failed to init email encryption: %s", err) t.Fatalf("failed to init email encryption: %s", err)
} }
res, err := ee.Decrypt(encrypted) res, _ := ee.Decrypt(encrypted)
if err == nil || res == testData { if res == testData {
t.Fatalf("incorrect decryption, the result is: %s", res) t.Fatalf("incorrect decryption, the result is: %s", res)
} }
} }

View File

@ -7,7 +7,7 @@ import (
"path/filepath" "path/filepath"
"time" "time"
_ "github.com/mattn/go-sqlite3" // sqlite driver _ "github.com/mattn/go-sqlite3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
@ -25,16 +25,16 @@ const (
"meta TEXT," + "meta TEXT," +
" target_id TEXT);" " target_id TEXT);"
creatTableAccountEmailQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL);` creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`
selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.email as "initiator_email", target_id, t.email as "target_email", account_id, meta selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
FROM events FROM events
LEFT JOIN deleted_users i ON events.initiator_id = i.id LEFT JOIN deleted_users i ON events.initiator_id = i.id
LEFT JOIN deleted_users t ON events.target_id = t.id LEFT JOIN deleted_users t ON events.target_id = t.id
WHERE account_id = ? WHERE account_id = ?
ORDER BY timestamp DESC LIMIT ? OFFSET ?;` ORDER BY timestamp DESC LIMIT ? OFFSET ?;`
selectAscQuery = `SELECT events.id, activity, timestamp, initiator_id, i.email as "initiator_email", target_id, t.email as "target_email", account_id, meta selectAscQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
FROM events FROM events
LEFT JOIN deleted_users i ON events.initiator_id = i.id LEFT JOIN deleted_users i ON events.initiator_id = i.id
LEFT JOIN deleted_users t ON events.target_id = t.id LEFT JOIN deleted_users t ON events.target_id = t.id
@ -44,13 +44,13 @@ const (
insertQuery = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " + insertQuery = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " +
"VALUES(?, ?, ?, ?, ?, ?)" "VALUES(?, ?, ?, ?, ?, ?)"
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email) VALUES(?, ?)` insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`
) )
// Store is the implementation of the activity.Store interface backed by SQLite // Store is the implementation of the activity.Store interface backed by SQLite
type Store struct { type Store struct {
db *sql.DB db *sql.DB
emailEncrypt *EmailEncrypt fieldEncrypt *FieldEncrypt
insertStatement *sql.Stmt insertStatement *sql.Stmt
selectAscStatement *sql.Stmt selectAscStatement *sql.Stmt
@ -66,49 +66,63 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
return nil, err return nil, err
} }
crypt, err := NewEmailEncrypt(encryptionKey) crypt, err := NewFieldEncrypt(encryptionKey)
if err != nil { if err != nil {
_ = db.Close()
return nil, err return nil, err
} }
_, err = db.Exec(createTableQuery) _, err = db.Exec(createTableQuery)
if err != nil { if err != nil {
_ = db.Close()
return nil, err return nil, err
} }
_, err = db.Exec(creatTableAccountEmailQuery) _, err = db.Exec(creatTableDeletedUsersQuery)
if err != nil { if err != nil {
_ = db.Close()
return nil, err
}
err = updateDeletedUsersTable(db)
if err != nil {
_ = db.Close()
return nil, err return nil, err
} }
insertStmt, err := db.Prepare(insertQuery) insertStmt, err := db.Prepare(insertQuery)
if err != nil { if err != nil {
_ = db.Close()
return nil, err return nil, err
} }
selectDescStmt, err := db.Prepare(selectDescQuery) selectDescStmt, err := db.Prepare(selectDescQuery)
if err != nil { if err != nil {
_ = db.Close()
return nil, err return nil, err
} }
selectAscStmt, err := db.Prepare(selectAscQuery) selectAscStmt, err := db.Prepare(selectAscQuery)
if err != nil { if err != nil {
_ = db.Close()
return nil, err return nil, err
} }
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery) deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
if err != nil { if err != nil {
_ = db.Close()
return nil, err return nil, err
} }
s := &Store{ s := &Store{
db: db, db: db,
emailEncrypt: crypt, fieldEncrypt: crypt,
insertStatement: insertStmt, insertStatement: insertStmt,
selectDescStatement: selectDescStmt, selectDescStatement: selectDescStmt,
selectAscStatement: selectAscStmt, selectAscStatement: selectAscStmt,
deleteUserStmt: deleteUserStmt, deleteUserStmt: deleteUserStmt,
} }
return s, nil return s, nil
} }
@ -119,12 +133,14 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
var operation activity.Activity var operation activity.Activity
var timestamp time.Time var timestamp time.Time
var initiator string var initiator string
var initiatorName *string
var initiatorEmail *string var initiatorEmail *string
var target string var target string
var targetUserName *string
var targetEmail *string var targetEmail *string
var account string var account string
var jsonMeta string var jsonMeta string
err := result.Scan(&id, &operation, &timestamp, &initiator, &initiatorEmail, &target, &targetEmail, &account, &jsonMeta) err := result.Scan(&id, &operation, &timestamp, &initiator, &initiatorName, &initiatorEmail, &target, &targetUserName, &targetEmail, &account, &jsonMeta)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -137,8 +153,18 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
} }
} }
if targetUserName != nil {
name, err := store.fieldEncrypt.Decrypt(*targetUserName)
if err != nil {
log.Errorf("failed to decrypt username for target id: %s", target)
meta["username"] = ""
} else {
meta["username"] = name
}
}
if targetEmail != nil { if targetEmail != nil {
email, err := store.emailEncrypt.Decrypt(*targetEmail) email, err := store.fieldEncrypt.Decrypt(*targetEmail)
if err != nil { if err != nil {
log.Errorf("failed to decrypt email address for target id: %s", target) log.Errorf("failed to decrypt email address for target id: %s", target)
meta["email"] = "" meta["email"] = ""
@ -157,8 +183,17 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
Meta: meta, Meta: meta,
} }
if initiatorName != nil {
name, err := store.fieldEncrypt.Decrypt(*initiatorName)
if err != nil {
log.Errorf("failed to decrypt username of initiator: %s", initiator)
} else {
event.InitiatorName = name
}
}
if initiatorEmail != nil { if initiatorEmail != nil {
email, err := store.emailEncrypt.Decrypt(*initiatorEmail) email, err := store.fieldEncrypt.Decrypt(*initiatorEmail)
if err != nil { if err != nil {
log.Errorf("failed to decrypt email address of initiator: %s", initiator) log.Errorf("failed to decrypt email address of initiator: %s", initiator)
} else { } else {
@ -191,7 +226,7 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([
// Save an event in the SQLite events table end encrypt the "email" element in meta map // Save an event in the SQLite events table end encrypt the "email" element in meta map
func (store *Store) Save(event *activity.Event) (*activity.Event, error) { func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
var jsonMeta string var jsonMeta string
meta, err := store.saveDeletedUserEmailInEncrypted(event) meta, err := store.saveDeletedUserEmailAndNameInEncrypted(event)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -219,26 +254,31 @@ func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
return eventCopy, nil return eventCopy, nil
} }
// saveDeletedUserEmailInEncrypted if the meta contains email then store it in encrypted way and delete this item from // saveDeletedUserEmailAndNameInEncrypted if the meta contains email and name then store it in encrypted way and delete
// meta map // this item from meta map
func (store *Store) saveDeletedUserEmailInEncrypted(event *activity.Event) (map[string]any, error) { func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event) (map[string]any, error) {
email, ok := event.Meta["email"] email, ok := event.Meta["email"]
if !ok { if !ok {
return event.Meta, nil return event.Meta, nil
} }
delete(event.Meta, "email") name, ok := event.Meta["name"]
if !ok {
return event.Meta, nil
}
encrypted := store.emailEncrypt.Encrypt(fmt.Sprintf("%s", email)) encryptedEmail := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
_, err := store.deleteUserStmt.Exec(event.TargetID, encrypted) encryptedName := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
_, err := store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(event.Meta) == 1 { if len(event.Meta) == 2 {
return nil, nil // nolint return nil, nil // nolint
} }
delete(event.Meta, "email") delete(event.Meta, "email")
delete(event.Meta, "name")
return event.Meta, nil return event.Meta, nil
} }
@ -249,3 +289,44 @@ func (store *Store) Close() error {
} }
return nil return nil
} }
func updateDeletedUsersTable(db *sql.DB) error {
log.Debugf("check deleted_users table version")
rows, err := db.Query(`PRAGMA table_info(deleted_users);`)
if err != nil {
return err
}
defer rows.Close()
found := false
for rows.Next() {
var (
cid int
name string
dataType string
notNull int
dfltVal sql.NullString
pk int
)
err := rows.Scan(&cid, &name, &dataType, &notNull, &dfltVal, &pk)
if err != nil {
return err
}
if name == "name" {
found = true
break
}
}
err = rows.Err()
if err != nil {
return err
}
if found {
return nil
}
log.Debugf("update delted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
return err
}

View File

@ -122,7 +122,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string
am.storeEvent(userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) am.storeEvent(userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
} }
return am.updateAccountPeers(account) am.updateAccountPeers(account)
return nil
} }
func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig { func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {

View File

@ -162,7 +162,7 @@ func (e *EphemeralManager) cleanup() {
for id, p := range deletePeers { for id, p := range deletePeers {
log.Debugf("delete ephemeral peer: %s", id) log.Debugf("delete ephemeral peer: %s", id)
_, err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator) err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator)
if err != nil { if err != nil {
log.Tracef("failed to delete ephemeral peer: %s", err) log.Tracef("failed to delete ephemeral peer: %s", err)
} }

View File

@ -29,9 +29,9 @@ type MocAccountManager struct {
store *MockStore store *MockStore
} }
func (a MocAccountManager) DeletePeer(accountID, peerID, userID string) (*Peer, error) { func (a MocAccountManager) DeletePeer(accountID, peerID, userID string) error {
delete(a.store.account.Peers, peerID) delete(a.store.account.Peers, peerID)
return nil, nil //nolint:nilnil return nil //nolint:nil
} }
func TestNewManager(t *testing.T) { func TestNewManager(t *testing.T) {

View File

@ -84,10 +84,7 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *G
return err return err
} }
err = am.updateAccountPeers(account) am.updateAccountPeers(account)
if err != nil {
return err
}
// the following snippet tracks the activity and stores the group events in the event store. // the following snippet tracks the activity and stores the group events in the event store.
// It has to happen after all the operations have been successfully performed. // It has to happen after all the operations have been successfully performed.
@ -229,7 +226,9 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
am.storeEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta()) am.storeEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
return am.updateAccountPeers(account) am.updateAccountPeers(account)
return nil
} }
// ListGroups objects of the peers // ListGroups objects of the peers
@ -281,11 +280,13 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string)
return err return err
} }
return am.updateAccountPeers(account) am.updateAccountPeers(account)
return nil
} }
// GroupDeletePeer removes peer from the group // GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey string) error { func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@ -301,7 +302,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str
account.Network.IncSerial() account.Network.IncSerial()
for i, itemID := range group.Peers { for i, itemID := range group.Peers {
if itemID == peerKey { if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(account); err != nil { if err := am.Store.SaveAccount(account); err != nil {
return err return err
@ -309,31 +310,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str
} }
} }
return am.updateAccountPeers(account) am.updateAccountPeers(account)
}
// GroupListPeers returns list of the peers from the group return nil
func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*Peer, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(status.NotFound, "account not found")
}
group, ok := account.Groups[groupID]
if !ok {
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
peers := make([]*Peer, 0, len(account.Groups))
for _, peerID := range group.Peers {
p, ok := account.Peers[peerID]
if ok {
peers = append(peers, p)
}
}
return peers, nil
} }

View File

@ -159,6 +159,11 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
select { select {
// condition when there are some updates // condition when there are some updates
case update, open := <-updates: case update, open := <-updates:
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
}
if !open { if !open {
log.Debugf("updates channel for peer %s was closed", peerKey.String()) log.Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(peer) s.cancelPeerRoutines(peer)

View File

@ -745,7 +745,13 @@ components:
type: boolean type: boolean
example: true example: true
peer: peer:
description: Peer Identifier associated with route description: Peer Identifier associated with route. This property can not be set together with `peer_groups`
type: string
example: chacbco6lnnbn6cg5s91
peer_groups:
description: Peers Group Identifier associated with route. This property can not be set together with `peer`
type: array
items:
type: string type: string
example: chacbco6lnnbn6cg5s91 example: chacbco6lnnbn6cg5s91
network: network:
@ -773,7 +779,9 @@ components:
- description - description
- network_id - network_id
- enabled - enabled
- peer # Only one property has to be set
#- peer
#- peer_groups
- network - network
- metric - metric
- masquerade - masquerade
@ -922,6 +930,10 @@ components:
description: The ID of the initiator of the event. E.g., an ID of a user that triggered the event. description: The ID of the initiator of the event. E.g., an ID of a user that triggered the event.
type: string type: string
example: google-oauth2|123456789012345678901 example: google-oauth2|123456789012345678901
initiator_name:
description: The name of the initiator of the event.
type: string
example: John Doe
initiator_email: initiator_email:
description: The e-mail address of the initiator of the event. E.g., an e-mail of a user that triggered the event. description: The e-mail address of the initiator of the event. E.g., an e-mail of a user that triggered the event.
type: string type: string
@ -942,6 +954,7 @@ components:
- activity - activity
- activity_code - activity_code
- initiator_id - initiator_id
- initiator_name
- initiator_email - initiator_email
- target_id - target_id
- meta - meta
@ -1139,8 +1152,8 @@ paths:
'500': '500':
"$ref": "#/components/responses/internal_error" "$ref": "#/components/responses/internal_error"
delete: delete:
summary: Block a User summary: Delete a User
description: This method blocks a user from accessing the system, but leaves the IDP user intact. description: This method removes a user from accessing the system. For this leaves the IDP user intact unless the `--user-delete-from-idp` is passed to management startup.
tags: [ Users ] tags: [ Users ]
security: security:
- BearerAuth: [ ] - BearerAuth: [ ]

View File

@ -1,6 +1,6 @@
// Package api provides primitives to interact with the openapi HTTP API. // Package api provides primitives to interact with the openapi HTTP API.
// //
// Code generated by github.com/deepmap/oapi-codegen version v1.11.1-0.20220912230023-4a1477f6a8ba DO NOT EDIT. // Code generated by github.com/deepmap/oapi-codegen version v1.15.0 DO NOT EDIT.
package api package api
import ( import (
@ -170,6 +170,9 @@ type Event struct {
// InitiatorId The ID of the initiator of the event. E.g., an ID of a user that triggered the event. // InitiatorId The ID of the initiator of the event. E.g., an ID of a user that triggered the event.
InitiatorId string `json:"initiator_id"` InitiatorId string `json:"initiator_id"`
// InitiatorName The name of the initiator of the event.
InitiatorName string `json:"initiator_name"`
// Meta The metadata of the event // Meta The metadata of the event
Meta map[string]string `json:"meta"` Meta map[string]string `json:"meta"`
@ -596,8 +599,11 @@ type Route struct {
// NetworkType Network type indicating if it is IPv4 or IPv6 // NetworkType Network type indicating if it is IPv4 or IPv6
NetworkType string `json:"network_type"` NetworkType string `json:"network_type"`
// Peer Peer Identifier associated with route // Peer Peer Identifier associated with route. This property can not be set together with `peer_groups`
Peer string `json:"peer"` Peer *string `json:"peer,omitempty"`
// PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer`
PeerGroups *[]string `json:"peer_groups,omitempty"`
} }
// RouteRequest defines model for RouteRequest. // RouteRequest defines model for RouteRequest.
@ -623,8 +629,11 @@ type RouteRequest struct {
// NetworkId Route network identifier, to group HA routes // NetworkId Route network identifier, to group HA routes
NetworkId string `json:"network_id"` NetworkId string `json:"network_id"`
// Peer Peer Identifier associated with route // Peer Peer Identifier associated with route. This property can not be set together with `peer_groups`
Peer string `json:"peer"` Peer *string `json:"peer,omitempty"`
// PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer`
PeerGroups *[]string `json:"peer_groups,omitempty"`
} }
// Rule defines model for Rule. // Rule defines model for Rule.

View File

@ -50,7 +50,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
events[i] = toEventResponse(e) events[i] = toEventResponse(e)
} }
err = h.fillEventsWithInitiatorEmail(events, account.Id, user.Id) err = h.fillEventsWithUserInfo(events, account.Id, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
return return
@ -59,8 +59,8 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(w, events) util.WriteJSONObject(w, events)
} }
func (h *EventsHandler) fillEventsWithInitiatorEmail(events []*api.Event, accountId, userId string) error { func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, userId string) error {
// build email map based on users // build email, name maps based on users
userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId) userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId)
if err != nil { if err != nil {
log.Errorf("failed to get users from account: %s", err) log.Errorf("failed to get users from account: %s", err)
@ -68,19 +68,39 @@ func (h *EventsHandler) fillEventsWithInitiatorEmail(events []*api.Event, accoun
} }
emails := make(map[string]string) emails := make(map[string]string)
names := make(map[string]string)
for _, ui := range userInfos { for _, ui := range userInfos {
emails[ui.ID] = ui.Email emails[ui.ID] = ui.Email
names[ui.ID] = ui.Name
} }
// fill event with email of initiator
var ok bool var ok bool
for _, event := range events { for _, event := range events {
// fill initiator
if event.InitiatorEmail == "" { if event.InitiatorEmail == "" {
event.InitiatorEmail, ok = emails[event.InitiatorId] event.InitiatorEmail, ok = emails[event.InitiatorId]
if !ok { if !ok {
log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId) log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
} }
} }
if event.InitiatorName == "" {
// here to allowed to be empty because in the first release we did not store the name
event.InitiatorName = names[event.InitiatorId]
}
// fill target meta
email, ok := emails[event.TargetId]
if !ok {
continue
}
event.Meta["email"] = email
username, ok := names[event.TargetId]
if !ok {
continue
}
event.Meta["username"] = username
} }
return nil return nil
} }
@ -95,6 +115,7 @@ func toEventResponse(event *activity.Event) *api.Event {
e := &api.Event{ e := &api.Event{
Id: fmt.Sprint(event.ID), Id: fmt.Sprint(event.ID),
InitiatorId: event.InitiatorID, InitiatorId: event.InitiatorID,
InitiatorName: event.InitiatorName,
InitiatorEmail: event.InitiatorEmail, InitiatorEmail: event.InitiatorEmail,
Activity: event.Activity.Message(), Activity: event.Activity.Message(),
ActivityCode: api.EventActivityCode(event.Activity.StringCode()), ActivityCode: api.EventActivityCode(event.Activity.StringCode()),

View File

@ -53,14 +53,6 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle
Issued: server.GroupIssuedAPI, Issued: server.GroupIssuedAPI,
}, nil }, nil
}, },
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
for _, peer := range TestPeers {
if peer.IP.String() == peerIP {
return peer, nil
}
}
return nil, fmt.Errorf("peer not found")
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,

View File

@ -61,7 +61,7 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe
} }
func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w http.ResponseWriter) { func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w http.ResponseWriter) {
_, err := h.accountManager.DeletePeer(accountID, peerID, userID) err := h.accountManager.DeletePeer(accountID, peerID, userID)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
return return

View File

@ -82,7 +82,33 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
return return
} }
newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix.String(), req.Peer, req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id) peerId := ""
if req.Peer != nil {
peerId = *req.Peer
}
peerGroupIds := []string{}
if req.PeerGroups != nil {
peerGroupIds = *req.PeerGroups
}
if (peerId != "" && len(peerGroupIds) > 0) || (peerId == "" && len(peerGroupIds) == 0) {
util.WriteError(status.Errorf(status.InvalidArgument, "only one peer or peer_groups should be provided"), w)
return
}
// do not allow non Linux peers
if peer := account.GetPeer(peerId); peer != nil {
if peer.Meta.GoOS != "linux" {
util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
return
}
}
newRoute, err := h.accountManager.CreateRoute(
account.Id, newPrefix.String(), peerId, peerGroupIds,
req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id,
)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
return return
@ -135,19 +161,49 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
return return
} }
if req.Peer != nil && req.PeerGroups != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "only peer or peers_group should be provided"), w)
return
}
if req.Peer == nil && req.PeerGroups == nil {
util.WriteError(status.Errorf(status.InvalidArgument, "either peer or peers_group should be provided"), w)
return
}
peerID := ""
if req.Peer != nil {
peerID = *req.Peer
}
// do not allow non Linux peers
if peer := account.GetPeer(peerID); peer != nil {
if peer.Meta.GoOS != "linux" {
util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
return
}
}
newRoute := &route.Route{ newRoute := &route.Route{
ID: routeID, ID: routeID,
Network: newPrefix, Network: newPrefix,
NetID: req.NetworkId, NetID: req.NetworkId,
NetworkType: prefixType, NetworkType: prefixType,
Masquerade: req.Masquerade, Masquerade: req.Masquerade,
Peer: req.Peer,
Metric: req.Metric, Metric: req.Metric,
Description: req.Description, Description: req.Description,
Enabled: req.Enabled, Enabled: req.Enabled,
Groups: req.Groups, Groups: req.Groups,
} }
if req.Peer != nil {
newRoute.Peer = peerID
}
if req.PeerGroups != nil {
newRoute.PeerGroups = *req.PeerGroups
}
err = h.accountManager.SaveRoute(account.Id, user.Id, newRoute) err = h.accountManager.SaveRoute(account.Id, user.Id, newRoute)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -208,16 +264,21 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
} }
func toRouteResponse(serverRoute *route.Route) *api.Route { func toRouteResponse(serverRoute *route.Route) *api.Route {
return &api.Route{ route := &api.Route{
Id: serverRoute.ID, Id: serverRoute.ID,
Description: serverRoute.Description, Description: serverRoute.Description,
NetworkId: serverRoute.NetID, NetworkId: serverRoute.NetID,
Enabled: serverRoute.Enabled, Enabled: serverRoute.Enabled,
Peer: serverRoute.Peer, Peer: &serverRoute.Peer,
Network: serverRoute.Network.String(), Network: serverRoute.Network.String(),
NetworkType: serverRoute.NetworkType.String(), NetworkType: serverRoute.NetworkType.String(),
Masquerade: serverRoute.Masquerade, Masquerade: serverRoute.Masquerade,
Metric: serverRoute.Metric, Metric: serverRoute.Metric,
Groups: serverRoute.Groups, Groups: serverRoute.Groups,
} }
if len(serverRoute.PeerGroups) > 0 {
route.PeerGroups = &serverRoute.PeerGroups
}
return route
} }

View File

@ -24,15 +24,22 @@ import (
const ( const (
existingRouteID = "existingRouteID" existingRouteID = "existingRouteID"
existingRouteID2 = "existingRouteID2" // for peer_groups test
notFoundRouteID = "notFoundRouteID" notFoundRouteID = "notFoundRouteID"
existingPeerIP = "100.64.0.100" existingPeerIP1 = "100.64.0.100"
existingPeerID = "peer-id" existingPeerIP2 = "100.64.0.101"
notFoundPeerID = "nonExistingPeer" notFoundPeerID = "nonExistingPeer"
existingPeerKey = "existingPeerKey" existingPeerKey = "existingPeerKey"
nonLinuxExistingPeerKey = "darwinExistingPeerKey"
testAccountID = "test_id" testAccountID = "test_id"
existingGroupID = "testGroup" existingGroupID = "testGroup"
notFoundGroupID = "nonExistingGroup"
) )
var emptyString = ""
var existingPeerID = "peer-id"
var nonLinuxExistingPeerID = "darwin-peer-id"
var baseExistingRoute = &route.Route{ var baseExistingRoute = &route.Route{
ID: existingRouteID, ID: existingRouteID,
Description: "base route", Description: "base route",
@ -51,8 +58,19 @@ var testingAccount = &server.Account{
Peers: map[string]*server.Peer{ Peers: map[string]*server.Peer{
existingPeerID: { existingPeerID: {
Key: existingPeerKey, Key: existingPeerKey,
IP: netip.MustParseAddr(existingPeerIP).AsSlice(), IP: netip.MustParseAddr(existingPeerIP1).AsSlice(),
ID: existingPeerID, ID: existingPeerID,
Meta: server.PeerSystemMeta{
GoOS: "linux",
},
},
nonLinuxExistingPeerID: {
Key: nonLinuxExistingPeerID,
IP: netip.MustParseAddr(existingPeerIP2).AsSlice(),
ID: nonLinuxExistingPeerID,
Meta: server.PeerSystemMeta{
GoOS: "darwin",
},
}, },
}, },
Users: map[string]*server.User{ Users: map[string]*server.User{
@ -67,17 +85,26 @@ func initRoutesTestData() *RoutesHandler {
if routeID == existingRouteID { if routeID == existingRouteID {
return baseExistingRoute, nil return baseExistingRoute, nil
} }
if routeID == existingRouteID2 {
route := baseExistingRoute.Copy()
route.PeerGroups = []string{existingGroupID}
return route, nil
}
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
}, },
CreateRouteFunc: func(accountID string, network, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) { CreateRouteFunc: func(accountID, network, peerID string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, _ string) (*route.Route, error) {
if peerID == notFoundPeerID { if peerID == notFoundPeerID {
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID {
return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0])
}
networkType, p, _ := route.ParseNetwork(network) networkType, p, _ := route.ParseNetwork(network)
return &route.Route{ return &route.Route{
ID: existingRouteID, ID: existingRouteID,
NetID: netID, NetID: netID,
Peer: peerID, Peer: peerID,
PeerGroups: peerGroups,
Network: p, Network: p,
NetworkType: networkType, NetworkType: networkType,
Description: description, Description: description,
@ -98,15 +125,6 @@ func initRoutesTestData() *RoutesHandler {
} }
return nil return nil
}, },
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
if peerIP != existingPeerID {
return nil, status.Errorf(status.NotFound, "Peer with ID %s not found", peerIP)
}
return &server.Peer{
Key: existingPeerKey,
IP: netip.MustParseAddr(existingPeerID).AsSlice(),
}, nil
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingAccount, testingAccount.Users["test_user"], nil return testingAccount, testingAccount.Users["test_user"], nil
}, },
@ -124,6 +142,9 @@ func initRoutesTestData() *RoutesHandler {
} }
func TestRoutesHandlers(t *testing.T) { func TestRoutesHandlers(t *testing.T) {
baseExistingRouteWithPeerGroups := baseExistingRoute.Copy()
baseExistingRouteWithPeerGroups.PeerGroups = []string{existingGroupID}
tt := []struct { tt := []struct {
name string name string
expectedStatus int expectedStatus int
@ -147,6 +168,14 @@ func TestRoutesHandlers(t *testing.T) {
requestPath: "/api/routes/" + notFoundRouteID, requestPath: "/api/routes/" + notFoundRouteID,
expectedStatus: http.StatusNotFound, expectedStatus: http.StatusNotFound,
}, },
{
name: "Get Existing Route with Peer Groups",
requestType: http.MethodGet,
requestPath: "/api/routes/" + existingRouteID2,
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRoute: toRouteResponse(baseExistingRouteWithPeerGroups),
},
{ {
name: "Delete Existing Route", name: "Delete Existing Route",
requestType: http.MethodDelete, requestType: http.MethodDelete,
@ -173,13 +202,21 @@ func TestRoutesHandlers(t *testing.T) {
Description: "Post", Description: "Post",
NetworkId: "awesomeNet", NetworkId: "awesomeNet",
Network: "192.168.0.0/16", Network: "192.168.0.0/16",
Peer: existingPeerID, Peer: &existingPeerID,
NetworkType: route.IPv4NetworkString, NetworkType: route.IPv4NetworkString,
Masquerade: false, Masquerade: false,
Enabled: false, Enabled: false,
Groups: []string{existingGroupID}, Groups: []string{existingGroupID},
}, },
}, },
{
name: "POST Non Linux Peer",
requestType: http.MethodPost,
requestPath: "/api/routes",
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"]}", nonLinuxExistingPeerID, existingGroupID)),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{ {
name: "POST Not Found Peer", name: "POST Not Found Peer",
requestType: http.MethodPost, requestType: http.MethodPost,
@ -204,6 +241,24 @@ func TestRoutesHandlers(t *testing.T) {
expectedStatus: http.StatusUnprocessableEntity, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{
name: "POST UnprocessableEntity when both peer and peer_groups are provided",
requestType: http.MethodPost,
requestPath: "/api/routes",
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"peer\":\"%s\",\"peer_groups\":[\"%s\"],\"groups\":[\"%s\"]}", existingPeerID, existingGroupID, existingGroupID))),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "POST UnprocessableEntity when no peer and peer_groups are provided",
requestType: http.MethodPost,
requestPath: "/api/routes",
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"groups\":[\"%s\"]}", existingPeerID))),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{ {
name: "PUT OK", name: "PUT OK",
requestType: http.MethodPut, requestType: http.MethodPut,
@ -216,7 +271,27 @@ func TestRoutesHandlers(t *testing.T) {
Description: "Post", Description: "Post",
NetworkId: "awesomeNet", NetworkId: "awesomeNet",
Network: "192.168.0.0/16", Network: "192.168.0.0/16",
Peer: existingPeerID, Peer: &existingPeerID,
NetworkType: route.IPv4NetworkString,
Masquerade: false,
Enabled: false,
Groups: []string{existingGroupID},
},
},
{
name: "PUT OK when peer_groups provided",
requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"peer_groups\":[\"%s\"],\"groups\":[\"%s\"]}", existingGroupID, existingGroupID)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRoute: &api.Route{
Id: existingRouteID,
Description: "Post",
NetworkId: "awesomeNet",
Network: "192.168.0.0/16",
Peer: &emptyString,
PeerGroups: &[]string{existingGroupID},
NetworkType: route.IPv4NetworkString, NetworkType: route.IPv4NetworkString,
Masquerade: false, Masquerade: false,
Enabled: false, Enabled: false,
@ -239,6 +314,14 @@ func TestRoutesHandlers(t *testing.T) {
expectedStatus: http.StatusUnprocessableEntity, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{
name: "PUT Non Linux Peer",
requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"]}", nonLinuxExistingPeerID, existingGroupID)),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{ {
name: "PUT Invalid Network Identifier", name: "PUT Invalid Network Identifier",
requestType: http.MethodPut, requestType: http.MethodPut,
@ -255,6 +338,24 @@ func TestRoutesHandlers(t *testing.T) {
expectedStatus: http.StatusUnprocessableEntity, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{
name: "PUT UnprocessableEntity when both peer and peer_groups are provided",
requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"peer\":\"%s\",\"peer_groups\":[\"%s\"],\"groups\":[\"%s\"]}", existingPeerID, existingGroupID, existingGroupID))),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "PUT UnprocessableEntity when no peer and peer_groups are provided",
requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"groups\":[\"%s\"]}", existingPeerID))),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
} }
p := initRoutesTestData() p := initRoutesTestData()

View File

@ -77,6 +77,7 @@ func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
// WriteError converts an error to an JSON error response. // WriteError converts an error to an JSON error response.
// If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise // If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise
func WriteError(err error, w http.ResponseWriter) { func WriteError(err error, w http.ResponseWriter) {
log.Errorf("got a handler error: %s", err.Error())
errStatus, ok := status.FromError(err) errStatus, ok := status.FromError(err)
httpStatus := http.StatusInternalServerError httpStatus := http.StatusInternalServerError
msg := "internal server error" msg := "internal server error"

View File

@ -210,47 +210,7 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) {
} }
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (am *AuthentikManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { func (am *AuthentikManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
ctx, err := am.authenticationContext()
if err != nil {
return err
}
userPk, err := strconv.ParseInt(userID, 10, 32)
if err != nil {
return err
}
var pendingInvite bool
if appMetadata.WTPendingInvite != nil {
pendingInvite = *appMetadata.WTPendingInvite
}
patchedUserReq := api.PatchedUserRequest{
Attributes: map[string]interface{}{
wtAccountID: appMetadata.WTAccountID,
wtPendingInvite: pendingInvite,
},
}
_, resp, err := am.apiClient.CoreApi.CoreUsersPartialUpdate(ctx, int32(userPk)).
PatchedUserRequest(patchedUserReq).
Execute()
if err != nil {
return err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to update user %s, statusCode %d", userID, resp.StatusCode)
}
return nil return nil
} }
@ -283,7 +243,10 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode) return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
} }
return parseAuthentikUser(*user) userData := parseAuthentikUser(*user)
userData.AppMetadata = appMetadata
return userData, nil
} }
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
@ -293,8 +256,7 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
return nil, err return nil, err
} }
accountFilter := fmt.Sprintf("{%q:%q}", wtAccountID, accountID) userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Execute()
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Attributes(accountFilter).Execute()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -313,10 +275,9 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
users := make([]*UserData, 0) users := make([]*UserData, 0)
for _, user := range userList.Results { for _, user := range userList.Results {
userData, err := parseAuthentikUser(user) userData := parseAuthentikUser(user)
if err != nil { userData.AppMetadata.WTAccountID = accountID
return nil, err
}
users = append(users, userData) users = append(users, userData)
} }
@ -350,65 +311,16 @@ func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) {
indexedUsers := make(map[string][]*UserData) indexedUsers := make(map[string][]*UserData)
for _, user := range userList.Results { for _, user := range userList.Results {
userData, err := parseAuthentikUser(user) userData := parseAuthentikUser(user)
if err != nil { indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
return nil, err
}
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
} }
return indexedUsers, nil return indexedUsers, nil
} }
// CreateUser creates a new user in authentik Idp and sends an invitation. // CreateUser creates a new user in authentik Idp and sends an invitation.
func (am *AuthentikManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { func (am *AuthentikManager) CreateUser(_, _, _, _ string) (*UserData, error) {
ctx, err := am.authenticationContext() return nil, fmt.Errorf("method CreateUser not implemented")
if err != nil {
return nil, err
}
groupID, err := am.getUserGroupByName("netbird")
if err != nil {
return nil, err
}
defaultBoolValue := true
createUserRequest := api.UserRequest{
Email: &email,
Name: name,
IsActive: &defaultBoolValue,
Groups: []string{groupID},
Username: email,
Attributes: map[string]interface{}{
wtAccountID: accountID,
wtPendingInvite: &defaultBoolValue,
},
}
user, resp, err := am.apiClient.CoreApi.CoreUsersCreate(ctx).UserRequest(createUserRequest).Execute()
if err != nil {
return nil, err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountCreateUser()
}
if resp.StatusCode != http.StatusCreated {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
}
return parseAuthentikUser(*user)
} }
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
@ -438,11 +350,7 @@ func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) {
users := make([]*UserData, 0) users := make([]*UserData, 0)
for _, user := range userList.Results { for _, user := range userList.Results {
userData, err := parseAuthentikUser(user) users = append(users, parseAuthentikUser(user))
if err != nil {
return nil, err
}
users = append(users, userData)
} }
return users, nil return users, nil
@ -501,64 +409,10 @@ func (am *AuthentikManager) authenticationContext() (context.Context, error) {
return context.WithValue(context.Background(), api.ContextAPIKeys, value), nil return context.WithValue(context.Background(), api.ContextAPIKeys, value), nil
} }
// getUserGroupByName retrieves the user group for assigning new users. func parseAuthentikUser(user api.User) *UserData {
// If the group is not found, a new group with the specified name will be created.
func (am *AuthentikManager) getUserGroupByName(name string) (string, error) {
ctx, err := am.authenticationContext()
if err != nil {
return "", err
}
groupList, resp, err := am.apiClient.CoreApi.CoreGroupsList(ctx).Name(name).Execute()
if err != nil {
return "", err
}
defer resp.Body.Close()
if groupList != nil {
if len(groupList.Results) > 0 {
return groupList.Results[0].Pk, nil
}
}
createGroupRequest := api.GroupRequest{Name: name}
group, resp, err := am.apiClient.CoreApi.CoreGroupsCreate(ctx).GroupRequest(createGroupRequest).Execute()
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
return "", fmt.Errorf("unable to create user group, statusCode: %d", resp.StatusCode)
}
return group.Pk, nil
}
func parseAuthentikUser(user api.User) (*UserData, error) {
var attributes struct {
AccountID string `json:"wt_account_id"`
PendingInvite bool `json:"wt_pending_invite"`
}
helper := JsonParser{}
buf, err := helper.Marshal(user.Attributes)
if err != nil {
return nil, err
}
err = helper.Unmarshal(buf, &attributes)
if err != nil {
return nil, err
}
return &UserData{ return &UserData{
Email: *user.Email, Email: *user.Email,
Name: user.Name, Name: user.Name,
ID: strconv.FormatInt(int64(user.Pk), 10), ID: strconv.FormatInt(int64(user.Pk), 10),
AppMetadata: AppMetadata{ }
WTAccountID: attributes.AccountID,
WTPendingInvite: &attributes.PendingInvite,
},
}, nil
} }

View File

@ -1,7 +1,6 @@
package idp package idp
import ( import (
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -11,18 +10,12 @@ import (
"time" "time"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server/telemetry"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/telemetry"
) )
const ( const profileFields = "id,displayName,mail,userPrincipalName"
// azure extension properties template
wtAccountIDTpl = "extension_%s_wt_account_id"
wtPendingInviteTpl = "extension_%s_wt_pending_invite"
profileFields = "id,displayName,mail,userPrincipalName"
extensionFields = "id,name,targetObjects"
)
// AzureManager azure manager client instance. // AzureManager azure manager client instance.
type AzureManager struct { type AzureManager struct {
@ -58,21 +51,6 @@ type AzureCredentials struct {
// azureProfile represents an azure user profile. // azureProfile represents an azure user profile.
type azureProfile map[string]any type azureProfile map[string]any
// passwordProfile represent authentication method for,
// newly created user profile.
type passwordProfile struct {
ForceChangePasswordNextSignIn bool `json:"forceChangePasswordNextSignIn"`
Password string `json:"password"`
}
// azureExtension represent custom attribute,
// that can be added to user objects in Azure Active Directory (AD).
type azureExtension struct {
Name string `json:"name"`
DataType string `json:"dataType"`
TargetObjects []string `json:"targetObjects"`
}
// NewAzureManager creates a new instance of the AzureManager. // NewAzureManager creates a new instance of the AzureManager.
func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) { func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport := http.DefaultTransport.(*http.Transport).Clone()
@ -115,7 +93,7 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
appMetrics: appMetrics, appMetrics: appMetrics,
} }
manager := &AzureManager{ return &AzureManager{
ObjectID: config.ObjectID, ObjectID: config.ObjectID,
ClientID: config.ClientID, ClientID: config.ClientID,
GraphAPIEndpoint: config.GraphAPIEndpoint, GraphAPIEndpoint: config.GraphAPIEndpoint,
@ -123,14 +101,7 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
credentials: credentials, credentials: credentials,
helper: helper, helper: helper,
appMetrics: appMetrics, appMetrics: appMetrics,
} }, nil
err := manager.configureAppMetadata()
if err != nil {
return nil, err
}
return manager, nil
} }
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure. // jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure.
@ -236,44 +207,14 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
} }
// CreateUser creates a new user in azure AD Idp. // CreateUser creates a new user in azure AD Idp.
func (am *AzureManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { func (am *AzureManager) CreateUser(_, _, _, _ string) (*UserData, error) {
payload, err := buildAzureCreateUserRequestPayload(email, name, accountID, am.ClientID) return nil, fmt.Errorf("method CreateUser not implemented")
if err != nil {
return nil, err
}
body, err := am.post("users", payload)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountCreateUser()
}
var profile azureProfile
err = am.helper.Unmarshal(body, &profile)
if err != nil {
return nil, err
}
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
profile[wtAccountIDField] = accountID
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
profile[wtPendingInviteField] = true
return profile.userData(am.ClientID), nil
} }
// GetUserDataByID requests user data from keycloak via ID. // GetUserDataByID requests user data from keycloak via ID.
func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{} q := url.Values{}
q.Add("$select", selectFields) q.Add("$select", profileFields)
body, err := am.get("users/"+userID, q) body, err := am.get("users/"+userID, q)
if err != nil { if err != nil {
@ -290,18 +231,17 @@ func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata)
return nil, err return nil, err
} }
return profile.userData(am.ClientID), nil userData := profile.userData()
userData.AppMetadata = appMetadata
return userData, nil
} }
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list. // If no users have been found, this function returns an empty list.
func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) { func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{} q := url.Values{}
q.Add("$select", selectFields) q.Add("$select", profileFields)
body, err := am.get("users/"+email, q) body, err := am.get("users/"+email, q)
if err != nil { if err != nil {
@ -319,20 +259,15 @@ func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
} }
users := make([]*UserData, 0) users := make([]*UserData, 0)
users = append(users, profile.userData(am.ClientID)) users = append(users, profile.userData())
return users, nil return users, nil
} }
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) { func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{} q := url.Values{}
q.Add("$select", selectFields) q.Add("$select", profileFields)
q.Add("$filter", fmt.Sprintf("%s eq '%s'", wtAccountIDField, accountID))
body, err := am.get("users", q) body, err := am.get("users", q)
if err != nil { if err != nil {
@ -351,7 +286,10 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
users := make([]*UserData, 0) users := make([]*UserData, 0)
for _, profile := range profiles.Value { for _, profile := range profiles.Value {
users = append(users, profile.userData(am.ClientID)) userData := profile.userData()
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
} }
return users, nil return users, nil
@ -360,12 +298,8 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID. // It returns a list of users indexed by accountID.
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) { func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{} q := url.Values{}
q.Add("$select", selectFields) q.Add("$select", profileFields)
body, err := am.get("users", q) body, err := am.get("users", q)
if err != nil { if err != nil {
@ -384,67 +318,15 @@ func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
indexedUsers := make(map[string][]*UserData) indexedUsers := make(map[string][]*UserData)
for _, profile := range profiles.Value { for _, profile := range profiles.Value {
userData := profile.userData(am.ClientID) userData := profile.userData()
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
} }
return indexedUsers, nil return indexedUsers, nil
} }
// UpdateUserAppMetadata updates user app metadata based on userID. // UpdateUserAppMetadata updates user app metadata based on userID.
func (am *AzureManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { func (am *AzureManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return err
}
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
data, err := am.helper.Marshal(map[string]any{
wtAccountIDField: appMetadata.WTAccountID,
wtPendingInviteField: appMetadata.WTPendingInvite,
})
if err != nil {
return err
}
payload := strings.NewReader(string(data))
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, userID)
req, err := http.NewRequest(http.MethodPatch, reqURL, payload)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.Debugf("updating idp metadata for user %s", userID)
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
if resp.StatusCode != http.StatusNoContent {
return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode)
}
return nil return nil
} }
@ -454,7 +336,7 @@ func (am *AzureManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented") return fmt.Errorf("method InviteUserByID not implemented")
} }
// DeleteUser from Azure // DeleteUser from Azure.
func (am *AzureManager) DeleteUser(userID string) error { func (am *AzureManager) DeleteUser(userID string) error {
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate()
if err != nil { if err != nil {
@ -491,81 +373,6 @@ func (am *AzureManager) DeleteUser(userID string) error {
return nil return nil
} }
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
q := url.Values{}
q.Add("$select", extensionFields)
resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID)
body, err := am.get(resource, q)
if err != nil {
return nil, err
}
var extensions struct{ Value []azureExtension }
err = am.helper.Unmarshal(body, &extensions)
if err != nil {
return nil, err
}
return extensions.Value, nil
}
func (am *AzureManager) createUserExtension(name string) (*azureExtension, error) {
extension := azureExtension{
Name: name,
DataType: "string",
TargetObjects: []string{"User"},
}
payload, err := am.helper.Marshal(extension)
if err != nil {
return nil, err
}
resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID)
body, err := am.post(resource, string(payload))
if err != nil {
return nil, err
}
var userExtension azureExtension
err = am.helper.Unmarshal(body, &userExtension)
if err != nil {
return nil, err
}
return &userExtension, nil
}
// configureAppMetadata sets up app metadata extensions if they do not exists.
func (am *AzureManager) configureAppMetadata() error {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
extensions, err := am.getUserExtensions()
if err != nil {
return err
}
// If the wt_account_id extension does not already exist, create it.
if !hasExtension(extensions, wtAccountIDField) {
_, err = am.createUserExtension(wtAccountID)
if err != nil {
return err
}
}
// If the wt_pending_invite extension does not already exist, create it.
if !hasExtension(extensions, wtPendingInviteField) {
_, err = am.createUserExtension(wtPendingInvite)
if err != nil {
return err
}
}
return nil
}
// get perform Get requests. // get perform Get requests.
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) { func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate()
@ -602,44 +409,8 @@ func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
return io.ReadAll(resp.Body) return io.ReadAll(resp.Body)
} }
// post perform Post requests.
func (am *AzureManager) post(resource string, body string) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}
reqURL := fmt.Sprintf("%s/%s", am.GraphAPIEndpoint, resource)
req, err := http.NewRequest(http.MethodPost, reqURL, strings.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// userData construct user data from keycloak profile. // userData construct user data from keycloak profile.
func (ap azureProfile) userData(clientID string) *UserData { func (ap azureProfile) userData() *UserData {
id, ok := ap["id"].(string) id, ok := ap["id"].(string)
if !ok { if !ok {
id = "" id = ""
@ -655,66 +426,9 @@ func (ap azureProfile) userData(clientID string) *UserData {
name = "" name = ""
} }
accountIDField := extensionName(wtAccountIDTpl, clientID)
accountID, ok := ap[accountIDField].(string)
if !ok {
accountID = ""
}
pendingInviteField := extensionName(wtPendingInviteTpl, clientID)
pendingInvite, ok := ap[pendingInviteField].(bool)
if !ok {
pendingInvite = false
}
return &UserData{ return &UserData{
Email: email, Email: email,
Name: name, Name: name,
ID: id, ID: id,
AppMetadata: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &pendingInvite,
},
} }
} }
func buildAzureCreateUserRequestPayload(email, name, accountID, clientID string) (string, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, clientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, clientID)
req := &azureProfile{
"accountEnabled": true,
"displayName": name,
"mailNickName": strings.Join(strings.Split(name, " "), ""),
"userPrincipalName": email,
"passwordProfile": passwordProfile{
ForceChangePasswordNextSignIn: true,
Password: GeneratePassword(8, 1, 1, 1),
},
wtAccountIDField: accountID,
wtPendingInviteField: true,
}
str, err := json.Marshal(req)
if err != nil {
return "", err
}
return string(str), nil
}
func extensionName(extensionTpl, clientID string) string {
clientID = strings.ReplaceAll(clientID, "-", "")
return fmt.Sprintf(extensionTpl, clientID)
}
// hasExtension checks whether a given extension by name,
// exists in an list of extensions.
func hasExtension(extensions []azureExtension, name string) bool {
for _, ext := range extensions {
if ext.Name == name {
return true
}
}
return false
}

View File

@ -8,15 +8,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
type mockAzureCredentials struct {
jwtToken JWTToken
err error
}
func (mc *mockAzureCredentials) Authenticate() (JWTToken, error) {
return mc.jwtToken, mc.err
}
func TestAzureJwtStillValid(t *testing.T) { func TestAzureJwtStillValid(t *testing.T) {
type jwtStillValidTest struct { type jwtStillValidTest struct {
name string name string
@ -124,115 +115,9 @@ func TestAzureAuthenticate(t *testing.T) {
} }
} }
func TestAzureUpdateUserAppMetadata(t *testing.T) {
type updateUserAppMetadataTest struct {
name string
inputReqBody string
expectedReqBody string
appMetadata AppMetadata
statusCode int
helper ManagerHelper
managerCreds ManagerCredentials
assertErrFunc assert.ErrorAssertionFunc
assertErrFuncMessage string
}
appMetadata := AppMetadata{WTAccountID: "ok"}
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
name: "Bad Authentication",
expectedReqBody: "",
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
err: fmt.Errorf("error"),
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Status Code",
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
name: "Bad Response Parsing",
statusCode: 400,
helper: &mockJsonParser{marshalErrorString: "error"},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Good request",
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 204,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
invite := true
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
name: "Update Pending Invite",
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":true}", appMetadata.WTAccountID),
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 204,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
t.Run(testCase.name, func(t *testing.T) {
reqClient := mockHTTPClient{
resBody: testCase.inputReqBody,
code: testCase.statusCode,
}
manager := &AzureManager{
httpClient: &reqClient,
credentials: testCase.managerCreds,
helper: testCase.helper,
}
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
})
}
}
func TestAzureProfile(t *testing.T) { func TestAzureProfile(t *testing.T) {
type azureProfileTest struct { type azureProfileTest struct {
name string name string
clientID string
invite bool invite bool
inputProfile azureProfile inputProfile azureProfile
expectedUserData UserData expectedUserData UserData
@ -240,90 +125,53 @@ func TestAzureProfile(t *testing.T) {
azureProfileTestCase1 := azureProfileTest{ azureProfileTestCase1 := azureProfileTest{
name: "Good Request", name: "Good Request",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
invite: false, invite: false,
inputProfile: azureProfile{ inputProfile: azureProfile{
"id": "test1", "id": "test1",
"displayName": "John Doe", "displayName": "John Doe",
"userPrincipalName": "test1@test.com", "userPrincipalName": "test1@test.com",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false,
}, },
expectedUserData: UserData{ expectedUserData: UserData{
Email: "test1@test.com", Email: "test1@test.com",
Name: "John Doe", Name: "John Doe",
ID: "test1", ID: "test1",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
}, },
} }
azureProfileTestCase2 := azureProfileTest{ azureProfileTestCase2 := azureProfileTest{
name: "Missing User ID", name: "Missing User ID",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
invite: true, invite: true,
inputProfile: azureProfile{ inputProfile: azureProfile{
"displayName": "John Doe", "displayName": "John Doe",
"userPrincipalName": "test2@test.com", "userPrincipalName": "test2@test.com",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": true,
}, },
expectedUserData: UserData{ expectedUserData: UserData{
Email: "test2@test.com", Email: "test2@test.com",
Name: "John Doe", Name: "John Doe",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
}, },
} }
azureProfileTestCase3 := azureProfileTest{ azureProfileTestCase3 := azureProfileTest{
name: "Missing User Name", name: "Missing User Name",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
invite: false, invite: false,
inputProfile: azureProfile{ inputProfile: azureProfile{
"id": "test3", "id": "test3",
"userPrincipalName": "test3@test.com", "userPrincipalName": "test3@test.com",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false,
}, },
expectedUserData: UserData{ expectedUserData: UserData{
ID: "test3", ID: "test3",
Email: "test3@test.com", Email: "test3@test.com",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
}, },
} }
azureProfileTestCase4 := azureProfileTest{ for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3} {
name: "Missing Extension Fields",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
invite: false,
inputProfile: azureProfile{
"id": "test4",
"displayName": "John Doe",
"userPrincipalName": "test4@test.com",
},
expectedUserData: UserData{
ID: "test4",
Name: "John Doe",
Email: "test4@test.com",
AppMetadata: AppMetadata{},
},
}
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3, azureProfileTestCase4} {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
userData := testCase.inputProfile.userData(testCase.clientID) userData := testCase.inputProfile.userData()
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match") assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match") assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match") assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match")
}) })
} }
} }

View File

@ -5,15 +5,14 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/netbirdio/netbird/management/server/telemetry"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
admin "google.golang.org/api/admin/directory/v1" admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/googleapi"
"google.golang.org/api/option" "google.golang.org/api/option"
"github.com/netbirdio/netbird/management/server/telemetry"
) )
// GoogleWorkspaceManager Google Workspace manager client instance. // GoogleWorkspaceManager Google Workspace manager client instance.
@ -73,17 +72,13 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te
} }
service, err := admin.NewService(context.Background(), service, err := admin.NewService(context.Background(),
option.WithScopes(admin.AdminDirectoryUserScope, admin.AdminDirectoryUserschemaScope), option.WithScopes(admin.AdminDirectoryUserReadonlyScope),
option.WithCredentials(adminCredentials), option.WithCredentials(adminCredentials),
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err = configureAppMetadataSchema(service, config.CustomerID); err != nil {
return nil, err
}
return &GoogleWorkspaceManager{ return &GoogleWorkspaceManager{
usersService: service.Users, usersService: service.Users,
CustomerID: config.CustomerID, CustomerID: config.CustomerID,
@ -95,27 +90,7 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te
} }
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
metadata, err := gm.helper.Marshal(appMetadata)
if err != nil {
return err
}
user := &admin.User{
CustomSchemas: map[string]googleapi.RawMessage{
"app_metadata": metadata,
},
}
_, err = gm.usersService.Update(userID, user).Do()
if err != nil {
return err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
return nil return nil
} }
@ -130,23 +105,23 @@ func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata App
gm.appMetrics.IDPMetrics().CountGetUserDataByID() gm.appMetrics.IDPMetrics().CountGetUserDataByID()
} }
return parseGoogleWorkspaceUser(user) userData := parseGoogleWorkspaceUser(user)
userData.AppMetadata = appMetadata
return userData, nil
} }
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) { func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) {
query := fmt.Sprintf("app_metadata.wt_account_id=\"%s\"", accountID) usersList, err := gm.usersService.List().Customer(gm.CustomerID).Projection("full").Do()
usersList, err := gm.usersService.List().Customer(gm.CustomerID).Query(query).Projection("full").Do()
if err != nil { if err != nil {
return nil, err return nil, err
} }
usersData := make([]*UserData, 0) usersData := make([]*UserData, 0)
for _, user := range usersList.Users { for _, user := range usersList.Users {
userData, err := parseGoogleWorkspaceUser(user) userData := parseGoogleWorkspaceUser(user)
if err != nil { userData.AppMetadata.WTAccountID = accountID
return nil, err
}
usersData = append(usersData, userData) usersData = append(usersData, userData)
} }
@ -168,61 +143,16 @@ func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, erro
indexedUsers := make(map[string][]*UserData) indexedUsers := make(map[string][]*UserData)
for _, user := range usersList.Users { for _, user := range usersList.Users {
userData, err := parseGoogleWorkspaceUser(user) userData := parseGoogleWorkspaceUser(user)
if err != nil { indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
return nil, err
}
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
} }
return indexedUsers, nil return indexedUsers, nil
} }
// CreateUser creates a new user in Google Workspace and sends an invitation. // CreateUser creates a new user in Google Workspace and sends an invitation.
func (gm *GoogleWorkspaceManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, error) {
invite := true return nil, fmt.Errorf("method CreateUser not implemented")
metadata := AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &invite,
}
username := &admin.UserName{}
fields := strings.Fields(name)
if n := len(fields); n > 0 {
username.GivenName = strings.Join(fields[:n-1], " ")
username.FamilyName = fields[n-1]
}
payload, err := gm.helper.Marshal(metadata)
if err != nil {
return nil, err
}
user := &admin.User{
Name: username,
PrimaryEmail: email,
CustomSchemas: map[string]googleapi.RawMessage{
"app_metadata": payload,
},
Password: GeneratePassword(8, 1, 1, 1),
}
user, err = gm.usersService.Insert(user).Do()
if err != nil {
return nil, err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountCreateUser()
}
return parseGoogleWorkspaceUser(user)
} }
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
@ -237,13 +167,8 @@ func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, err
gm.appMetrics.IDPMetrics().CountGetUserByEmail() gm.appMetrics.IDPMetrics().CountGetUserByEmail()
} }
userData, err := parseGoogleWorkspaceUser(user)
if err != nil {
return nil, err
}
users := make([]*UserData, 0) users := make([]*UserData, 0)
users = append(users, userData) users = append(users, parseGoogleWorkspaceUser(user))
return users, nil return users, nil
} }
@ -281,8 +206,7 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error)
creds, err := google.CredentialsFromJSON( creds, err := google.CredentialsFromJSON(
context.Background(), context.Background(),
decodeKey, decodeKey,
admin.AdminDirectoryUserschemaScope, admin.AdminDirectoryUserReadonlyScope,
admin.AdminDirectoryUserScope,
) )
if err == nil { if err == nil {
// No need to fallback to the default Google credentials path // No need to fallback to the default Google credentials path
@ -294,8 +218,7 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error)
creds, err = google.FindDefaultCredentials( creds, err = google.FindDefaultCredentials(
context.Background(), context.Background(),
admin.AdminDirectoryUserschemaScope, admin.AdminDirectoryUserReadonlyScope,
admin.AdminDirectoryUserScope,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -304,62 +227,11 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error)
return creds, nil return creds, nil
} }
// configureAppMetadataSchema create a custom schema for managing app metadata fields in Google Workspace.
func configureAppMetadataSchema(service *admin.Service, customerID string) error {
schemaList, err := service.Schemas.List(customerID).Do()
if err != nil {
return err
}
// checks if app_metadata schema is already created
for _, schema := range schemaList.Schemas {
if schema.SchemaName == "app_metadata" {
return nil
}
}
// create new app_metadata schema
appMetadataSchema := &admin.Schema{
SchemaName: "app_metadata",
Fields: []*admin.SchemaFieldSpec{
{
FieldName: "wt_account_id",
FieldType: "STRING",
MultiValued: false,
},
{
FieldName: "wt_pending_invite",
FieldType: "BOOL",
MultiValued: false,
},
},
}
_, err = service.Schemas.Insert(customerID, appMetadataSchema).Do()
if err != nil {
return err
}
return nil
}
// parseGoogleWorkspaceUser parse google user to UserData. // parseGoogleWorkspaceUser parse google user to UserData.
func parseGoogleWorkspaceUser(user *admin.User) (*UserData, error) { func parseGoogleWorkspaceUser(user *admin.User) *UserData {
var appMetadata AppMetadata
// Get app metadata from custom schemas
if user.CustomSchemas != nil {
rawMessage := user.CustomSchemas["app_metadata"]
helper := JsonParser{}
if err := helper.Unmarshal(rawMessage, &appMetadata); err != nil {
return nil, err
}
}
return &UserData{ return &UserData{
ID: user.Id, ID: user.Id,
Email: user.PrimaryEmail, Email: user.PrimaryEmail,
Name: user.Name.FullName, Name: user.Name.FullName,
AppMetadata: appMetadata, }
}, nil
} }

View File

@ -9,6 +9,11 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
) )
const (
// UnsetAccountID is a special key to map users without an account ID
UnsetAccountID = "unset"
)
// Manager idp manager interface // Manager idp manager interface
type Manager interface { type Manager interface {
UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error
@ -38,10 +43,10 @@ type Config struct {
ManagerType string ManagerType string
ClientConfig *ClientConfig ClientConfig *ClientConfig
ExtraConfig ExtraConfig ExtraConfig ExtraConfig
Auth0ClientCredentials Auth0ClientConfig Auth0ClientCredentials *Auth0ClientConfig
AzureClientCredentials AzureClientConfig AzureClientCredentials *AzureClientConfig
KeycloakClientCredentials KeycloakClientConfig KeycloakClientCredentials *KeycloakClientConfig
ZitadelClientCredentials ZitadelClientConfig ZitadelClientCredentials *ZitadelClientConfig
} }
// ManagerCredentials interface that authenticates using the credential of each type of idp // ManagerCredentials interface that authenticates using the credential of each type of idp
@ -97,7 +102,7 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
case "auth0": case "auth0":
auth0ClientConfig := config.Auth0ClientCredentials auth0ClientConfig := config.Auth0ClientCredentials
if config.ClientConfig != nil { if config.ClientConfig != nil {
auth0ClientConfig = Auth0ClientConfig{ auth0ClientConfig = &Auth0ClientConfig{
Audience: config.ExtraConfig["Audience"], Audience: config.ExtraConfig["Audience"],
AuthIssuer: config.ClientConfig.Issuer, AuthIssuer: config.ClientConfig.Issuer,
ClientID: config.ClientConfig.ClientID, ClientID: config.ClientConfig.ClientID,
@ -106,11 +111,11 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
} }
} }
return NewAuth0Manager(auth0ClientConfig, appMetrics) return NewAuth0Manager(*auth0ClientConfig, appMetrics)
case "azure": case "azure":
azureClientConfig := config.AzureClientCredentials azureClientConfig := config.AzureClientCredentials
if config.ClientConfig != nil { if config.ClientConfig != nil {
azureClientConfig = AzureClientConfig{ azureClientConfig = &AzureClientConfig{
ClientID: config.ClientConfig.ClientID, ClientID: config.ClientConfig.ClientID,
ClientSecret: config.ClientConfig.ClientSecret, ClientSecret: config.ClientConfig.ClientSecret,
GrantType: config.ClientConfig.GrantType, GrantType: config.ClientConfig.GrantType,
@ -120,11 +125,11 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
} }
} }
return NewAzureManager(azureClientConfig, appMetrics) return NewAzureManager(*azureClientConfig, appMetrics)
case "keycloak": case "keycloak":
keycloakClientConfig := config.KeycloakClientCredentials keycloakClientConfig := config.KeycloakClientCredentials
if config.ClientConfig != nil { if config.ClientConfig != nil {
keycloakClientConfig = KeycloakClientConfig{ keycloakClientConfig = &KeycloakClientConfig{
ClientID: config.ClientConfig.ClientID, ClientID: config.ClientConfig.ClientID,
ClientSecret: config.ClientConfig.ClientSecret, ClientSecret: config.ClientConfig.ClientSecret,
GrantType: config.ClientConfig.GrantType, GrantType: config.ClientConfig.GrantType,
@ -133,11 +138,11 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
} }
} }
return NewKeycloakManager(keycloakClientConfig, appMetrics) return NewKeycloakManager(*keycloakClientConfig, appMetrics)
case "zitadel": case "zitadel":
zitadelClientConfig := config.ZitadelClientCredentials zitadelClientConfig := config.ZitadelClientCredentials
if config.ClientConfig != nil { if config.ClientConfig != nil {
zitadelClientConfig = ZitadelClientConfig{ zitadelClientConfig = &ZitadelClientConfig{
ClientID: config.ClientConfig.ClientID, ClientID: config.ClientConfig.ClientID,
ClientSecret: config.ClientConfig.ClientSecret, ClientSecret: config.ClientConfig.ClientSecret,
GrantType: config.ClientConfig.GrantType, GrantType: config.ClientConfig.GrantType,
@ -146,7 +151,7 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
} }
} }
return NewZitadelManager(zitadelClientConfig, appMetrics) return NewZitadelManager(*zitadelClientConfig, appMetrics)
case "authentik": case "authentik":
authentikConfig := AuthentikClientConfig{ authentikConfig := AuthentikClientConfig{
Issuer: config.ClientConfig.Issuer, Issuer: config.ClientConfig.Issuer,
@ -171,7 +176,11 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
CustomerID: config.ExtraConfig["CustomerId"], CustomerID: config.ExtraConfig["CustomerId"],
} }
return NewGoogleWorkspaceManager(googleClientConfig, appMetrics) return NewGoogleWorkspaceManager(googleClientConfig, appMetrics)
case "jumpcloud":
jumpcloudConfig := JumpCloudClientConfig{
APIToken: config.ExtraConfig["ApiToken"],
}
return NewJumpCloudManager(jumpcloudConfig, appMetrics)
default: default:
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType) return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
} }

View File

@ -0,0 +1,257 @@
package idp
import (
"context"
"fmt"
"net/http"
"strings"
"time"
v1 "github.com/TheJumpCloud/jcapi-go/v1"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
contentType = "application/json"
accept = "application/json"
)
// JumpCloudManager JumpCloud manager client instance.
type JumpCloudManager struct {
client *v1.APIClient
apiToken string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
// JumpCloudClientConfig JumpCloud manager client configurations.
type JumpCloudClientConfig struct {
APIToken string
}
// JumpCloudCredentials JumpCloud authentication information.
type JumpCloudCredentials struct {
clientConfig JumpCloudClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
appMetrics telemetry.AppMetrics
}
// NewJumpCloudManager creates a new instance of the JumpCloudManager.
func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppMetrics) (*JumpCloudManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
if config.APIToken == "" {
return nil, fmt.Errorf("jumpCloud IdP configuration is incomplete, ApiToken is missing")
}
client := v1.NewAPIClient(v1.NewConfiguration())
credentials := &JumpCloudCredentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
return &JumpCloudManager{
client: client,
apiToken: config.APIToken,
httpClient: httpClient,
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}, nil
}
// Authenticate retrieves access token to use the JumpCloud user API.
func (jc *JumpCloudCredentials) Authenticate() (JWTToken, error) {
return JWTToken{}, nil
}
func (jm *JumpCloudManager) authenticationContext() context.Context {
return context.WithValue(context.Background(), v1.ContextAPIKey, v1.APIKey{
Key: jm.apiToken,
})
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
return nil
}
// GetUserDataByID requests user data from JumpCloud via ID.
func (jm *JumpCloudManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
authCtx := jm.authenticationContext()
user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountGetUserDataByID()
}
userData := parseJumpCloudUser(user)
userData.AppMetadata = appMetadata
return userData, nil
}
// GetAccount returns all the users for a given profile.
func (jm *JumpCloudManager) GetAccount(accountID string) ([]*UserData, error) {
authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode)
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountGetAccount()
}
users := make([]*UserData, 0)
for _, user := range userList.Results {
userData := parseJumpCloudUser(user)
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
}
return users, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (jm *JumpCloudManager) GetAllAccounts() (map[string][]*UserData, error) {
authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountGetAllAccounts()
}
indexedUsers := make(map[string][]*UserData)
for _, user := range userList.Results {
userData := parseJumpCloudUser(user)
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
}
return indexedUsers, nil
}
// CreateUser creates a new user in JumpCloud Idp and sends an invitation.
func (jm *JumpCloudManager) CreateUser(_, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (jm *JumpCloudManager) GetUserByEmail(email string) ([]*UserData, error) {
searchFilter := map[string]interface{}{
"searchFilter": map[string]interface{}{
"filter": []string{email},
"fields": []string{"email"},
},
}
authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, searchFilter)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountGetUserByEmail()
}
usersData := make([]*UserData, 0)
for _, user := range userList.Results {
usersData = append(usersData, parseJumpCloudUser(user))
}
return usersData, nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (jm *JumpCloudManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from jumpCloud directory
func (jm *JumpCloudManager) DeleteUser(userID string) error {
authCtx := jm.authenticationContext()
_, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
}
if jm.appMetrics != nil {
jm.appMetrics.IDPMetrics().CountDeleteUser()
}
return nil
}
// parseJumpCloudUser parse JumpCloud system user returned from API V1 to UserData.
func parseJumpCloudUser(user v1.Systemuserreturn) *UserData {
names := []string{user.Firstname, user.Middlename, user.Lastname}
return &UserData{
Email: user.Email,
Name: strings.Join(names, " "),
ID: user.Id,
}
}

View File

@ -0,0 +1,46 @@
package idp
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
)
func TestNewJumpCloudManager(t *testing.T) {
type test struct {
name string
inputConfig JumpCloudClientConfig
assertErrFunc require.ErrorAssertionFunc
assertErrFuncMessage string
}
defaultTestConfig := JumpCloudClientConfig{
APIToken: "test123",
}
testCase1 := test{
name: "Good Configuration",
inputConfig: defaultTestConfig,
assertErrFunc: require.NoError,
assertErrFuncMessage: "shouldn't return error",
}
testCase2Config := defaultTestConfig
testCase2Config.APIToken = ""
testCase2 := test{
name: "Missing APIToken Configuration",
inputConfig: testCase2Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
}
for _, testCase := range []test{testCase1, testCase2} {
t.Run(testCase.name, func(t *testing.T) {
_, err := NewJumpCloudManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
})
}
}

View File

@ -1,12 +1,10 @@
package idp package idp
import ( import (
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"path"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -18,11 +16,6 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
) )
const (
wtAccountID = "wt_account_id"
wtPendingInvite = "wt_pending_invite"
)
// KeycloakManager keycloak manager client instance. // KeycloakManager keycloak manager client instance.
type KeycloakManager struct { type KeycloakManager struct {
adminEndpoint string adminEndpoint string
@ -51,28 +44,10 @@ type KeycloakCredentials struct {
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
} }
// keycloakUserCredential describe the authentication method for,
// newly created user profile.
type keycloakUserCredential struct {
Type string `json:"type"`
Value string `json:"value"`
Temporary bool `json:"temporary"`
}
// keycloakUserAttributes holds additional user data fields. // keycloakUserAttributes holds additional user data fields.
type keycloakUserAttributes map[string][]string type keycloakUserAttributes map[string][]string
// createUserRequest is a user create request. // keycloakProfile represents a keycloak user profile response.
type keycloakCreateUserRequest struct {
Email string `json:"email"`
Username string `json:"username"`
Enabled bool `json:"enabled"`
EmailVerified bool `json:"emailVerified"`
Credentials []keycloakUserCredential `json:"credentials"`
Attributes keycloakUserAttributes `json:"attributes"`
}
// keycloakProfile represents an keycloak user profile response.
type keycloakProfile struct { type keycloakProfile struct {
ID string `json:"id"` ID string `json:"id"`
CreatedTimestamp int64 `json:"createdTimestamp"` CreatedTimestamp int64 `json:"createdTimestamp"`
@ -230,62 +205,8 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
} }
// CreateUser creates a new user in keycloak Idp and sends an invite. // CreateUser creates a new user in keycloak Idp and sends an invite.
func (km *KeycloakManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { func (km *KeycloakManager) CreateUser(_, _, _, _ string) (*UserData, error) {
jwtToken, err := km.credentials.Authenticate() return nil, fmt.Errorf("method CreateUser not implemented")
if err != nil {
return nil, err
}
invite := true
appMetadata := AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &invite,
}
payloadString, err := buildKeycloakCreateUserRequestPayload(email, name, appMetadata)
if err != nil {
return nil, err
}
reqURL := fmt.Sprintf("%s/users", km.adminEndpoint)
payload := strings.NewReader(payloadString)
req, err := http.NewRequest(http.MethodPost, reqURL, payload)
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountCreateUser()
}
resp, err := km.httpClient.Do(req)
if err != nil {
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
}
locationHeader := resp.Header.Get("location")
userID, err := extractUserIDFromLocationHeader(locationHeader)
if err != nil {
return nil, err
}
return km.GetUserDataByID(userID, appMetadata)
} }
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
@ -319,7 +240,7 @@ func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) {
} }
// GetUserDataByID requests user data from keycloak via ID. // GetUserDataByID requests user data from keycloak via ID.
func (km *KeycloakManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserData, error) {
body, err := km.get("users/"+userID, nil) body, err := km.get("users/"+userID, nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -338,12 +259,9 @@ func (km *KeycloakManager) GetUserDataByID(userID string, appMetadata AppMetadat
return profile.userData(), nil return profile.userData(), nil
} }
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given account profile.
func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) { func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
q := url.Values{} profiles, err := km.fetchAllUserProfiles()
q.Add("q", wtAccountID+":"+accountID)
body, err := km.get("users", q)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -352,15 +270,12 @@ func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
km.appMetrics.IDPMetrics().CountGetAccount() km.appMetrics.IDPMetrics().CountGetAccount()
} }
profiles := make([]keycloakProfile, 0)
err = km.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
users := make([]*UserData, 0) users := make([]*UserData, 0)
for _, profile := range profiles { for _, profile := range profiles {
users = append(users, profile.userData()) userData := profile.userData()
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
} }
return users, nil return users, nil
@ -369,15 +284,7 @@ func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID. // It returns a list of users indexed by accountID.
func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) { func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
totalUsers, err := km.totalUsersCount() profiles, err := km.fetchAllUserProfiles()
if err != nil {
return nil, err
}
q := url.Values{}
q.Add("max", fmt.Sprint(*totalUsers))
body, err := km.get("users", q)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -386,78 +293,17 @@ func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
km.appMetrics.IDPMetrics().CountGetAllAccounts() km.appMetrics.IDPMetrics().CountGetAllAccounts()
} }
profiles := make([]keycloakProfile, 0)
err = km.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData) indexedUsers := make(map[string][]*UserData)
for _, profile := range profiles { for _, profile := range profiles {
userData := profile.userData() userData := profile.userData()
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
} }
return indexedUsers, nil return indexedUsers, nil
} }
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (km *KeycloakManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { func (km *KeycloakManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
jwtToken, err := km.credentials.Authenticate()
if err != nil {
return err
}
attrs := keycloakUserAttributes{}
attrs.Set(wtAccountID, appMetadata.WTAccountID)
if appMetadata.WTPendingInvite != nil {
attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite))
} else {
attrs.Set(wtPendingInvite, "false")
}
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, userID)
data, err := km.helper.Marshal(map[string]any{
"attributes": attrs,
})
if err != nil {
return err
}
payload := strings.NewReader(string(data))
req, err := http.NewRequest(http.MethodPut, reqURL, payload)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.Debugf("updating IdP metadata for user %s", userID)
resp, err := km.httpClient.Do(req)
if err != nil {
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
if resp.StatusCode != http.StatusNoContent {
return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode)
}
return nil return nil
} }
@ -467,7 +313,7 @@ func (km *KeycloakManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented") return fmt.Errorf("method InviteUserByID not implemented")
} }
// DeleteUser from Keycloack // DeleteUser from Keycloak by user ID.
func (km *KeycloakManager) DeleteUser(userID string) error { func (km *KeycloakManager) DeleteUser(userID string) error {
jwtToken, err := km.credentials.Authenticate() jwtToken, err := km.credentials.Authenticate()
if err != nil { if err != nil {
@ -475,7 +321,6 @@ func (km *KeycloakManager) DeleteUser(userID string) error {
} }
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID)) reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID))
req, err := http.NewRequest(http.MethodDelete, reqURL, nil) req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
if err != nil { if err != nil {
return err return err
@ -508,32 +353,27 @@ func (km *KeycloakManager) DeleteUser(userID string) error {
return nil return nil
} }
func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) { func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) {
attrs := keycloakUserAttributes{} totalUsers, err := km.totalUsersCount()
attrs.Set(wtAccountID, appMetadata.WTAccountID)
attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite))
req := &keycloakCreateUserRequest{
Email: email,
Username: name,
Enabled: true,
EmailVerified: true,
Credentials: []keycloakUserCredential{
{
Type: "password",
Value: GeneratePassword(8, 1, 1, 1),
Temporary: false,
},
},
Attributes: attrs,
}
str, err := json.Marshal(req)
if err != nil { if err != nil {
return "", err return nil, err
} }
return string(str), nil q := url.Values{}
q.Add("max", fmt.Sprint(*totalUsers))
body, err := km.get("users", q)
if err != nil {
return nil, err
}
profiles := make([]keycloakProfile, 0)
err = km.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
return profiles, nil
} }
// get perform Get requests. // get perform Get requests.
@ -588,53 +428,11 @@ func (km *KeycloakManager) totalUsersCount() (*int, error) {
return &count, nil return &count, nil
} }
// extractUserIDFromLocationHeader extracts the user ID from the location,
// header once the user is created successfully
func extractUserIDFromLocationHeader(locationHeader string) (string, error) {
userURL, err := url.Parse(locationHeader)
if err != nil {
return "", err
}
return path.Base(userURL.Path), nil
}
// userData construct user data from keycloak profile. // userData construct user data from keycloak profile.
func (kp keycloakProfile) userData() *UserData { func (kp keycloakProfile) userData() *UserData {
accountID := kp.Attributes.Get(wtAccountID)
pendingInvite, err := strconv.ParseBool(kp.Attributes.Get(wtPendingInvite))
if err != nil {
pendingInvite = false
}
return &UserData{ return &UserData{
Email: kp.Email, Email: kp.Email,
Name: kp.Username, Name: kp.Username,
ID: kp.ID, ID: kp.ID,
AppMetadata: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &pendingInvite,
},
} }
} }
// Set sets the key to value. It replaces any existing
// values.
func (ka keycloakUserAttributes) Set(key, value string) {
ka[key] = []string{value}
}
// Get returns the first value associated with the given key.
// If there are no values associated with the key, Get returns
// the empty string.
func (ka keycloakUserAttributes) Get(key string) string {
if ka == nil {
return ""
}
values := ka[key]
if len(values) == 0 {
return ""
}
return values[0]
}

View File

@ -84,15 +84,6 @@ func TestNewKeycloakManager(t *testing.T) {
} }
} }
type mockKeycloakCredentials struct {
jwtToken JWTToken
err error
}
func (mc *mockKeycloakCredentials) Authenticate() (JWTToken, error) {
return mc.jwtToken, mc.err
}
func TestKeycloakRequestJWTToken(t *testing.T) { func TestKeycloakRequestJWTToken(t *testing.T) {
type requestJWTTokenTest struct { type requestJWTTokenTest struct {
@ -316,108 +307,3 @@ func TestKeycloakAuthenticate(t *testing.T) {
}) })
} }
} }
func TestKeycloakUpdateUserAppMetadata(t *testing.T) {
type updateUserAppMetadataTest struct {
name string
inputReqBody string
expectedReqBody string
appMetadata AppMetadata
statusCode int
helper ManagerHelper
managerCreds ManagerCredentials
assertErrFunc assert.ErrorAssertionFunc
assertErrFuncMessage string
}
appMetadata := AppMetadata{WTAccountID: "ok"}
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
name: "Bad Authentication",
expectedReqBody: "",
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockKeycloakCredentials{
jwtToken: JWTToken{},
err: fmt.Errorf("error"),
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Status Code",
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockKeycloakCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
name: "Bad Response Parsing",
statusCode: 400,
helper: &mockJsonParser{marshalErrorString: "error"},
managerCreds: &mockKeycloakCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Good request",
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 204,
helper: JsonParser{},
managerCreds: &mockKeycloakCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
invite := true
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
name: "Update Pending Invite",
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"true\"]}}", appMetadata.WTAccountID),
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 204,
helper: JsonParser{},
managerCreds: &mockKeycloakCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
t.Run(testCase.name, func(t *testing.T) {
reqClient := mockHTTPClient{
resBody: testCase.inputReqBody,
code: testCase.statusCode,
}
manager := &KeycloakManager{
httpClient: &reqClient,
credentials: testCase.managerCreds,
helper: testCase.helper,
}
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
})
}
}

View File

@ -8,9 +8,9 @@ import (
"strings" "strings"
"time" "time"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/okta/okta-sdk-golang/v2/okta" "github.com/okta/okta-sdk-golang/v2/okta"
"github.com/okta/okta-sdk-golang/v2/okta/query"
"github.com/netbirdio/netbird/management/server/telemetry"
) )
// OktaManager okta manager client instance. // OktaManager okta manager client instance.
@ -76,11 +76,6 @@ func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*
return nil, err return nil, err
} }
err = updateUserProfileSchema(client)
if err != nil {
return nil, err
}
credentials := &OktaCredentials{ credentials := &OktaCredentials{
clientConfig: config, clientConfig: config,
httpClient: httpClient, httpClient: httpClient,
@ -103,49 +98,8 @@ func (oc *OktaCredentials) Authenticate() (JWTToken, error) {
} }
// CreateUser creates a new user in okta Idp and sends an invitation. // CreateUser creates a new user in okta Idp and sends an invitation.
func (om *OktaManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { func (om *OktaManager) CreateUser(_, _, _, _ string) (*UserData, error) {
var ( return nil, fmt.Errorf("method CreateUser not implemented")
sendEmail = true
activate = true
userProfile = okta.UserProfile{
"email": email,
"login": email,
wtAccountID: accountID,
wtPendingInvite: true,
}
)
fields := strings.Fields(name)
if n := len(fields); n > 0 {
userProfile["firstName"] = strings.Join(fields[:n-1], " ")
userProfile["lastName"] = fields[n-1]
}
user, resp, err := om.client.User.CreateUser(context.Background(),
okta.CreateUserRequest{
Profile: &userProfile,
},
&query.Params{
Activate: &activate,
SendEmail: &sendEmail,
},
)
if err != nil {
return nil, err
}
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountCreateUser()
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
}
return parseOktaUser(user)
} }
// GetUserDataByID requests user data from keycloak via ID. // GetUserDataByID requests user data from keycloak via ID.
@ -166,7 +120,13 @@ func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) (
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode) return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
} }
return parseOktaUser(user) userData, err := parseOktaUser(user)
if err != nil {
return nil, err
}
userData.AppMetadata = appMetadata
return userData, nil
} }
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
@ -200,8 +160,7 @@ func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) {
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) { func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
search := fmt.Sprintf("profile.wt_account_id eq %q", accountID) users, resp, err := om.client.User.ListUsers(context.Background(), nil)
users, resp, err := om.client.User.ListUsers(context.Background(), &query.Params{Search: search})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -223,6 +182,7 @@ func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
userData.AppMetadata.WTAccountID = accountID
list = append(list, userData) list = append(list, userData)
} }
@ -256,13 +216,7 @@ func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) {
return nil, err return nil, err
} }
accountID := userData.AppMetadata.WTAccountID indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
} }
return indexedUsers, nil return indexedUsers, nil
@ -270,46 +224,6 @@ func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) {
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
user, resp, err := om.client.User.GetUser(context.Background(), userID)
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to update user, statusCode %d", resp.StatusCode)
}
profile := *user.Profile
if appMetadata.WTPendingInvite != nil {
profile[wtPendingInvite] = *appMetadata.WTPendingInvite
}
if appMetadata.WTAccountID != "" {
profile[wtAccountID] = appMetadata.WTAccountID
}
user.Profile = &profile
_, resp, err = om.client.User.UpdateUser(context.Background(), userID, *user, nil)
if err != nil {
fmt.Println(err.Error())
return err
}
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to update user, statusCode %d", resp.StatusCode)
}
return nil return nil
} }
@ -341,60 +255,12 @@ func (om *OktaManager) DeleteUser(userID string) error {
return nil return nil
} }
// updateUserProfileSchema updates the Okta user schema to include custom fields,
// wt_account_id and wt_pending_invite.
func updateUserProfileSchema(client *okta.Client) error {
// Ensure Okta doesn't enforce user input for these fields, as they are solely used by Netbird
userPermissions := []*okta.UserSchemaAttributePermission{{Action: "HIDE", Principal: "SELF"}}
_, resp, err := client.UserSchema.UpdateUserProfile(
context.Background(),
"default",
okta.UserSchema{
Definitions: &okta.UserSchemaDefinitions{
Custom: &okta.UserSchemaPublic{
Id: "#custom",
Type: "object",
Properties: map[string]*okta.UserSchemaAttribute{
wtAccountID: {
MaxLength: 100,
MinLength: 1,
Required: new(bool),
Scope: "NONE",
Title: "Wt Account Id",
Type: "string",
Permissions: userPermissions,
},
wtPendingInvite: {
Required: new(bool),
Scope: "NONE",
Title: "Wt Pending Invite",
Type: "boolean",
Permissions: userPermissions,
},
},
},
},
})
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unable to update user profile schema, statusCode %d", resp.StatusCode)
}
return nil
}
// parseOktaUserToUserData parse okta user to UserData. // parseOktaUserToUserData parse okta user to UserData.
func parseOktaUser(user *okta.User) (*UserData, error) { func parseOktaUser(user *okta.User) (*UserData, error) {
var oktaUser struct { var oktaUser struct {
Email string `json:"email"` Email string `json:"email"`
FirstName string `json:"firstName"` FirstName string `json:"firstName"`
LastName string `json:"lastName"` LastName string `json:"lastName"`
AccountID string `json:"wt_account_id"`
PendingInvite bool `json:"wt_pending_invite"`
} }
if user == nil { if user == nil {
@ -418,9 +284,5 @@ func parseOktaUser(user *okta.User) (*UserData, error) {
Email: oktaUser.Email, Email: oktaUser.Email,
Name: strings.Join([]string{oktaUser.FirstName, oktaUser.LastName}, " "), Name: strings.Join([]string{oktaUser.FirstName, oktaUser.LastName}, " "),
ID: user.Id, ID: user.Id,
AppMetadata: AppMetadata{
WTAccountID: oktaUser.AccountID,
WTPendingInvite: &oktaUser.PendingInvite,
},
}, nil }, nil
} }

View File

@ -1,15 +1,15 @@
package idp package idp
import ( import (
"testing"
"github.com/okta/okta-sdk-golang/v2/okta" "github.com/okta/okta-sdk-golang/v2/okta"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"testing"
) )
func TestParseOktaUser(t *testing.T) { func TestParseOktaUser(t *testing.T) {
type parseOktaUserTest struct { type parseOktaUserTest struct {
name string name string
invite bool
inputProfile *okta.User inputProfile *okta.User
expectedUserData *UserData expectedUserData *UserData
assertErrFunc assert.ErrorAssertionFunc assertErrFunc assert.ErrorAssertionFunc
@ -17,15 +17,12 @@ func TestParseOktaUser(t *testing.T) {
parseOktaTestCase1 := parseOktaUserTest{ parseOktaTestCase1 := parseOktaUserTest{
name: "Good Request", name: "Good Request",
invite: true,
inputProfile: &okta.User{ inputProfile: &okta.User{
Id: "123", Id: "123",
Profile: &okta.UserProfile{ Profile: &okta.UserProfile{
"email": "test@example.com", "email": "test@example.com",
"firstName": "John", "firstName": "John",
"lastName": "Doe", "lastName": "Doe",
"wt_account_id": "456",
"wt_pending_invite": true,
}, },
}, },
expectedUserData: &UserData{ expectedUserData: &UserData{
@ -41,36 +38,17 @@ func TestParseOktaUser(t *testing.T) {
parseOktaTestCase2 := parseOktaUserTest{ parseOktaTestCase2 := parseOktaUserTest{
name: "Invalid okta user", name: "Invalid okta user",
invite: true,
inputProfile: nil, inputProfile: nil,
expectedUserData: nil, expectedUserData: nil,
assertErrFunc: assert.Error, assertErrFunc: assert.Error,
} }
parseOktaTestCase3 := parseOktaUserTest{ for _, testCase := range []parseOktaUserTest{parseOktaTestCase1, parseOktaTestCase2} {
name: "Invalid pending invite type",
invite: false,
inputProfile: &okta.User{
Id: "123",
Profile: &okta.UserProfile{
"email": "test@example.com",
"firstName": "John",
"lastName": "Doe",
"wt_account_id": "456",
"wt_pending_invite": "true",
},
},
expectedUserData: nil,
assertErrFunc: assert.Error,
}
for _, testCase := range []parseOktaUserTest{parseOktaTestCase1, parseOktaTestCase2, parseOktaTestCase3} {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
userData, err := parseOktaUser(testCase.inputProfile) userData, err := parseOktaUser(testCase.inputProfile)
testCase.assertErrFunc(t, err, testCase.assertErrFunc) testCase.assertErrFunc(t, err, testCase.assertErrFunc)
if err == nil { if err == nil {
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
assert.True(t, userDataEqual(testCase.expectedUserData, userData), "user data should match") assert.True(t, userDataEqual(testCase.expectedUserData, userData), "user data should match")
} }
}) })
@ -83,13 +61,5 @@ func userDataEqual(a, b *UserData) bool {
if a.Email != b.Email || a.Name != b.Name || a.ID != b.ID { if a.Email != b.Email || a.Name != b.Name || a.ID != b.ID {
return false return false
} }
if a.AppMetadata.WTAccountID != b.AppMetadata.WTAccountID {
return false
}
if a.AppMetadata.WTPendingInvite != nil && b.AppMetadata.WTPendingInvite != nil &&
*a.AppMetadata.WTPendingInvite != *b.AppMetadata.WTPendingInvite {
return false
}
return true return true
} }

View File

@ -1,20 +1,18 @@
package idp package idp
import ( import (
"encoding/base64"
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server/telemetry"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/telemetry"
) )
// ZitadelManager zitadel manager client instance. // ZitadelManager zitadel manager client instance.
@ -67,12 +65,6 @@ type zitadelUser struct {
type zitadelAttributes map[string][]map[string]any type zitadelAttributes map[string][]map[string]any
// zitadelMetadata holds additional user data.
type zitadelMetadata struct {
Key string `json:"key"`
Value string `json:"value"`
}
// zitadelProfile represents an zitadel user profile response. // zitadelProfile represents an zitadel user profile response.
type zitadelProfile struct { type zitadelProfile struct {
ID string `json:"id"` ID string `json:"id"`
@ -81,7 +73,6 @@ type zitadelProfile struct {
PreferredLoginName string `json:"preferredLoginName"` PreferredLoginName string `json:"preferredLoginName"`
LoginNames []string `json:"loginNames"` LoginNames []string `json:"loginNames"`
Human *zitadelUser `json:"human"` Human *zitadelUser `json:"human"`
Metadata []zitadelMetadata
} }
// NewZitadelManager creates a new instance of the ZitadelManager. // NewZitadelManager creates a new instance of the ZitadelManager.
@ -234,42 +225,8 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
} }
// CreateUser creates a new user in zitadel Idp and sends an invite. // CreateUser creates a new user in zitadel Idp and sends an invite.
func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { func (zm *ZitadelManager) CreateUser(_, _, _, _ string) (*UserData, error) {
payload, err := buildZitadelCreateUserRequestPayload(email, name) return nil, fmt.Errorf("method CreateUser not implemented")
if err != nil {
return nil, err
}
body, err := zm.post("users/human/_import", payload)
if err != nil {
return nil, err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountCreateUser()
}
var result struct {
UserID string `json:"userId"`
}
err = zm.helper.Unmarshal(body, &result)
if err != nil {
return nil, err
}
invite := true
appMetadata := AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &invite,
}
// Add metadata to new user
err = zm.UpdateUserAppMetadata(result.UserID, appMetadata)
if err != nil {
return nil, err
}
return zm.GetUserDataByID(result.UserID, appMetadata)
} }
// GetUserByEmail searches users with a given email. // GetUserByEmail searches users with a given email.
@ -307,12 +264,6 @@ func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
users := make([]*UserData, 0) users := make([]*UserData, 0)
for _, profile := range profiles.Result { for _, profile := range profiles.Result {
metadata, err := zm.getUserMetadata(profile.ID)
if err != nil {
return nil, err
}
profile.Metadata = metadata
users = append(users, profile.userData()) users = append(users, profile.userData())
} }
@ -336,18 +287,15 @@ func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata
return nil, err return nil, err
} }
metadata, err := zm.getUserMetadata(userID) userData := profile.User.userData()
if err != nil { userData.AppMetadata = appMetadata
return nil, err
}
profile.User.Metadata = metadata
return profile.User.userData(), nil return userData, nil
} }
// GetAccount returns all the users for a given profile. // GetAccount returns all the users for a given profile.
func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) { func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
accounts, err := zm.GetAllAccounts() body, err := zm.post("users/_search", "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -356,7 +304,21 @@ func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
zm.appMetrics.IDPMetrics().CountGetAccount() zm.appMetrics.IDPMetrics().CountGetAccount()
} }
return accounts[accountID], nil var profiles struct{ Result []zitadelProfile }
err = zm.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
for _, profile := range profiles.Result {
userData := profile.userData()
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
}
return users, nil
} }
// GetAllAccounts gets all registered accounts with corresponding user data. // GetAllAccounts gets all registered accounts with corresponding user data.
@ -379,22 +341,8 @@ func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
indexedUsers := make(map[string][]*UserData) indexedUsers := make(map[string][]*UserData)
for _, profile := range profiles.Result { for _, profile := range profiles.Result {
// fetch user metadata
metadata, err := zm.getUserMetadata(profile.ID)
if err != nil {
return nil, err
}
profile.Metadata = metadata
userData := profile.userData() userData := profile.userData()
accountID := userData.AppMetadata.WTAccountID indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
} }
return indexedUsers, nil return indexedUsers, nil
@ -402,42 +350,7 @@ func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
// Metadata values are base64 encoded. // Metadata values are base64 encoded.
func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { func (zm *ZitadelManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
if appMetadata.WTPendingInvite == nil {
appMetadata.WTPendingInvite = new(bool)
}
pendingInviteBuf := strconv.AppendBool([]byte{}, *appMetadata.WTPendingInvite)
wtAccountIDValue := base64.StdEncoding.EncodeToString([]byte(appMetadata.WTAccountID))
wtPendingInviteValue := base64.StdEncoding.EncodeToString(pendingInviteBuf)
metadata := zitadelAttributes{
"metadata": {
{
"key": wtAccountID,
"value": wtAccountIDValue,
},
{
"key": wtPendingInvite,
"value": wtPendingInviteValue,
},
},
}
payload, err := zm.helper.Marshal(metadata)
if err != nil {
return err
}
resource := fmt.Sprintf("users/%s", userID)
_, err = zm.post(resource, string(payload))
if err != nil {
return err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
return nil return nil
} }
@ -459,24 +372,6 @@ func (zm *ZitadelManager) DeleteUser(userID string) error {
} }
return nil return nil
}
// getUserMetadata requests user metadata from zitadel via ID.
func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) {
resource := fmt.Sprintf("users/%s/metadata/_search", userID)
body, err := zm.post(resource, "")
if err != nil {
return nil, err
}
var metadata struct{ Result []zitadelMetadata }
err = zm.helper.Unmarshal(body, &metadata)
if err != nil {
return nil, err
}
return metadata.Result, nil
} }
// post perform Post requests. // post perform Post requests.
@ -516,38 +411,7 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
} }
// delete perform Delete requests. // delete perform Delete requests.
func (zm *ZitadelManager) delete(resource string) error { func (zm *ZitadelManager) delete(_ string) error {
jwtToken, err := zm.credentials.Authenticate()
if err != nil {
return err
}
reqURL := fmt.Sprintf("%s/%s", zm.managementEndpoint, resource)
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := zm.httpClient.Do(req)
if err != nil {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete %s, statusCode %d", reqURL, resp.StatusCode)
}
return nil return nil
} }
@ -587,38 +451,13 @@ func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
return io.ReadAll(resp.Body) return io.ReadAll(resp.Body)
} }
// value returns string represented by the base64 string value.
func (zm zitadelMetadata) value() string {
value, err := base64.StdEncoding.DecodeString(zm.Value)
if err != nil {
return ""
}
return string(value)
}
// userData construct user data from zitadel profile. // userData construct user data from zitadel profile.
func (zp zitadelProfile) userData() *UserData { func (zp zitadelProfile) userData() *UserData {
var ( var (
email string email string
name string name string
wtAccountIDValue string
wtPendingInviteValue bool
) )
for _, metadata := range zp.Metadata {
if metadata.Key == wtAccountID {
wtAccountIDValue = metadata.value()
}
if metadata.Key == wtPendingInvite {
value, err := strconv.ParseBool(metadata.value())
if err == nil {
wtPendingInviteValue = value
}
}
}
// Obtain the email for the human account and the login name, // Obtain the email for the human account and the login name,
// for the machine account. // for the machine account.
if zp.Human != nil { if zp.Human != nil {
@ -635,39 +474,5 @@ func (zp zitadelProfile) userData() *UserData {
Email: email, Email: email,
Name: name, Name: name,
ID: zp.ID, ID: zp.ID,
AppMetadata: AppMetadata{
WTAccountID: wtAccountIDValue,
WTPendingInvite: &wtPendingInviteValue,
},
} }
} }
func buildZitadelCreateUserRequestPayload(email string, name string) (string, error) {
var firstName, lastName string
words := strings.Fields(name)
if n := len(words); n > 0 {
firstName = strings.Join(words[:n-1], " ")
lastName = words[n-1]
}
req := &zitadelUser{
UserName: name,
Profile: zitadelUserInfo{
FirstName: strings.TrimSpace(firstName),
LastName: strings.TrimSpace(lastName),
DisplayName: name,
},
Email: zitadelEmail{
Email: email,
IsEmailVerified: false,
},
}
str, err := json.Marshal(req)
if err != nil {
return "", err
}
return string(str), nil
}

View File

@ -7,9 +7,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
) )
func TestNewZitadelManager(t *testing.T) { func TestNewZitadelManager(t *testing.T) {
@ -63,15 +64,6 @@ func TestNewZitadelManager(t *testing.T) {
} }
} }
type mockZitadelCredentials struct {
jwtToken JWTToken
err error
}
func (mc *mockZitadelCredentials) Authenticate() (JWTToken, error) {
return mc.jwtToken, mc.err
}
func TestZitadelRequestJWTToken(t *testing.T) { func TestZitadelRequestJWTToken(t *testing.T) {
type requestJWTTokenTest struct { type requestJWTTokenTest struct {
@ -296,98 +288,6 @@ func TestZitadelAuthenticate(t *testing.T) {
} }
} }
func TestZitadelUpdateUserAppMetadata(t *testing.T) {
type updateUserAppMetadataTest struct {
name string
inputReqBody string
expectedReqBody string
appMetadata AppMetadata
statusCode int
helper ManagerHelper
managerCreds ManagerCredentials
assertErrFunc assert.ErrorAssertionFunc
assertErrFuncMessage string
}
appMetadata := AppMetadata{WTAccountID: "ok"}
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
name: "Bad Authentication",
expectedReqBody: "",
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
err: fmt.Errorf("error"),
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Response Parsing",
statusCode: 400,
helper: &mockJsonParser{marshalErrorString: "error"},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
name: "Good request",
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"ZmFsc2U=\"}]}",
appMetadata: appMetadata,
statusCode: 200,
helper: JsonParser{},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
invite := true
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Update Pending Invite",
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"dHJ1ZQ==\"}]}",
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 200,
helper: JsonParser{},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4} {
t.Run(testCase.name, func(t *testing.T) {
reqClient := mockHTTPClient{
resBody: testCase.inputReqBody,
code: testCase.statusCode,
}
manager := &ZitadelManager{
httpClient: &reqClient,
credentials: testCase.managerCreds,
helper: testCase.helper,
}
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
})
}
}
func TestZitadelProfile(t *testing.T) { func TestZitadelProfile(t *testing.T) {
type azureProfileTest struct { type azureProfileTest struct {
name string name string
@ -418,16 +318,6 @@ func TestZitadelProfile(t *testing.T) {
IsEmailVerified: true, IsEmailVerified: true,
}, },
}, },
Metadata: []zitadelMetadata{
{
Key: "wt_account_id",
Value: "MQ==",
},
{
Key: "wt_pending_invite",
Value: "ZmFsc2U=",
},
},
}, },
expectedUserData: UserData{ expectedUserData: UserData{
ID: "test1", ID: "test1",
@ -451,16 +341,6 @@ func TestZitadelProfile(t *testing.T) {
"machine", "machine",
}, },
Human: nil, Human: nil,
Metadata: []zitadelMetadata{
{
Key: "wt_account_id",
Value: "MQ==",
},
{
Key: "wt_pending_invite",
Value: "dHJ1ZQ==",
},
},
}, },
expectedUserData: UserData{ expectedUserData: UserData{
ID: "test2", ID: "test2",
@ -480,8 +360,6 @@ func TestZitadelProfile(t *testing.T) {
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match") assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match") assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match") assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match")
}) })
} }
} }

View File

@ -176,6 +176,7 @@ func (w *Worker) generateProperties() properties {
rulesDirection map[string]int rulesDirection map[string]int
groups int groups int
routes int routes int
routesWithRGGroups int
nameservers int nameservers int
uiClient int uiClient int
version string version string
@ -201,6 +202,11 @@ func (w *Worker) generateProperties() properties {
groups = groups + len(account.Groups) groups = groups + len(account.Groups)
routes = routes + len(account.Routes) routes = routes + len(account.Routes)
for _, route := range account.Routes {
if len(route.PeerGroups) > 0 {
routesWithRGGroups++
}
}
nameservers = nameservers + len(account.NameServerGroups) nameservers = nameservers + len(account.NameServerGroups)
for _, policy := range account.Policies { for _, policy := range account.Policies {
@ -282,6 +288,7 @@ func (w *Worker) generateProperties() properties {
metricsProperties["rules"] = rules metricsProperties["rules"] = rules
metricsProperties["groups"] = groups metricsProperties["groups"] = groups
metricsProperties["routes"] = routes metricsProperties["routes"] = routes
metricsProperties["routes_with_routing_groups"] = routesWithRGGroups
metricsProperties["nameservers"] = nameservers metricsProperties["nameservers"] = nameservers
metricsProperties["version"] = version metricsProperties["version"] = version
metricsProperties["min_active_peer_version"] = minActivePeerVersion metricsProperties["min_active_peer_version"] = minActivePeerVersion

View File

@ -0,0 +1,239 @@
package metrics
import (
"testing"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/route"
)
type mockDatasource struct{}
// GetAllConnectedPeers returns a map of connected peer IDs for use in tests with predefined information
func (mockDatasource) GetAllConnectedPeers() map[string]struct{} {
return map[string]struct{}{
"1": {},
}
}
// GetAllAccounts returns a list of *server.Account for use in tests with predefined information
func (mockDatasource) GetAllAccounts() []*server.Account {
return []*server.Account{
{
Id: "1",
Settings: &server.Settings{PeerLoginExpirationEnabled: true},
SetupKeys: map[string]*server.SetupKey{
"1": {
Id: "1",
Ephemeral: true,
UsedTimes: 1,
},
},
Groups: map[string]*server.Group{
"1": {},
"2": {},
},
NameServerGroups: map[string]*nbdns.NameServerGroup{
"1": {},
},
Peers: map[string]*server.Peer{
"1": {
ID: "1",
UserID: "test",
SSHEnabled: true,
Meta: server.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"},
},
},
Policies: []*server.Policy{
{
Rules: []*server.PolicyRule{
{
Bidirectional: true,
Protocol: server.PolicyRuleProtocolTCP,
},
},
},
{
Rules: []*server.PolicyRule{
{
Bidirectional: false,
Protocol: server.PolicyRuleProtocolTCP,
},
},
},
},
Routes: map[string]*route.Route{
"1": {
ID: "1",
PeerGroups: make([]string, 1),
},
},
Users: map[string]*server.User{
"1": {
IsServiceUser: true,
PATs: map[string]*server.PersonalAccessToken{
"1": {},
},
},
"2": {
IsServiceUser: false,
PATs: map[string]*server.PersonalAccessToken{
"1": {},
},
},
},
},
{
Id: "2",
Settings: &server.Settings{PeerLoginExpirationEnabled: true},
SetupKeys: map[string]*server.SetupKey{
"1": {
Id: "1",
Ephemeral: true,
UsedTimes: 1,
},
},
Groups: map[string]*server.Group{
"1": {},
"2": {},
},
NameServerGroups: map[string]*nbdns.NameServerGroup{
"1": {},
},
Peers: map[string]*server.Peer{
"1": {
ID: "1",
UserID: "test",
SSHEnabled: true,
Meta: server.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"},
},
},
Policies: []*server.Policy{
{
Rules: []*server.PolicyRule{
{
Bidirectional: true,
Protocol: server.PolicyRuleProtocolTCP,
},
},
},
{
Rules: []*server.PolicyRule{
{
Bidirectional: false,
Protocol: server.PolicyRuleProtocolTCP,
},
},
},
},
Routes: map[string]*route.Route{
"1": {
ID: "1",
PeerGroups: make([]string, 1),
},
},
Users: map[string]*server.User{
"1": {
IsServiceUser: true,
PATs: map[string]*server.PersonalAccessToken{
"1": {},
},
},
"2": {
IsServiceUser: false,
PATs: map[string]*server.PersonalAccessToken{
"1": {},
},
},
},
},
}
}
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
func TestGenerateProperties(t *testing.T) {
ds := mockDatasource{}
worker := Worker{
dataSource: ds,
connManager: ds,
}
properties := worker.generateProperties()
if properties["accounts"] != 2 {
t.Errorf("expected 2 accounts, got %d", properties["accounts"])
}
if properties["peers"] != 2 {
t.Errorf("expected 2 peers, got %d", properties["peers"])
}
if properties["routes"] != 2 {
t.Errorf("expected 2 routes, got %d", properties["routes"])
}
if properties["rules"] != 4 {
t.Errorf("expected 4 rules, got %d", properties["rules"])
}
if properties["users"] != 2 {
t.Errorf("expected 1 users, got %d", properties["users"])
}
if properties["setup_keys_usage"] != 2 {
t.Errorf("expected 1 setup_keys_usage, got %d", properties["setup_keys_usage"])
}
if properties["pats"] != 4 {
t.Errorf("expected 4 personal_access_tokens, got %d", properties["pats"])
}
if properties["peers_ssh_enabled"] != 2 {
t.Errorf("expected 2 peers_ssh_enabled, got %d", properties["peers_ssh_enabled"])
}
if properties["routes_with_routing_groups"] != 2 {
t.Errorf("expected 2 routes_with_routing_groups, got %d", properties["routes_with_routing_groups"])
}
if properties["rules_protocol_tcp"] != 4 {
t.Errorf("expected 4 rules_protocol_tcp, got %d", properties["rules_protocol_tcp"])
}
if properties["rules_direction_oneway"] != 2 {
t.Errorf("expected 2 rules_direction_oneway, got %d", properties["rules_direction_oneway"])
}
if properties["active_peers_last_day"] != 2 {
t.Errorf("expected 2 active_peers_last_day, got %d", properties["active_peers_last_day"])
}
if properties["min_active_peer_version"] != "0.0.1" {
t.Errorf("expected 0.0.1 min_active_peer_version, got %s", properties["min_active_peer_version"])
}
if properties["max_active_peer_version"] != "0.0.1" {
t.Errorf("expected 0.0.1 max_active_peer_version, got %s", properties["max_active_peer_version"])
}
if properties["peers_login_expiration_enabled"] != 2 {
t.Errorf("expected 2 peers_login_expiration_enabled, got %d", properties["peers_login_expiration_enabled"])
}
if properties["service_users"] != 2 {
t.Errorf("expected 2 service_users, got %d", properties["service_users"])
}
if properties["peer_os_linux"] != 2 {
t.Errorf("expected 2 peer_os_linux, got %d", properties["peer_os_linux"])
}
if properties["ephemeral_peers_setup_keys"] != 2 {
t.Errorf("expected 2 ephemeral_peers_setup_keys, got %d", properties["ephemeral_peers_setup_keys_usage"])
}
if properties["ephemeral_peers_setup_keys_usage"] != 2 {
t.Errorf("expected 2 ephemeral_peers_setup_keys_usage, got %d", properties["ephemeral_peers_setup_keys_usage"])
}
if properties["nameservers"] != 2 {
t.Errorf("expected 2 nameservers, got %d", properties["nameservers"])
}
if properties["groups"] != 4 {
t.Errorf("expected 4 groups, got %d", properties["groups"])
}
if properties["user_peers"] != 2 {
t.Errorf("expected 2 user_peers, got %d", properties["user_peers"])
}
}

View File

@ -20,12 +20,9 @@ type MockAccountManager struct {
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error) GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
AccountExistsFunc func(accountId string) (*bool, error)
GetPeerByKeyFunc func(peerKey string) (*server.Peer, error)
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error) GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool) error MarkPeerConnectedFunc func(peerKey string, connected bool) error
DeletePeerFunc func(accountID, peerKey, userID string) (*server.Peer, error) DeletePeerFunc func(accountID, peerKey, userID string) error
GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error)
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(peerKey string) (*server.Network, error) GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, *server.NetworkMap, error) AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, *server.NetworkMap, error)
@ -33,9 +30,8 @@ type MockAccountManager struct {
SaveGroupFunc func(accountID, userID string, group *server.Group) error SaveGroupFunc func(accountID, userID string, group *server.Group) error
DeleteGroupFunc func(accountID, userId, groupID string) error DeleteGroupFunc func(accountID, userId, groupID string) error
ListGroupsFunc func(accountID string) ([]*server.Group, error) ListGroupsFunc func(accountID string) ([]*server.Group, error)
GroupAddPeerFunc func(accountID, groupID, peerKey string) error GroupAddPeerFunc func(accountID, groupID, peerID string) error
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error GroupDeletePeerFunc func(accountID, groupID, peerID string) error
GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error)
GetRuleFunc func(accountID, ruleID, userID string) (*server.Rule, error) GetRuleFunc func(accountID, ruleID, userID string) (*server.Rule, error)
SaveRuleFunc func(accountID, userID string, rule *server.Rule) error SaveRuleFunc func(accountID, userID string, rule *server.Rule) error
DeleteRuleFunc func(accountID, ruleID, userID string) error DeleteRuleFunc func(accountID, ruleID, userID string) error
@ -50,7 +46,7 @@ type MockAccountManager struct {
UpdatePeerMetaFunc func(peerID string, meta server.PeerSystemMeta) error UpdatePeerMetaFunc func(peerID string, meta server.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
UpdatePeerFunc func(accountID, userID string, peer *server.Peer) (*server.Peer, error) UpdatePeerFunc func(accountID, userID string, peer *server.Peer) (*server.Peer, error)
CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error) GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID, userID string, route *route.Route) error SaveRouteFunc func(accountID, userID string, route *route.Route) error
DeleteRouteFunc func(accountID, routeID, userID string) error DeleteRouteFunc func(accountID, routeID, userID string) error
@ -90,11 +86,11 @@ func (am *MockAccountManager) GetUsersFromAccount(accountID string, userID strin
} }
// DeletePeer mock implementation of DeletePeer from server.AccountManager interface // DeletePeer mock implementation of DeletePeer from server.AccountManager interface
func (am *MockAccountManager) DeletePeer(accountID, peerID, userID string) (*server.Peer, error) { func (am *MockAccountManager) DeletePeer(accountID, peerID, userID string) error {
if am.DeletePeerFunc != nil { if am.DeletePeerFunc != nil {
return am.DeletePeerFunc(accountID, peerID, userID) return am.DeletePeerFunc(accountID, peerID, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented") return status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
} }
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface // GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
@ -140,22 +136,6 @@ func (am *MockAccountManager) GetAccountByUserOrAccountID(
) )
} }
// AccountExists mock implementation of AccountExists from server.AccountManager interface
func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) {
if am.AccountExistsFunc != nil {
return am.AccountExistsFunc(accountId)
}
return nil, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented")
}
// GetPeerByKey mocks implementation of GetPeerByKey from server.AccountManager interface
func (am *MockAccountManager) GetPeerByKey(peerKey string) (*server.Peer, error) {
if am.GetPeerByKeyFunc != nil {
return am.GetPeerByKeyFunc(peerKey)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeerByKey is not implemented")
}
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool) error { func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool) error {
if am.MarkPeerConnectedFunc != nil { if am.MarkPeerConnectedFunc != nil {
@ -164,14 +144,6 @@ func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool)
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
} }
// GetPeerByIP mock implementation of GetPeerByIP from server.AccountManager interface
func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*server.Peer, error) {
if am.GetPeerByIPFunc != nil {
return am.GetPeerByIPFunc(accountId, peerIP)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeerByIP is not implemented")
}
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface // GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
if am.GetAccountFromPATFunc != nil { if am.GetAccountFromPATFunc != nil {
@ -281,29 +253,21 @@ func (am *MockAccountManager) ListGroups(accountID string) ([]*server.Group, err
} }
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
func (am *MockAccountManager) GroupAddPeer(accountID, groupID, peerKey string) error { func (am *MockAccountManager) GroupAddPeer(accountID, groupID, peerID string) error {
if am.GroupAddPeerFunc != nil { if am.GroupAddPeerFunc != nil {
return am.GroupAddPeerFunc(accountID, groupID, peerKey) return am.GroupAddPeerFunc(accountID, groupID, peerID)
} }
return status.Errorf(codes.Unimplemented, "method GroupAddPeer is not implemented") return status.Errorf(codes.Unimplemented, "method GroupAddPeer is not implemented")
} }
// GroupDeletePeer mock implementation of GroupDeletePeer from server.AccountManager interface // GroupDeletePeer mock implementation of GroupDeletePeer from server.AccountManager interface
func (am *MockAccountManager) GroupDeletePeer(accountID, groupID, peerKey string) error { func (am *MockAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
if am.GroupDeletePeerFunc != nil { if am.GroupDeletePeerFunc != nil {
return am.GroupDeletePeerFunc(accountID, groupID, peerKey) return am.GroupDeletePeerFunc(accountID, groupID, peerID)
} }
return status.Errorf(codes.Unimplemented, "method GroupDeletePeer is not implemented") return status.Errorf(codes.Unimplemented, "method GroupDeletePeer is not implemented")
} }
// GroupListPeers mock implementation of GroupListPeers from server.AccountManager interface
func (am *MockAccountManager) GroupListPeers(accountID, groupID string) ([]*server.Peer, error) {
if am.GroupListPeersFunc != nil {
return am.GroupListPeersFunc(accountID, groupID)
}
return nil, status.Errorf(codes.Unimplemented, "method GroupListPeers is not implemented")
}
// GetRule mock implementation of GetRule from server.AccountManager interface // GetRule mock implementation of GetRule from server.AccountManager interface
func (am *MockAccountManager) GetRule(accountID, ruleID, userID string) (*server.Rule, error) { func (am *MockAccountManager) GetRule(accountID, ruleID, userID string) (*server.Rule, error) {
if am.GetRuleFunc != nil { if am.GetRuleFunc != nil {
@ -401,9 +365,9 @@ func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *server.
} }
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface // CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(accountID string, network, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { func (am *MockAccountManager) CreateRoute(accountID, network, peerID string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
if am.CreateRouteFunc != nil { if am.CreateRouteFunc != nil {
return am.CreateRouteFunc(accountID, network, peerID, description, netID, masquerade, metric, groups, enabled, userID) return am.CreateRouteFunc(accountID, network, peerID, peerGroups, description, netID, masquerade, metric, groups, enabled, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
} }

View File

@ -7,7 +7,6 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
@ -74,11 +73,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
return nil, err return nil, err
} }
err = am.updateAccountPeers(account) am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after create nameserver %s", name)
}
am.storeEvent(userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) am.storeEvent(userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
@ -113,11 +108,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n
return err return err
} }
err = am.updateAccountPeers(account) am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return status.Errorf(status.Internal, "failed to update peers after update nameserver %s", nsGroupToSave.Name)
}
am.storeEvent(userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) am.storeEvent(userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
@ -147,10 +138,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, use
return err return err
} }
err = am.updateAccountPeers(account) am.updateAccountPeers(account)
if err != nil {
return status.Errorf(status.Internal, "failed to update peers after deleting nameserver %s", nsGroupID)
}
am.storeEvent(userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) am.storeEvent(userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())

View File

@ -195,16 +195,6 @@ func (p *PeerStatus) Copy() *PeerStatus {
} }
} }
// GetPeerByKey looks up peer by its public WireGuard key
func (am *DefaultAccountManager) GetPeerByKey(peerPubKey string) (*Peer, error) {
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
return nil, err
}
return account.FindPeerByPubKey(peerPubKey)
}
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin. // the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, error) { func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, error) {
@ -290,10 +280,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
if oldStatus.LoginExpired { if oldStatus.LoginExpired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from // we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again. // the expired one. Here we notify them that connection is now allowed again.
err = am.updateAccountPeers(account) am.updateAccountPeers(account)
if err != nil {
return err
}
} }
return nil return nil
@ -364,37 +351,30 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *Pe
return nil, err return nil, err
} }
err = am.updateAccountPeers(account) am.updateAccountPeers(account)
if err != nil {
return nil, err
}
return peer, nil return peer, nil
} }
// DeletePeer removes peer from the account by its IP // deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock
func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) (*Peer, error) { func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID) // the first loop is needed to ensure all peers present under the account before modifying, otherwise
if err != nil { // we might have some inconsistencies
return nil, err peers := make([]*Peer, 0, len(peerIDs))
} for _, peerID := range peerIDs {
peer := account.GetPeer(peerID) peer := account.GetPeer(peerID)
if peer == nil { if peer == nil {
return nil, status.Errorf(status.NotFound, "peer %s not found", peerID) return status.Errorf(status.NotFound, "peer %s not found", peerID)
}
peers = append(peers, peer)
} }
account.DeletePeer(peerID) // the 2nd loop performs the actual modification
for _, peer := range peers {
err = am.Store.SaveAccount(account) account.DeletePeer(peer.ID)
if err != nil { am.peersUpdateManager.SendUpdate(peer.ID,
return nil, err
}
err = am.peersUpdateManager.SendUpdate(peer.ID,
&UpdateMessage{ &UpdateMessage{
Update: &proto.SyncResponse{ Update: &proto.SyncResponse{
// fill those field for backward compatibility // fill those field for backward compatibility
@ -410,36 +390,36 @@ func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) (*
}, },
}, },
}) })
if err != nil { am.peersUpdateManager.CloseChannel(peer.ID)
return nil, err
}
if err := am.updateAccountPeers(account); err != nil {
return nil, err
}
am.peersUpdateManager.CloseChannel(peerID)
am.storeEvent(userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) am.storeEvent(userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
return peer, nil
} }
// GetPeerByIP returns peer by its IP return nil
func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*Peer, error) { }
// DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, err return err
} }
for _, peer := range account.Peers { err = am.deletePeers(account, []string{peerID}, userID)
if peerIP == peer.IP.String() { if err != nil {
return peer, nil return err
}
} }
return nil, status.Errorf(status.NotFound, "peer with IP %s not found", peerIP) err = am.Store.SaveAccount(account)
if err != nil {
return err
}
am.updateAccountPeers(account)
return nil
} }
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
@ -609,10 +589,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
am.storeEvent(opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) am.storeEvent(opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
err = am.updateAccountPeers(account) am.updateAccountPeers(account)
if err != nil {
return nil, nil, err
}
networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain) networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain)
return newPeer, networkMap, nil return newPeer, networkMap, nil
@ -727,10 +704,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
} }
if updateRemotePeers { if updateRemotePeers {
err = am.updateAccountPeers(account) am.updateAccountPeers(account)
if err != nil {
return nil, nil, err
}
} }
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil
} }
@ -804,10 +778,7 @@ func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(peer *Peer, account *A
} }
// trigger network map update // trigger network map update
err = am.updateAccountPeers(account) am.updateAccountPeers(account)
if err != nil {
return nil, err
}
return peer, nil return peer, nil
} }
@ -852,7 +823,9 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
} }
// trigger network map update // trigger network map update
return am.updateAccountPeers(account) am.updateAccountPeers(account)
return nil
} }
// GetPeer for a given accountID, peerID and userID error if not found. // GetPeer for a given accountID, peerID and userID error if not found.
@ -909,21 +882,12 @@ func updatePeerMeta(peer *Peer, meta PeerSystemMeta, account *Account) (*Peer, b
// updateAccountPeers updates all peers that belong to an account. // updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers. // Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(account *Account) error { func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
peers := account.GetPeers() peers := account.GetPeers()
for _, peer := range peers { for _, peer := range peers {
remotePeerNetworkMap, err := am.GetNetworkMap(peer.ID) remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain)
if err != nil {
return err
}
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain()) update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
err = am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update}) am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
if err != nil {
return err
} }
} }
return nil
}

View File

@ -350,7 +350,9 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po
} }
am.storeEvent(userID, policy.ID, accountID, action, policy.EventMeta()) am.storeEvent(userID, policy.ID, accountID, action, policy.EventMeta())
return am.updateAccountPeers(account) am.updateAccountPeers(account)
return nil
} }
// DeletePolicy from the store // DeletePolicy from the store
@ -375,7 +377,9 @@ func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string
am.storeEvent(userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) am.storeEvent(userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
return am.updateAccountPeers(account) am.updateAccountPeers(account)
return nil
} }
// ListPolicies from the store // ListPolicies from the store

View File

@ -9,7 +9,6 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus"
) )
// GetRoute gets a route object from account and route IDs // GetRoute gets a route object from account and route IDs
@ -39,30 +38,82 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
} }
// checkPrefixPeerExists checks the combination of prefix and peer id, if it exists returns an error, otherwise returns nil // checkRoutePrefixExistsForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkPrefixPeerExists(accountID, peerID string, prefix netip.Prefix) error { func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account, peerID, routeID string, peerGroupIDs []string, prefix netip.Prefix) error {
// routes can have both peer and peer_groups
if peerID == "" {
return nil
}
account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
routesWithPrefix := account.GetRoutesByPrefix(prefix) routesWithPrefix := account.GetRoutesByPrefix(prefix)
// lets remember all the peers and the peer groups from routesWithPrefix
seenPeers := make(map[string]bool)
seenPeerGroups := make(map[string]bool)
for _, prefixRoute := range routesWithPrefix { for _, prefixRoute := range routesWithPrefix {
if prefixRoute.Peer == peerID { // we skip route(s) with the same network ID as we want to allow updating of the existing route
return status.Errorf(status.AlreadyExists, "failed to add route with prefix %s - peer already has this route", prefix.String()) // when create a new route routeID is newly generated so nothing will be skipped
if routeID == prefixRoute.ID {
continue
}
if prefixRoute.Peer != "" {
seenPeers[prefixRoute.ID] = true
}
for _, groupID := range prefixRoute.PeerGroups {
seenPeerGroups[groupID] = true
group := account.GetGroup(groupID)
if group == nil {
return status.Errorf(
status.InvalidArgument, "failed to add route with prefix %s - peer group %s doesn't exist",
prefix.String(), groupID)
}
for _, pID := range group.Peers {
seenPeers[pID] = true
} }
} }
}
if peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group
peer := account.GetPeer(peerID)
if peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
if _, ok := seenPeers[peerID]; ok {
return status.Errorf(status.AlreadyExists,
"failed to add route with prefix %s - peer %s already has this route", prefix.String(), peerID)
}
}
// check that peerGroupIDs are not in any route peerGroups list
for _, groupID := range peerGroupIDs {
group := account.GetGroup(groupID) // we validated the group existent before entering this function, o need to check again.
if _, ok := seenPeerGroups[groupID]; ok {
return status.Errorf(
status.AlreadyExists, "failed to add route with prefix %s - peer group %s already has this route",
prefix.String(), group.Name)
}
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
for _, id := range group.Peers {
if _, ok := seenPeers[id]; ok {
peer := account.GetPeer(peerID)
if peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
return status.Errorf(status.AlreadyExists,
"failed to add route with prefix %s - peer %s from the group %s already has this route",
prefix.String(), peer.Name, group.Name)
}
}
}
return nil return nil
} }
// CreateRoute creates and saves a new route // CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(accountID string, network, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@ -71,19 +122,29 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peerID,
return nil, err return nil, err
} }
if peerID != "" { if peerID != "" && len(peerGroupIDs) != 0 {
peer := account.GetPeer(peerID) return nil, status.Errorf(
if peer == nil { status.InvalidArgument,
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) "peer with ID %s and peers group %s should not be provided at the same time",
} peerID, peerGroupIDs)
} }
var newRoute route.Route var newRoute route.Route
newRoute.ID = xid.New().String()
prefixType, newPrefix, err := route.ParseNetwork(network) prefixType, newPrefix, err := route.ParseNetwork(network)
if err != nil { if err != nil {
return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", network) return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", network)
} }
err = am.checkPrefixPeerExists(accountID, peerID, newPrefix)
if len(peerGroupIDs) > 0 {
err = validateGroups(peerGroupIDs, account.Groups)
if err != nil {
return nil, err
}
}
err = am.checkRoutePrefixExistsForPeers(account, peerID, newRoute.ID, peerGroupIDs, newPrefix)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -102,7 +163,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peerID,
} }
newRoute.Peer = peerID newRoute.Peer = peerID
newRoute.ID = xid.New().String() newRoute.PeerGroups = peerGroupIDs
newRoute.Network = newPrefix newRoute.Network = newPrefix
newRoute.NetworkType = prefixType newRoute.NetworkType = prefixType
newRoute.Description = description newRoute.Description = description
@ -123,11 +184,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peerID,
return nil, err return nil, err
} }
err = am.updateAccountPeers(account) am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return &newRoute, status.Errorf(status.Internal, "failed to update peers after create route %s", newPrefix)
}
am.storeEvent(userID, newRoute.ID, accountID, activity.RouteCreated, newRoute.EventMeta()) am.storeEvent(userID, newRoute.ID, accountID, activity.RouteCreated, newRoute.EventMeta())
@ -160,11 +217,20 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
return err return err
} }
if routeToSave.Peer != "" { if routeToSave.Peer != "" && len(routeToSave.PeerGroups) != 0 {
peer := account.GetPeer(routeToSave.Peer) return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
if peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", routeToSave.Peer)
} }
if len(routeToSave.PeerGroups) > 0 {
err = validateGroups(routeToSave.PeerGroups, account.Groups)
if err != nil {
return err
}
}
err = am.checkRoutePrefixExistsForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network)
if err != nil {
return err
} }
err = validateGroups(routeToSave.Groups, account.Groups) err = validateGroups(routeToSave.Groups, account.Groups)
@ -179,10 +245,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
return err return err
} }
err = am.updateAccountPeers(account) am.updateAccountPeers(account)
if err != nil {
return err
}
am.storeEvent(userID, routeToSave.ID, accountID, activity.RouteUpdated, routeToSave.EventMeta()) am.storeEvent(userID, routeToSave.ID, accountID, activity.RouteUpdated, routeToSave.EventMeta())
@ -212,7 +275,9 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string)
am.storeEvent(userID, routy.ID, accountID, activity.RouteRemoved, routy.EventMeta()) am.storeEvent(userID, routy.ID, accountID, activity.RouteRemoved, routy.EventMeta())
return am.updateAccountPeers(account) am.updateAccountPeers(account)
return nil
} }
// ListRoutes returns a list of routes from account // ListRoutes returns a list of routes from account

View File

@ -14,12 +14,24 @@ import (
const ( const (
peer1Key = "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=" peer1Key = "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8="
peer2Key = "/yF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5NyI=" peer2Key = "/yF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5NyI="
peer3Key = "ayF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5NaF="
peer4Key = "ayF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5acc="
peer5Key = "ayF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5a55="
peer1ID = "peer-1-id" peer1ID = "peer-1-id"
peer2ID = "peer-2-id" peer2ID = "peer-2-id"
peer3ID = "peer-3-id"
peer4ID = "peer-4-id"
peer5ID = "peer-5-id"
routeGroup1 = "routeGroup1" routeGroup1 = "routeGroup1"
routeGroup2 = "routeGroup2" routeGroup2 = "routeGroup2"
routeGroup3 = "routeGroup3" // for existing route
routeGroup4 = "routeGroup4" // for existing route
routeGroupHA1 = "routeGroupHA1"
routeGroupHA2 = "routeGroupHA2"
routeInvalidGroup1 = "routeInvalidGroup1" routeInvalidGroup1 = "routeInvalidGroup1"
userID = "testingUser" userID = "testingUser"
existingNetwork = "10.10.10.0/24"
existingRouteID = "random-id"
) )
func TestCreateRoute(t *testing.T) { func TestCreateRoute(t *testing.T) {
@ -27,6 +39,7 @@ func TestCreateRoute(t *testing.T) {
network string network string
netID string netID string
peerKey string peerKey string
peerGroupIDs []string
description string description string
masquerade bool masquerade bool
metric int metric int
@ -67,6 +80,48 @@ func TestCreateRoute(t *testing.T) {
Groups: []string{routeGroup1}, Groups: []string{routeGroup1},
}, },
}, },
{
name: "Happy Path Peer Groups",
inputArgs: input{
network: "192.168.0.0/16",
netID: "happy",
peerGroupIDs: []string{routeGroupHA1, routeGroupHA2},
description: "super",
masquerade: false,
metric: 9999,
enabled: true,
groups: []string{routeGroup1, routeGroup2},
},
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
Network: netip.MustParsePrefix("192.168.0.0/16"),
NetworkType: route.IPv4Network,
NetID: "happy",
PeerGroups: []string{routeGroupHA1, routeGroupHA2},
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1, routeGroup2},
},
},
{
name: "Both peer and peer_groups Provided Should Fail",
inputArgs: input{
network: "192.168.0.0/16",
netID: "happy",
peerKey: peer1ID,
peerGroupIDs: []string{routeGroupHA1},
description: "super",
masquerade: false,
metric: 9999,
enabled: true,
groups: []string{routeGroup1},
},
errFunc: require.Error,
shouldCreate: false,
},
{ {
name: "Bad Prefix Should Fail", name: "Bad Prefix Should Fail",
inputArgs: input{ inputArgs: input{
@ -97,6 +152,36 @@ func TestCreateRoute(t *testing.T) {
errFunc: require.Error, errFunc: require.Error,
shouldCreate: false, shouldCreate: false,
}, },
{
name: "Bad Peer already has this route",
inputArgs: input{
network: existingNetwork,
netID: "bad",
peerKey: peer5ID,
description: "super",
masquerade: false,
metric: 9999,
enabled: true,
groups: []string{routeGroup1},
},
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Bad Peers Group already has this route",
inputArgs: input{
network: existingNetwork,
netID: "bad",
peerGroupIDs: []string{routeGroup1, routeGroup3},
description: "super",
masquerade: false,
metric: 9999,
enabled: true,
groups: []string{routeGroup1},
},
errFunc: require.Error,
shouldCreate: false,
},
{ {
name: "Empty Peer Should Create", name: "Empty Peer Should Create",
inputArgs: input{ inputArgs: input{
@ -238,13 +323,14 @@ func TestCreateRoute(t *testing.T) {
account, err := initTestRouteAccount(t, am) account, err := initTestRouteAccount(t, am)
if err != nil { if err != nil {
t.Error("failed to init testing account") t.Errorf("failed to init testing account: %s", err)
} }
outRoute, err := am.CreateRoute( outRoute, err := am.CreateRoute(
account.Id, account.Id,
testCase.inputArgs.network, testCase.inputArgs.network,
testCase.inputArgs.peerKey, testCase.inputArgs.peerKey,
testCase.inputArgs.peerGroupIDs,
testCase.inputArgs.description, testCase.inputArgs.description,
testCase.inputArgs.netID, testCase.inputArgs.netID,
testCase.inputArgs.masquerade, testCase.inputArgs.masquerade,
@ -272,6 +358,7 @@ func TestCreateRoute(t *testing.T) {
func TestSaveRoute(t *testing.T) { func TestSaveRoute(t *testing.T) {
validPeer := peer2ID validPeer := peer2ID
validUsedPeer := peer5ID
invalidPeer := "nonExisting" invalidPeer := "nonExisting"
validPrefix := netip.MustParsePrefix("192.168.0.0/24") validPrefix := netip.MustParsePrefix("192.168.0.0/24")
invalidPrefix, _ := netip.ParsePrefix("192.168.0.0/34") invalidPrefix, _ := netip.ParsePrefix("192.168.0.0/34")
@ -279,11 +366,14 @@ func TestSaveRoute(t *testing.T) {
invalidMetric := 99999 invalidMetric := 99999
validNetID := "12345678901234567890qw" validNetID := "12345678901234567890qw"
invalidNetID := "12345678901234567890qwertyuiopqwertyuiop1" invalidNetID := "12345678901234567890qwertyuiopqwertyuiop1"
validGroupHA1 := routeGroupHA1
validGroupHA2 := routeGroupHA2
testCases := []struct { testCases := []struct {
name string name string
existingRoute *route.Route existingRoute *route.Route
newPeer *string newPeer *string
newPeerGroups []string
newMetric *int newMetric *int
newPrefix *netip.Prefix newPrefix *netip.Prefix
newGroups []string newGroups []string
@ -325,6 +415,55 @@ func TestSaveRoute(t *testing.T) {
Groups: []string{routeGroup2}, Groups: []string{routeGroup2},
}, },
}, },
{
name: "Happy Path Peer Groups",
existingRoute: &route.Route{
ID: "testingRoute",
Network: netip.MustParsePrefix("192.168.0.0/16"),
NetID: validNetID,
NetworkType: route.IPv4Network,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
},
newPeerGroups: []string{validGroupHA1, validGroupHA2},
newMetric: &validMetric,
newPrefix: &validPrefix,
newGroups: []string{routeGroup2},
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
ID: "testingRoute",
Network: validPrefix,
NetID: validNetID,
NetworkType: route.IPv4Network,
PeerGroups: []string{validGroupHA1, validGroupHA2},
Description: "super",
Masquerade: false,
Metric: validMetric,
Enabled: true,
Groups: []string{routeGroup2},
},
},
{
name: "Both peer and peers_roup Provided Should Fail",
existingRoute: &route.Route{
ID: "testingRoute",
Network: netip.MustParsePrefix("192.168.0.0/16"),
NetID: validNetID,
NetworkType: route.IPv4Network,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
},
newPeer: &validPeer,
newPeerGroups: []string{validGroupHA1},
errFunc: require.Error,
},
{ {
name: "Bad Prefix Should Fail", name: "Bad Prefix Should Fail",
existingRoute: &route.Route{ existingRoute: &route.Route{
@ -461,6 +600,71 @@ func TestSaveRoute(t *testing.T) {
newGroups: []string{routeInvalidGroup1}, newGroups: []string{routeInvalidGroup1},
errFunc: require.Error, errFunc: require.Error,
}, },
{
name: "Allow to modify existing route with new peer",
existingRoute: &route.Route{
ID: "testingRoute",
Network: netip.MustParsePrefix(existingNetwork),
NetID: validNetID,
NetworkType: route.IPv4Network,
Peer: peer1ID,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
},
newPeer: &validPeer,
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
ID: "testingRoute",
Network: netip.MustParsePrefix(existingNetwork),
NetID: validNetID,
NetworkType: route.IPv4Network,
Peer: validPeer,
PeerGroups: []string{},
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
},
},
{
name: "Do not allow to modify existing route with a peer from another route",
existingRoute: &route.Route{
ID: "testingRoute",
Network: netip.MustParsePrefix(existingNetwork),
NetID: validNetID,
NetworkType: route.IPv4Network,
Peer: peer1ID,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
},
newPeer: &validUsedPeer,
errFunc: require.Error,
},
{
name: "Do not allow to modify existing route with a peers group from another route",
existingRoute: &route.Route{
ID: "testingRoute",
Network: netip.MustParsePrefix(existingNetwork),
NetID: validNetID,
NetworkType: route.IPv4Network,
PeerGroups: []string{routeGroup3},
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1},
},
newPeerGroups: []string{routeGroup4},
errFunc: require.Error,
},
} }
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
@ -488,6 +692,9 @@ func TestSaveRoute(t *testing.T) {
if testCase.newPeer != nil { if testCase.newPeer != nil {
routeToSave.Peer = *testCase.newPeer routeToSave.Peer = *testCase.newPeer
} }
if len(testCase.newPeerGroups) != 0 {
routeToSave.PeerGroups = testCase.newPeerGroups
}
if testCase.newMetric != nil { if testCase.newMetric != nil {
routeToSave.Metric = *testCase.newMetric routeToSave.Metric = *testCase.newMetric
} }
@ -569,6 +776,96 @@ func TestDeleteRoute(t *testing.T) {
} }
} }
func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
baseRoute := &route.Route{
Network: netip.MustParsePrefix("192.168.0.0/16"),
NetID: "superNet",
NetworkType: route.IPv4Network,
PeerGroups: []string{routeGroupHA1, routeGroupHA2},
Description: "ha route",
Masquerade: false,
Metric: 9999,
Enabled: true,
Groups: []string{routeGroup1, routeGroup2},
}
am, err := createRouterManager(t)
if err != nil {
t.Error("failed to create account manager")
}
account, err := initTestRouteAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
newAccountRoutes, err := am.GetNetworkMap(peer1ID)
require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
newRoute, err := am.CreateRoute(
account.Id, baseRoute.Network.String(), baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description,
baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.Enabled, userID)
require.NoError(t, err)
require.Equal(t, newRoute.Enabled, true)
peer1Routes, err := am.GetNetworkMap(peer1ID)
require.NoError(t, err)
require.Len(t, peer1Routes.Routes, 3, "HA route should have more than 1 routes")
peer2Routes, err := am.GetNetworkMap(peer2ID)
require.NoError(t, err)
require.Len(t, peer2Routes.Routes, 3, "HA route should have more than 1 routes")
peer4Routes, err := am.GetNetworkMap(peer4ID)
require.NoError(t, err)
require.Len(t, peer4Routes.Routes, 3, "HA route should have more than 1 routes")
groups, err := am.ListGroups(account.Id)
require.NoError(t, err)
var groupHA1, groupHA2 *Group
for _, group := range groups {
switch group.Name {
case routeGroupHA1:
groupHA1 = group
case routeGroupHA2:
groupHA2 = group
}
}
err = am.GroupDeletePeer(account.Id, groupHA1.ID, peer2ID)
require.NoError(t, err)
peer2RoutesAfterDelete, err := am.GetNetworkMap(peer2ID)
require.NoError(t, err)
require.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have only 2 route")
err = am.GroupDeletePeer(account.Id, groupHA2.ID, peer4ID)
require.NoError(t, err)
peer2RoutesAfterDelete, err = am.GetNetworkMap(peer2ID)
require.NoError(t, err)
require.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route")
err = am.GroupAddPeer(account.Id, groupHA2.ID, peer4ID)
require.NoError(t, err)
peer1RoutesAfterAdd, err := am.GetNetworkMap(peer1ID)
require.NoError(t, err)
require.Len(t, peer1RoutesAfterAdd.Routes, 2, "HA route should have more than 1 route")
peer2RoutesAfterAdd, err := am.GetNetworkMap(peer2ID)
require.NoError(t, err)
require.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have more than 1 route")
err = am.DeleteRoute(account.Id, newRoute.ID, userID)
require.NoError(t, err)
peer1DeletedRoute, err := am.GetNetworkMap(peer1ID)
require.NoError(t, err)
require.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1")
}
func TestGetNetworkMap_RouteSync(t *testing.T) { func TestGetNetworkMap_RouteSync(t *testing.T) {
// no routes for peer in different groups // no routes for peer in different groups
// no routes when route is deleted // no routes when route is deleted
@ -599,7 +896,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
createdRoute, err := am.CreateRoute(account.Id, baseRoute.Network.String(), peer1ID, createdRoute, err := am.CreateRoute(account.Id, baseRoute.Network.String(), peer1ID, []string{},
baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false,
userID) userID)
require.NoError(t, err) require.NoError(t, err)
@ -695,6 +992,8 @@ func createRouterStore(t *testing.T) (Store, error) {
} }
func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
t.Helper()
accountID := "testingAcc" accountID := "testingAcc"
domain := "example.com" domain := "example.com"
@ -754,6 +1053,81 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
} }
account.Peers[peer2.ID] = peer2 account.Peers[peer2.ID] = peer2
ips = account.getTakenIPs()
peer3IP, err := AllocatePeerIP(account.Network.Net, ips)
if err != nil {
return nil, err
}
peer3 := &Peer{
IP: peer3IP,
ID: peer3ID,
Key: peer3Key,
Name: "test-host3@netbird.io",
UserID: userID,
Meta: PeerSystemMeta{
Hostname: "test-host3@netbird.io",
GoOS: "darwin",
Kernel: "Darwin",
Core: "13.4.1",
Platform: "arm64",
OS: "darwin",
WtVersion: "development",
UIVersion: "development",
},
}
account.Peers[peer3.ID] = peer3
ips = account.getTakenIPs()
peer4IP, err := AllocatePeerIP(account.Network.Net, ips)
if err != nil {
return nil, err
}
peer4 := &Peer{
IP: peer4IP,
ID: peer4ID,
Key: peer4Key,
Name: "test-host4@netbird.io",
UserID: userID,
Meta: PeerSystemMeta{
Hostname: "test-host4@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
}
account.Peers[peer4.ID] = peer4
ips = account.getTakenIPs()
peer5IP, err := AllocatePeerIP(account.Network.Net, ips)
if err != nil {
return nil, err
}
peer5 := &Peer{
IP: peer5IP,
ID: peer5ID,
Key: peer5Key,
Name: "test-host4@netbird.io",
UserID: userID,
Meta: PeerSystemMeta{
Hostname: "test-host4@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
}
account.Peers[peer5.ID] = peer5
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(account)
if err != nil { if err != nil {
return nil, err return nil, err
@ -770,24 +1144,57 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = am.GroupAddPeer(accountID, groupAll.ID, peer3ID)
newGroup := &Group{ if err != nil {
ID: routeGroup1, return nil, err
Name: routeGroup1,
Peers: []string{peer1.ID},
} }
err = am.SaveGroup(accountID, userID, newGroup) err = am.GroupAddPeer(accountID, groupAll.ID, peer4ID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
newGroup = &Group{ newGroup := []*Group{
{
ID: routeGroup1,
Name: routeGroup1,
Peers: []string{peer1.ID},
},
{
ID: routeGroup2, ID: routeGroup2,
Name: routeGroup2, Name: routeGroup2,
Peers: []string{peer2.ID}, Peers: []string{peer2.ID},
},
{
ID: routeGroup3,
Name: routeGroup3,
Peers: []string{peer5.ID},
},
{
ID: routeGroup4,
Name: routeGroup4,
Peers: []string{peer5.ID},
},
{
ID: routeGroupHA1,
Name: routeGroupHA1,
Peers: []string{peer1.ID, peer2.ID, peer3.ID}, // we have one non Linux peer, see peer3
},
{
ID: routeGroupHA2,
Name: routeGroupHA2,
Peers: []string{peer1.ID, peer4.ID},
},
} }
err = am.SaveGroup(accountID, userID, newGroup) for _, group := range newGroup {
err = am.SaveGroup(accountID, userID, group)
if err != nil {
return nil, err
}
}
_, err = am.CreateRoute(account.Id, existingNetwork, "", []string{routeGroup3, routeGroup4},
"", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -317,7 +317,9 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
} }
}() }()
return newKey, am.updateAccountPeers(account) am.updateAccountPeers(account)
return newKey, nil
} }
// ListSetupKeys returns a list of all setup keys of the account // ListSetupKeys returns a list of all setup keys of the account

View File

@ -19,6 +19,7 @@ type GRPCMetrics struct {
activeStreamsGauge asyncint64.Gauge activeStreamsGauge asyncint64.Gauge
syncRequestDuration syncint64.Histogram syncRequestDuration syncint64.Histogram
loginRequestDuration syncint64.Histogram loginRequestDuration syncint64.Histogram
channelQueueLength syncint64.Histogram
ctx context.Context ctx context.Context
} }
@ -52,6 +53,18 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err return nil, err
} }
// We use histogram here as we have multiple channel at the same time and we want to see a slice at any given time
// Then we should be able to extract min, manx, mean and the percentiles.
// TODO(yury): This needs custom bucketing as we are interested in the values from 0 to server.channelBufferSize (100)
channelQueue, err := meter.SyncInt64().Histogram(
"management.grpc.updatechannel.queue",
instrument.WithDescription("Number of update messages in the channel queue"),
instrument.WithUnit("length"),
)
if err != nil {
return nil, err
}
return &GRPCMetrics{ return &GRPCMetrics{
meter: meter, meter: meter,
syncRequestsCounter: syncRequestsCounter, syncRequestsCounter: syncRequestsCounter,
@ -60,6 +73,7 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
activeStreamsGauge: activeStreamsGauge, activeStreamsGauge: activeStreamsGauge,
syncRequestDuration: syncRequestDuration, syncRequestDuration: syncRequestDuration,
loginRequestDuration: loginRequestDuration, loginRequestDuration: loginRequestDuration,
channelQueueLength: channelQueue,
ctx: ctx, ctx: ctx,
}, err }, err
} }
@ -100,3 +114,8 @@ func (grpcMetrics *GRPCMetrics) RegisterConnectedStreams(producer func() int64)
}, },
) )
} }
// UpdateChannelQueueLength update the histogram that keep distribution of the update messages channel queue
func (metrics *GRPCMetrics) UpdateChannelQueueLength(len int) {
metrics.channelQueueLength.Record(metrics.ctx, int64(len))
}

View File

@ -11,37 +11,54 @@ import (
// StoreMetrics represents all metrics related to the FileStore // StoreMetrics represents all metrics related to the FileStore
type StoreMetrics struct { type StoreMetrics struct {
globalLockAcquisitionDuration syncint64.Histogram globalLockAcquisitionDurationMicro syncint64.Histogram
persistenceDuration syncint64.Histogram globalLockAcquisitionDurationMs syncint64.Histogram
persistenceDurationMicro syncint64.Histogram
persistenceDurationMs syncint64.Histogram
ctx context.Context ctx context.Context
} }
// NewStoreMetrics creates an instance of StoreMetrics // NewStoreMetrics creates an instance of StoreMetrics
func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, error) { func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, error) {
globalLockAcquisitionDuration, err := meter.SyncInt64().Histogram("management.store.global.lock.acquisition.duration.micro", globalLockAcquisitionDurationMicro, err := meter.SyncInt64().Histogram("management.store.global.lock.acquisition.duration.micro",
instrument.WithUnit("microseconds"))
if err != nil {
return nil, err
}
persistenceDuration, err := meter.SyncInt64().Histogram("management.store.persistence.duration.micro",
instrument.WithUnit("microseconds")) instrument.WithUnit("microseconds"))
if err != nil { if err != nil {
return nil, err return nil, err
} }
globalLockAcquisitionDurationMs, err := meter.SyncInt64().Histogram("management.store.global.lock.acquisition.duration.ms")
if err != nil {
return nil, err
}
persistenceDurationMicro, err := meter.SyncInt64().Histogram("management.store.persistence.duration.micro",
instrument.WithUnit("microseconds"))
if err != nil {
return nil, err
}
persistenceDurationMs, err := meter.SyncInt64().Histogram("management.store.persistence.duration.ms")
if err != nil {
return nil, err
}
return &StoreMetrics{ return &StoreMetrics{
globalLockAcquisitionDuration: globalLockAcquisitionDuration, globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro,
persistenceDuration: persistenceDuration, globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs,
persistenceDurationMicro: persistenceDurationMicro,
persistenceDurationMs: persistenceDurationMs,
ctx: ctx, ctx: ctx,
}, nil }, nil
} }
// CountGlobalLockAcquisitionDuration counts the duration of the global lock acquisition // CountGlobalLockAcquisitionDuration counts the duration of the global lock acquisition
func (metrics *StoreMetrics) CountGlobalLockAcquisitionDuration(duration time.Duration) { func (metrics *StoreMetrics) CountGlobalLockAcquisitionDuration(duration time.Duration) {
metrics.globalLockAcquisitionDuration.Record(metrics.ctx, duration.Microseconds()) metrics.globalLockAcquisitionDurationMicro.Record(metrics.ctx, duration.Microseconds())
metrics.globalLockAcquisitionDurationMs.Record(metrics.ctx, duration.Milliseconds())
} }
// CountPersistenceDuration counts the duration of a store persistence operation // CountPersistenceDuration counts the duration of a store persistence operation
func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) { func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) {
metrics.persistenceDuration.Record(metrics.ctx, duration.Microseconds()) metrics.persistenceDurationMicro.Record(metrics.ctx, duration.Microseconds())
metrics.persistenceDurationMs.Record(metrics.ctx, duration.Milliseconds())
} }

View File

@ -118,11 +118,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) {
}, },
} }
log.Debugf("sending new TURN credentials to peer %s", peerID) log.Debugf("sending new TURN credentials to peer %s", peerID)
err := m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update}) m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update})
if err != nil {
log.Errorf("error while sending TURN update to peer %s %v", peerID, err)
// todo maybe continue trying?
}
} }
} }
}() }()

View File

@ -29,7 +29,7 @@ func NewPeersUpdateManager() *PeersUpdateManager {
} }
// SendUpdate sends update message to the peer's channel // SendUpdate sends update message to the peer's channel
func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) error { func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) {
p.channelsMux.Lock() p.channelsMux.Lock()
defer p.channelsMux.Unlock() defer p.channelsMux.Unlock()
if channel, ok := p.peerChannels[peerID]; ok { if channel, ok := p.peerChannels[peerID]; ok {
@ -39,10 +39,9 @@ func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) er
default: default:
log.Warnf("channel for peer %s is %d full", peerID, len(channel)) log.Warnf("channel for peer %s is %d full", peerID, len(channel))
} }
return nil } else {
}
log.Debugf("peer %s has no channel", peerID) log.Debugf("peer %s has no channel", peerID)
return nil }
} }
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer. // CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.

View File

@ -31,10 +31,7 @@ func TestSendUpdate(t *testing.T) {
if _, ok := peersUpdater.peerChannels[peer]; !ok { if _, ok := peersUpdater.peerChannels[peer]; !ok {
t.Error("Error creating the channel") t.Error("Error creating the channel")
} }
err := peersUpdater.SendUpdate(peer, update1) peersUpdater.SendUpdate(peer, update1)
if err != nil {
t.Error("Error sending update: ", err)
}
select { select {
case <-peersUpdater.peerChannels[peer]: case <-peersUpdater.peerChannels[peer]:
default: default:
@ -42,10 +39,7 @@ func TestSendUpdate(t *testing.T) {
} }
for range [channelBufferSize]int{} { for range [channelBufferSize]int{} {
err = peersUpdater.SendUpdate(peer, update1) peersUpdater.SendUpdate(peer, update1)
if err != nil {
t.Errorf("got an early error sending update: %v ", err)
}
} }
update2 := &UpdateMessage{Update: &proto.SyncResponse{ update2 := &UpdateMessage{Update: &proto.SyncResponse{
@ -54,10 +48,7 @@ func TestSendUpdate(t *testing.T) {
}, },
}} }}
err = peersUpdater.SendUpdate(peer, update2) peersUpdater.SendUpdate(peer, update2)
if err != nil {
t.Error("update shouldn't return an error when channel buffer is full")
}
timeout := time.After(5 * time.Second) timeout := time.After(5 * time.Second)
for range [channelBufferSize]int{} { for range [channelBufferSize]int{} {
select { select {

View File

@ -307,8 +307,17 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (
return user, nil return user, nil
} }
func (am *DefaultAccountManager) deleteServiceUser(account *Account, initiatorUserID string, targetUser *User) {
meta := map[string]any{"name": targetUser.ServiceUserName}
am.storeEvent(initiatorUserID, targetUser.Id, account.Id, activity.ServiceUserDeleted, meta)
delete(account.Users, targetUser.Id)
}
// DeleteUser deletes a user from the given account. // DeleteUser deletes a user from the given account.
func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error { func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error {
if initiatorUserID == targetUserID {
return status.Errorf(status.InvalidArgument, "self deletion is not allowed")
}
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@ -317,11 +326,6 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t
return err return err
} }
targetUser := account.Users[targetUserID]
if targetUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
executingUser := account.Users[initiatorUserID] executingUser := account.Users[initiatorUserID]
if executingUser == nil { if executingUser == nil {
return status.Errorf(status.NotFound, "user not found") return status.Errorf(status.NotFound, "user not found")
@ -330,51 +334,68 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t
return status.Errorf(status.PermissionDenied, "only admins can delete users") return status.Errorf(status.PermissionDenied, "only admins can delete users")
} }
peers, err := account.FindUserPeers(targetUserID) targetUser := account.Users[targetUserID]
if err != nil { if targetUser == nil {
return status.Errorf(status.Internal, "failed to find user peers") return status.Errorf(status.NotFound, "target user not found")
} }
if err := am.expireAndUpdatePeers(account, peers); err != nil { // handle service user first and exit, no need to fetch extra data from IDP, etc
log.Errorf("failed update deleted peers expiration: %s", err) if targetUser.IsServiceUser {
return err am.deleteServiceUser(account, initiatorUserID, targetUser)
return am.Store.SaveAccount(account)
} }
targetUserEmail, err := am.getEmailOfTargetUser(account.Id, initiatorUserID, targetUserID) return am.deleteRegularUser(account, initiatorUserID, targetUserID)
}
func (am *DefaultAccountManager) deleteRegularUser(account *Account, initiatorUserID, targetUserID string) error {
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(account.Id, initiatorUserID, targetUserID)
if err != nil { if err != nil {
log.Errorf("failed to resolve email address: %s", err) log.Errorf("failed to resolve email address: %s", err)
return err return err
} }
var meta map[string]any
var eventAction activity.Activity
if targetUser.IsServiceUser {
meta = map[string]any{"name": targetUser.ServiceUserName}
eventAction = activity.ServiceUserDeleted
} else {
meta = map[string]any{"email": targetUserEmail}
eventAction = activity.UserDeleted
}
am.storeEvent(initiatorUserID, targetUserID, accountID, eventAction, meta)
if !isNil(am.idpManager) { if !isNil(am.idpManager) {
err := am.deleteUserFromIDP(targetUserID, accountID) err = am.deleteUserFromIDP(targetUserID, account.Id)
if err != nil { if err != nil {
log.Debugf("failed to delete user from IDP: %s", targetUserID)
return err return err
} }
} }
delete(account.Users, targetUserID) err = am.deleteUserPeers(initiatorUserID, targetUserID, account)
if err != nil {
return err
}
delete(account.Users, targetUserID)
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(account)
if err != nil { if err != nil {
return err return err
} }
meta := map[string]any{"name": tuName, "email": tuEmail}
am.storeEvent(initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
am.updateAccountPeers(account)
return nil return nil
} }
func (am *DefaultAccountManager) deleteUserPeers(initiatorUserID string, targetUserID string, account *Account) error {
peers, err := account.FindUserPeers(targetUserID)
if err != nil {
return status.Errorf(status.Internal, "failed to find user peers")
}
peerIDs := make([]string, 0, len(peers))
for _, peer := range peers {
peerIDs = append(peerIDs, peer.ID)
}
return am.deletePeers(account, peerIDs, initiatorUserID)
}
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error { func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
@ -655,9 +676,7 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd
return nil, err return nil, err
} }
if err := am.updateAccountPeers(account); err != nil { am.updateAccountPeers(account)
log.Errorf("failed updating account peers while updating user %s", accountID)
}
} else { } else {
if err = am.Store.SaveAccount(account); err != nil { if err = am.Store.SaveAccount(account); err != nil {
return nil, err return nil, err
@ -833,6 +852,9 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers []*Peer) error { func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers []*Peer) error {
var peerIDs []string var peerIDs []string
for _, peer := range peers { for _, peer := range peers {
if peer.Status.LoginExpired {
continue
}
peerIDs = append(peerIDs, peer.ID) peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true) peer.MarkLoginExpired(true)
account.UpdatePeer(peer) account.UpdatePeer(peer)
@ -848,9 +870,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers []
if len(peerIDs) != 0 { if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service // this will trigger peer disconnect from the management service
am.peersUpdateManager.CloseChannels(peerIDs) am.peersUpdateManager.CloseChannels(peerIDs)
if err := am.updateAccountPeers(account); err != nil { am.updateAccountPeers(account)
return err
}
} }
return nil return nil
} }
@ -876,18 +896,18 @@ func (am *DefaultAccountManager) deleteUserFromIDP(targetUserID, accountID strin
return nil return nil
} }
func (am *DefaultAccountManager) getEmailOfTargetUser(accountId string, initiatorId, targetId string) (string, error) { func (am *DefaultAccountManager) getEmailAndNameOfTargetUser(accountId, initiatorId, targetId string) (string, string, error) {
userInfos, err := am.GetUsersFromAccount(accountId, initiatorId) userInfos, err := am.GetUsersFromAccount(accountId, initiatorId)
if err != nil { if err != nil {
return "", err return "", "", err
} }
for _, ui := range userInfos { for _, ui := range userInfos {
if ui.ID == targetId { if ui.ID == targetId {
return ui.Email, nil return ui.Email, ui.Name, nil
} }
} }
return "", fmt.Errorf("email not found for user: %s", targetId) return "", "", fmt.Errorf("user info not found for user: %s", targetId)
} }
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {

View File

@ -424,7 +424,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
assert.Nil(t, store.Accounts[mockAccountID].Users[mockServiceUserID]) assert.Nil(t, store.Accounts[mockAccountID].Users[mockServiceUserID])
} }
func TestUser_DeleteUser_regularUser(t *testing.T) { func TestUser_DeleteUser_SelfDelete(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "") account := newAccountWithId(mockAccountID, mockUserID, "")
@ -439,6 +439,32 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
} }
err = am.DeleteUser(mockAccountID, mockUserID, mockUserID) err = am.DeleteUser(mockAccountID, mockUserID, mockUserID)
if err == nil {
t.Fatalf("failed to prevent self deletion")
}
}
func TestUser_DeleteUser_regularUser(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
targetId := "user2"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: true,
ServiceUserName: "user2username",
}
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
err = am.DeleteUser(mockAccountID, mockUserID, targetId)
if err != nil { if err != nil {
t.Errorf("unexpected error: %s", err) t.Errorf("unexpected error: %s", err)
} }

View File

@ -1,4 +1,3 @@
#!/bin/sh
# This code is based on the netbird-installer contribution by physk on GitHub. # This code is based on the netbird-installer contribution by physk on GitHub.
# Source: https://github.com/physk/netbird-installer # Source: https://github.com/physk/netbird-installer
set -e set -e
@ -17,6 +16,12 @@ OS_TYPE=""
ARCH="$(uname -m)" ARCH="$(uname -m)"
PACKAGE_MANAGER="bin" PACKAGE_MANAGER="bin"
INSTALL_DIR="" INSTALL_DIR=""
SUDO=""
if command -v sudo > /dev/null && [ "$(id -u)" -ne 0 ]; then
SUDO="sudo"
fi
get_latest_release() { get_latest_release() {
if [ -n "$GITHUB_TOKEN" ]; then if [ -n "$GITHUB_TOKEN" ]; then
@ -65,27 +70,35 @@ download_release_binary() {
unzip -q -o "$BINARY_NAME" unzip -q -o "$BINARY_NAME"
mv "netbird_ui_${OS_TYPE}_${ARCH}" "$INSTALL_DIR" mv "netbird_ui_${OS_TYPE}_${ARCH}" "$INSTALL_DIR"
else else
sudo mkdir -p "$INSTALL_DIR" ${SUDO} mkdir -p "$INSTALL_DIR"
tar -xzvf "$BINARY_NAME" tar -xzvf "$BINARY_NAME"
sudo mv "${1%_"${BINARY_BASE_NAME}"}" "$INSTALL_DIR/" ${SUDO} mv "${1%_"${BINARY_BASE_NAME}"}" "$INSTALL_DIR/"
fi fi
} }
add_apt_repo() { add_apt_repo() {
sudo apt-get update ${SUDO} apt-get update
sudo apt-get install ca-certificates gnupg -y ${SUDO} apt-get install ca-certificates curl gnupg -y
curl -sSL https://pkgs.wiretrustee.com/debian/public.key \ # Remove old keys and repo source files
| sudo gpg --dearmor --output /usr/share/keyrings/wiretrustee-archive-keyring.gpg ${SUDO} rm -f \
/etc/apt/sources.list.d/netbird.list \
/etc/apt/sources.list.d/wiretrustee.list \
/etc/apt/trusted.gpg.d/wiretrustee.gpg \
/usr/share/keyrings/netbird-archive-keyring.gpg \
/usr/share/keyrings/wiretrustee-archive-keyring.gpg
APT_REPO="deb [signed-by=/usr/share/keyrings/wiretrustee-archive-keyring.gpg] https://pkgs.wiretrustee.com/debian stable main" curl -sSL https://pkgs.netbird.io/debian/public.key \
echo "$APT_REPO" | sudo tee /etc/apt/sources.list.d/wiretrustee.list | ${SUDO} gpg --dearmor -o /usr/share/keyrings/netbird-archive-keyring.gpg
sudo apt-get update echo 'deb [signed-by=/usr/share/keyrings/netbird-archive-keyring.gpg] https://pkgs.netbird.io/debian stable main' \
| ${SUDO} tee /etc/apt/sources.list.d/netbird.list
${SUDO} apt-get update
} }
add_rpm_repo() { add_rpm_repo() {
cat <<-EOF | sudo tee /etc/yum.repos.d/netbird.repo cat <<-EOF | ${SUDO} tee /etc/yum.repos.d/netbird.repo
[NetBird] [NetBird]
name=NetBird name=NetBird
baseurl=https://pkgs.netbird.io/yum/ baseurl=https://pkgs.netbird.io/yum/
@ -104,7 +117,7 @@ add_aur_repo() {
for PKG in $INSTALL_PKGS; do for PKG in $INSTALL_PKGS; do
if ! pacman -Q "$PKG" > /dev/null 2>&1; then if ! pacman -Q "$PKG" > /dev/null 2>&1; then
# Install missing package(s) # Install missing package(s)
sudo pacman -S "$PKG" --noconfirm ${SUDO} pacman -S "$PKG" --noconfirm
# Add installed package for clean up later # Add installed package for clean up later
REMOVE_PKGS="$REMOVE_PKGS $PKG" REMOVE_PKGS="$REMOVE_PKGS $PKG"
@ -121,7 +134,7 @@ add_aur_repo() {
fi fi
# Clean up the installed packages # Clean up the installed packages
sudo pacman -Rs "$REMOVE_PKGS" --noconfirm ${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
} }
install_native_binaries() { install_native_binaries() {
@ -181,6 +194,134 @@ install_netbird() {
fi fi
fi fi
# Run the installation, if a desktop environment is not detected
# only the CLI will be installed
case "$PACKAGE_MANAGER" in
apt)
add_apt_repo
${SUDO} apt-get install netbird -y
if ! $SKIP_UI_APP; then
${SUDO} apt-get install netbird-ui -y
fi
;;
yum)
add_rpm_repo
${SUDO} yum -y install netbird
if ! $SKIP_UI_APP; then
${SUDO} yum -y install netbird-ui
fi
;;
dnf)
add_rpm_repo
${SUDO} dnf -y install dnf-plugin-config-manager
${SUDO} dnf config-manager --add-repo /etc/yum.repos.d/netbird.repo
${SUDO} dnf -y install netbird
if ! $SKIP_UI_APP; then
${SUDO} dnf -y install netbird-ui
fi
;;
pacman)
${SUDO} pacman -Syy
add_aur_repo
;;
brew)
# Remove Wiretrustee if it had been installed using Homebrew before
if brew ls --versions wiretrustee >/dev/null 2>&1; then
echo "Removing existing wiretrustee client"
# Stop and uninstall daemon service:
wiretrustee service stop
wiretrustee service uninstall
# Unlik the app
brew unlink wiretrustee
fi
brew install netbirdio/tap/netbird
if ! $SKIP_UI_APP; then
brew install --cask netbirdio/tap/netbird-ui
fi
;;
*)
if [ "$OS_NAME" = "nixos" ];then
echo "Please add NetBird to your NixOS configuration.nix directly:"
echo ""
echo "services.netbird.enable = true;"
if ! $SKIP_UI_APP; then
echo "environment.systemPackages = [ pkgs.netbird-ui ];"
fi
echo "Build and apply new configuration:"
echo ""
echo "${SUDO} nixos-rebuild switch"
exit 0
fi
install_native_binaries
;;
esac
# Add package manager to config
${SUDO} mkdir -p "$CONFIG_FOLDER"
echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null
# Load and start netbird service
if ! ${SUDO} netbird service install 2>&1; then
echo "NetBird service has already been loaded"
fi
if ! ${SUDO} netbird service start 2>&1; then
echo "NetBird service has already been started"
fi
echo "Installation has been finished. To connect, you need to run NetBird by executing the following command:"
echo ""
echo "netbird up"
}
version_greater_equal() {
printf '%s\n%s\n' "$2" "$1" | sort -V -C
}
is_bin_package_manager() {
if ${SUDO} test -f "$1" && ${SUDO} grep -q "package_manager=bin" "$1" ; then
return 0
else
return 1
fi
}
update_netbird() {
if is_bin_package_manager "$CONFIG_FILE"; then
latest_release=$(get_latest_release)
latest_version=${latest_release#v}
installed_version=$(netbird version)
if [ "$latest_version" = "$installed_version" ]; then
echo "Installed netbird version ($installed_version) is up-to-date"
exit 0
fi
if version_greater_equal "$latest_version" "$installed_version"; then
echo "NetBird new version ($latest_version) available. Updating..."
echo ""
echo "Initiating NetBird update. This will stop the netbird service and restart it after the update"
${SUDO} netbird service stop
${SUDO} netbird service uninstall
install_native_binaries
${SUDO} netbird service install
${SUDO} netbird service start
fi
else
echo "NetBird installation was done using a package manager. Please use your system's package manager to update"
fi
}
# Identify OS name and default package manager # Identify OS name and default package manager
if type uname >/dev/null 2>&1; then if type uname >/dev/null 2>&1; then
case "$(uname)" in case "$(uname)" in
@ -235,135 +376,6 @@ install_netbird() {
esac esac
fi fi
# Run the installation, if a desktop environment is not detected
# only the CLI will be installed
case "$PACKAGE_MANAGER" in
apt)
add_apt_repo
sudo apt-get install netbird -y
if ! $SKIP_UI_APP; then
sudo apt-get install netbird-ui -y
fi
;;
yum)
add_rpm_repo
sudo yum -y install netbird
if ! $SKIP_UI_APP; then
sudo yum -y install netbird-ui
fi
;;
dnf)
add_rpm_repo
sudo dnf -y install dnf-plugin-config-manager
sudo dnf config-manager --add-repo /etc/yum.repos.d/netbird.repo
sudo dnf -y install netbird
if ! $SKIP_UI_APP; then
sudo dnf -y install netbird-ui
fi
;;
pacman)
sudo pacman -Syy
add_aur_repo
;;
brew)
# Remove Wiretrustee if it had been installed using Homebrew before
if brew ls --versions wiretrustee >/dev/null 2>&1; then
echo "Removing existing wiretrustee client"
# Stop and uninstall daemon service:
wiretrustee service stop
wiretrustee service uninstall
# Unlik the app
brew unlink wiretrustee
fi
brew install netbirdio/tap/netbird
if ! $SKIP_UI_APP; then
brew install --cask netbirdio/tap/netbird-ui
fi
;;
*)
if [ "$OS_NAME" = "nixos" ];then
echo "Please add NetBird to your NixOS configuration.nix directly:"
echo ""
echo "services.netbird.enable = true;"
if ! $SKIP_UI_APP; then
echo "environment.systemPackages = [ pkgs.netbird-ui ];"
fi
echo "Build and apply new configuration:"
echo ""
echo "sudo nixos-rebuild switch"
exit 0
fi
install_native_binaries
;;
esac
# Add package manager to config
sudo mkdir -p "$CONFIG_FOLDER"
echo "package_manager=$PACKAGE_MANAGER" | sudo tee "$CONFIG_FILE" > /dev/null
# Load and start netbird service
if ! sudo netbird service install 2>&1; then
echo "NetBird service has already been loaded"
fi
if ! sudo netbird service start 2>&1; then
echo "NetBird service has already been started"
fi
echo "Installation has been finished. To connect, you need to run NetBird by executing the following command:"
echo ""
echo "sudo netbird up"
}
version_greater_equal() {
printf '%s\n%s\n' "$2" "$1" | sort -V -C
}
is_bin_package_manager() {
if sudo test -f "$1" && sudo grep -q "package_manager=bin" "$1" ; then
return 0
else
return 1
fi
}
update_netbird() {
if is_bin_package_manager "$CONFIG_FILE"; then
latest_release=$(get_latest_release)
latest_version=${latest_release#v}
installed_version=$(netbird version)
if [ "$latest_version" = "$installed_version" ]; then
echo "Installed netbird version ($installed_version) is up-to-date"
exit 0
fi
if version_greater_equal "$latest_version" "$installed_version"; then
echo "NetBird new version ($latest_version) available. Updating..."
echo ""
echo "Initiating NetBird update. This will stop the netbird service and restart it after the update"
sudo netbird service stop
sudo netbird service uninstall
install_native_binaries
sudo netbird service install
sudo netbird service start
fi
else
echo "NetBird installation was done using a package manager. Please use your system's package manager to update"
fi
}
case "$1" in case "$1" in
--update) --update)
update_netbird update_netbird

View File

@ -70,6 +70,7 @@ type Route struct {
NetID string NetID string
Description string Description string
Peer string Peer string
PeerGroups []string
NetworkType NetworkType NetworkType NetworkType
Masquerade bool Masquerade bool
Metric int Metric int
@ -79,7 +80,7 @@ type Route struct {
// EventMeta returns activity event meta related to the route // EventMeta returns activity event meta related to the route
func (r *Route) EventMeta() map[string]any { func (r *Route) EventMeta() map[string]any {
return map[string]any{"name": r.NetID, "network_range": r.Network.String(), "peer_id": r.Peer} return map[string]any{"name": r.NetID, "network_range": r.Network.String(), "peer_id": r.Peer, "peer_groups": r.PeerGroups}
} }
// Copy copies a route object // Copy copies a route object
@ -91,12 +92,14 @@ func (r *Route) Copy() *Route {
Network: r.Network, Network: r.Network,
NetworkType: r.NetworkType, NetworkType: r.NetworkType,
Peer: r.Peer, Peer: r.Peer,
PeerGroups: make([]string, len(r.PeerGroups)),
Metric: r.Metric, Metric: r.Metric,
Masquerade: r.Masquerade, Masquerade: r.Masquerade,
Enabled: r.Enabled, Enabled: r.Enabled,
Groups: make([]string, len(r.Groups)), Groups: make([]string, len(r.Groups)),
} }
copy(route.Groups, r.Groups) copy(route.Groups, r.Groups)
copy(route.PeerGroups, r.PeerGroups)
return route return route
} }
@ -111,7 +114,8 @@ func (r *Route) IsEqual(other *Route) bool {
other.Metric == r.Metric && other.Metric == r.Metric &&
other.Masquerade == r.Masquerade && other.Masquerade == r.Masquerade &&
other.Enabled == r.Enabled && other.Enabled == r.Enabled &&
compareGroupsList(r.Groups, other.Groups) compareList(r.Groups, other.Groups) &&
compareList(r.PeerGroups, other.PeerGroups)
} }
// ParseNetwork Parses a network prefix string and returns a netip.Prefix object and if is invalid, IPv4 or IPv6 // ParseNetwork Parses a network prefix string and returns a netip.Prefix object and if is invalid, IPv4 or IPv6
@ -134,7 +138,7 @@ func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) {
return IPv4Network, masked, nil return IPv4Network, masked, nil
} }
func compareGroupsList(list, other []string) bool { func compareList(list, other []string) bool {
if len(list) != len(other) { if len(list) != len(other) {
return false return false
} }

View File

@ -5,6 +5,8 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
log "github.com/sirupsen/logrus"
) )
// WriteJson writes JSON config object to a file creating parent directories if required // WriteJson writes JSON config object to a file creating parent directories if required
@ -54,6 +56,68 @@ func WriteJson(file string, obj interface{}) error {
return nil return nil
} }
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
func DirectWriteJson(file string, obj interface{}) error {
_, _, err := prepareConfigFileDir(file)
if err != nil {
return err
}
targetFile, err := openOrCreateFile(file)
if err != nil {
return err
}
defer func() {
err = targetFile.Close()
if err != nil {
log.Errorf("failed to close file %s: %v", file, err)
}
}()
// make it pretty
bs, err := json.MarshalIndent(obj, "", " ")
if err != nil {
return err
}
err = targetFile.Truncate(0)
if err != nil {
return err
}
_, err = targetFile.Write(bs)
if err != nil {
return err
}
return nil
}
func openOrCreateFile(file string) (*os.File, error) {
s, err := os.Stat(file)
if err == nil {
return os.OpenFile(file, os.O_WRONLY, s.Mode())
}
if !os.IsNotExist(err) {
return nil, err
}
targetFile, err := os.Create(file)
if err != nil {
return nil, err
}
//no:lint
err = targetFile.Chmod(0640)
if err != nil {
_ = targetFile.Close()
return nil, err
}
return targetFile, nil
}
// ReadJson reads JSON config file and maps to a provided interface // ReadJson reads JSON config file and maps to a provided interface
func ReadJson(file string, res interface{}) (interface{}, error) { func ReadJson(file string, res interface{}) (interface{}, error) {