mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-12 22:02:43 +01:00
Merge remote-tracking branch 'origin/chore/benchmark-with-large-runner' into chore/benchmark-with-large-runner
This commit is contained in:
commit
2644f37025
1
.gitignore
vendored
1
.gitignore
vendored
@ -29,3 +29,4 @@ infrastructure_files/setup.env
|
||||
infrastructure_files/setup-*.env
|
||||
.vscode
|
||||
.DS_Store
|
||||
vendor/
|
||||
|
@ -1,4 +1,9 @@
|
||||
<div align="center">
|
||||
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github">
|
||||
Webinar: How to Achieve Zero Trust Access to Kubernetes — Effortlessly
|
||||
</a>
|
||||
<br/>
|
||||
<br/>
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png"/>
|
||||
</p>
|
||||
|
@ -1,4 +1,4 @@
|
||||
FROM alpine:3.21.0
|
||||
FROM alpine:3.21.3
|
||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||
|
@ -95,7 +95,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -1,123 +0,0 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
type MockWGIface struct {
|
||||
CreateFunc func() error
|
||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||
IsUserspaceBindFunc func() bool
|
||||
NameFunc func() string
|
||||
AddressFunc func() device.WGAddress
|
||||
ToInterfaceFunc func() *net.Interface
|
||||
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddrFunc func(newAddr string) error
|
||||
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeerFunc func(peerKey string) error
|
||||
AddAllowedIPFunc func(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
|
||||
CloseFunc func() error
|
||||
SetFilterFunc func(filter device.PacketFilter) error
|
||||
GetFilterFunc func() device.PacketFilter
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
||||
GetInterfaceGUIDStringFunc func() (string, error)
|
||||
GetProxyFunc func() wgproxy.Proxy
|
||||
GetNetFunc func() *netstack.Net
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
|
||||
return m.GetInterfaceGUIDStringFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Create() error {
|
||||
return m.CreateFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
|
||||
return m.CreateOnAndroidFunc(routeRange, ip, domains)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) IsUserspaceBind() bool {
|
||||
return m.IsUserspaceBindFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Name() string {
|
||||
return m.NameFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Address() device.WGAddress {
|
||||
return m.AddressFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) ToInterface() *net.Interface {
|
||||
return m.ToInterfaceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return m.UpFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) UpdateAddr(newAddr string) error {
|
||||
return m.UpdateAddrFunc(newAddr)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RemovePeer(peerKey string) error {
|
||||
return m.RemovePeerFunc(peerKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||
return m.AddAllowedIPFunc(peerKey, allowedIP)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
||||
return m.RemoveAllowedIPFunc(peerKey, allowedIP)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Close() error {
|
||||
return m.CloseFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
|
||||
return m.SetFilterFunc(filter)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetFilter() device.PacketFilter {
|
||||
return m.GetFilterFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetDevice() *device.FilteredDevice {
|
||||
return m.GetDeviceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
|
||||
return m.GetWGDeviceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||
return m.GetStatsFunc(peerKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||
return m.GetProxyFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetNet() *netstack.Net {
|
||||
return m.GetNetFunc()
|
||||
}
|
@ -1,39 +0,0 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
type IWGIface interface {
|
||||
Create() error
|
||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
||||
IsUserspaceBind() bool
|
||||
Name() string
|
||||
Address() device.WGAddress
|
||||
ToInterface() *net.Interface
|
||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddr(newAddr string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeer(peerKey string) error
|
||||
AddAllowedIP(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||
Close() error
|
||||
SetFilter(filter device.PacketFilter) error
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetWGDevice() *wgdevice.Device
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
GetNet() *netstack.Net
|
||||
}
|
111
client/internal/dns.go
Normal file
111
client/internal/dns.go
Normal file
@ -0,0 +1,111 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
|
||||
ip := net.ParseIP(aRecord.RData)
|
||||
if ip == nil || ip.To4() == nil {
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
if !ipNet.Contains(ip) {
|
||||
return nbdns.SimpleRecord{}, false
|
||||
}
|
||||
|
||||
ipOctets := strings.Split(ip.String(), ".")
|
||||
slices.Reverse(ipOctets)
|
||||
rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa")
|
||||
|
||||
return nbdns.SimpleRecord{
|
||||
Name: rdnsName,
|
||||
Type: int(dns.TypePTR),
|
||||
Class: aRecord.Class,
|
||||
TTL: aRecord.TTL,
|
||||
RData: dns.Fqdn(aRecord.Name),
|
||||
}, true
|
||||
}
|
||||
|
||||
// generateReverseZoneName creates the reverse DNS zone name for a given network
|
||||
func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
|
||||
networkIP := ipNet.IP.Mask(ipNet.Mask)
|
||||
maskOnes, _ := ipNet.Mask.Size()
|
||||
|
||||
// round up to nearest byte
|
||||
octetsToUse := (maskOnes + 7) / 8
|
||||
|
||||
octets := strings.Split(networkIP.String(), ".")
|
||||
if octetsToUse > len(octets) {
|
||||
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
|
||||
}
|
||||
|
||||
reverseOctets := make([]string, octetsToUse)
|
||||
for i := 0; i < octetsToUse; i++ {
|
||||
reverseOctets[octetsToUse-1-i] = octets[i]
|
||||
}
|
||||
|
||||
return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
|
||||
}
|
||||
|
||||
// zoneExists checks if a zone with the given name already exists in the configuration
|
||||
func zoneExists(config *nbdns.Config, zoneName string) bool {
|
||||
for _, zone := range config.CustomZones {
|
||||
if zone.Domain == zoneName {
|
||||
log.Debugf("reverse DNS zone %s already exists", zoneName)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// collectPTRRecords gathers all PTR records for the given network from A records
|
||||
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
|
||||
var records []nbdns.SimpleRecord
|
||||
|
||||
for _, zone := range config.CustomZones {
|
||||
for _, record := range zone.Records {
|
||||
if record.Type != int(dns.TypeA) {
|
||||
continue
|
||||
}
|
||||
|
||||
if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
|
||||
records = append(records, ptrRecord)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return records
|
||||
}
|
||||
|
||||
// addReverseZone adds a reverse DNS zone to the configuration for the given network
|
||||
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
||||
zoneName, err := generateReverseZoneName(ipNet)
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
return
|
||||
}
|
||||
|
||||
if zoneExists(config, zoneName) {
|
||||
log.Debugf("reverse DNS zone %s already exists", zoneName)
|
||||
return
|
||||
}
|
||||
|
||||
records := collectPTRRecords(config, ipNet)
|
||||
|
||||
reverseZone := nbdns.CustomZone{
|
||||
Domain: zoneName,
|
||||
Records: records,
|
||||
}
|
||||
|
||||
config.CustomZones = append(config.CustomZones, reverseZone)
|
||||
log.Debugf("added reverse DNS zone: %s with %d records", zoneName, len(records))
|
||||
}
|
@ -9,6 +9,11 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4ReverseZone = ".in-addr.arpa"
|
||||
ipv6ReverseZone = ".ip6.arpa"
|
||||
)
|
||||
|
||||
type hostManager interface {
|
||||
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
||||
restoreHostDNS() error
|
||||
@ -94,9 +99,10 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
||||
}
|
||||
|
||||
for _, customZone := range dnsConfig.CustomZones {
|
||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.TrimSuffix(customZone.Domain, "."),
|
||||
MatchOnly: false,
|
||||
MatchOnly: matchOnly,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -395,12 +395,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
|
||||
localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
return fmt.Errorf("local handler updater: %w", err)
|
||||
}
|
||||
|
||||
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
return fmt.Errorf("upstream handler updater: %w", err)
|
||||
}
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic
|
||||
|
||||
@ -447,7 +447,8 @@ func (s *DefaultServer) buildLocalHandlerUpdate(
|
||||
|
||||
for _, customZone := range customZones {
|
||||
if len(customZone.Records) == 0 {
|
||||
return nil, nil, fmt.Errorf("received an empty list of records")
|
||||
log.Warnf("received a custom zone with empty records, skipping domain: %s", customZone.Domain)
|
||||
continue
|
||||
}
|
||||
|
||||
muxUpdates = append(muxUpdates, handlerWrapper{
|
||||
@ -460,7 +461,8 @@ func (s *DefaultServer) buildLocalHandlerUpdate(
|
||||
for _, record := range customZone.Records {
|
||||
var class uint16 = dns.ClassINET
|
||||
if record.Class != nbdns.DefaultClass {
|
||||
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
|
||||
log.Warnf("received an invalid class type: %s", record.Class)
|
||||
continue
|
||||
}
|
||||
|
||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
||||
|
@ -266,7 +266,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Fail",
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalMap: make(registrationMap),
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
@ -285,7 +285,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldFail: true,
|
||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{
|
||||
domain: ".",
|
||||
handler: dummyHandler,
|
||||
priority: PriorityDefault,
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
|
@ -154,7 +154,7 @@ type Engine struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgInterface iface.IWGIface
|
||||
wgInterface WGIface
|
||||
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
|
||||
@ -953,7 +953,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
}
|
||||
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil {
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
@ -1022,7 +1022,7 @@ func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string {
|
||||
return dnsRoutes
|
||||
}
|
||||
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
||||
CustomZones: make([]nbdns.CustomZone, 0),
|
||||
@ -1062,6 +1062,11 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
|
||||
}
|
||||
dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup)
|
||||
}
|
||||
|
||||
if len(dnsUpdate.CustomZones) > 0 {
|
||||
addReverseZone(&dnsUpdate, network)
|
||||
}
|
||||
|
||||
return dnsUpdate
|
||||
}
|
||||
|
||||
@ -1368,7 +1373,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
||||
return nil, nil, err
|
||||
}
|
||||
routes := toRoutes(netMap.GetRoutes())
|
||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
|
||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
|
||||
return routes, &dnsCfg, nil
|
||||
}
|
||||
|
||||
|
@ -23,10 +23,11 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
@ -48,6 +49,8 @@ import (
|
||||
"github.com/netbirdio/netbird/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -64,6 +67,114 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
type MockWGIface struct {
|
||||
CreateFunc func() error
|
||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||
IsUserspaceBindFunc func() bool
|
||||
NameFunc func() string
|
||||
AddressFunc func() device.WGAddress
|
||||
ToInterfaceFunc func() *net.Interface
|
||||
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
|
||||
UpdateAddrFunc func(newAddr string) error
|
||||
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeerFunc func(peerKey string) error
|
||||
AddAllowedIPFunc func(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
|
||||
CloseFunc func() error
|
||||
SetFilterFunc func(filter device.PacketFilter) error
|
||||
GetFilterFunc func() device.PacketFilter
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
||||
GetInterfaceGUIDStringFunc func() (string, error)
|
||||
GetProxyFunc func() wgproxy.Proxy
|
||||
GetNetFunc func() *netstack.Net
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
|
||||
return m.GetInterfaceGUIDStringFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Create() error {
|
||||
return m.CreateFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
|
||||
return m.CreateOnAndroidFunc(routeRange, ip, domains)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) IsUserspaceBind() bool {
|
||||
return m.IsUserspaceBindFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Name() string {
|
||||
return m.NameFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Address() device.WGAddress {
|
||||
return m.AddressFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) ToInterface() *net.Interface {
|
||||
return m.ToInterfaceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
return m.UpFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) UpdateAddr(newAddr string) error {
|
||||
return m.UpdateAddrFunc(newAddr)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RemovePeer(peerKey string) error {
|
||||
return m.RemovePeerFunc(peerKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||
return m.AddAllowedIPFunc(peerKey, allowedIP)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
||||
return m.RemoveAllowedIPFunc(peerKey, allowedIP)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) Close() error {
|
||||
return m.CloseFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
|
||||
return m.SetFilterFunc(filter)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetFilter() device.PacketFilter {
|
||||
return m.GetFilterFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetDevice() *device.FilteredDevice {
|
||||
return m.GetDeviceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
|
||||
return m.GetWGDeviceFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
||||
return m.GetStatsFunc(peerKey)
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||
return m.GetProxyFunc()
|
||||
}
|
||||
|
||||
func (m *MockWGIface) GetNet() *netstack.Net {
|
||||
return m.GetNetFunc()
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("debug", "console")
|
||||
code := m.Run()
|
||||
@ -245,11 +356,20 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
peer.NewRecorder("https://mgm"),
|
||||
nil)
|
||||
|
||||
wgIface := &iface.MockWGIface{
|
||||
wgIface := &MockWGIface{
|
||||
NameFunc: func() string { return "utun102" },
|
||||
RemovePeerFunc: func(peerKey string) error {
|
||||
return nil
|
||||
},
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
engine.wgInterface = wgIface
|
||||
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
|
||||
@ -692,6 +812,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Domain: "0.66.100.in-addr.arpa.",
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*mgmtProto.NameServerGroup{
|
||||
{
|
||||
@ -721,6 +844,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Domain: "0.66.100.in-addr.arpa.",
|
||||
},
|
||||
},
|
||||
expectedNSGroupsLen: 1,
|
||||
expectedNSGroups: []*nbdns.NameServerGroup{
|
||||
@ -1226,7 +1352,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
8
client/internal/iface.go
Normal file
8
client/internal/iface.go
Normal file
@ -0,0 +1,8 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package internal
|
||||
|
||||
type WGIface interface {
|
||||
wgIfaceBase
|
||||
}
|
@ -1,6 +1,4 @@
|
||||
//go:build !windows
|
||||
|
||||
package iface
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net"
|
||||
@ -16,7 +14,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
type IWGIface interface {
|
||||
type wgIfaceBase interface {
|
||||
Create() error
|
||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
||||
IsUserspaceBind() bool
|
6
client/internal/iface_windows.go
Normal file
6
client/internal/iface_windows.go
Normal file
@ -0,0 +1,6 @@
|
||||
package internal
|
||||
|
||||
type WGIface interface {
|
||||
wgIfaceBase
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
@ -15,7 +15,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
@ -56,7 +55,7 @@ const (
|
||||
type WgConfig struct {
|
||||
WgListenPort int
|
||||
RemoteKey string
|
||||
WgInterface iface.IWGIface
|
||||
WgInterface WGIface
|
||||
AllowedIps string
|
||||
PreSharedKey *wgtypes.Key
|
||||
}
|
||||
|
17
client/internal/peer/iface.go
Normal file
17
client/internal/peer/iface.go
Normal file
@ -0,0 +1,17 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type WGIface interface {
|
||||
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemovePeer(peerKey string) error
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
GetProxy() wgproxy.Proxy
|
||||
}
|
@ -4,19 +4,19 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
runtime "runtime"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@ -62,7 +62,7 @@ type clientNetwork struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgInterface iface.IWGIface
|
||||
wgInterface iface.WGIface
|
||||
routes map[route.ID]*route.Route
|
||||
routeUpdate chan routesUpdate
|
||||
peerStateUpdate chan struct{}
|
||||
@ -75,7 +75,7 @@ type clientNetwork struct {
|
||||
func newClientNetworkWatcher(
|
||||
ctx context.Context,
|
||||
dnsRouteInterval time.Duration,
|
||||
wgInterface iface.IWGIface,
|
||||
wgInterface iface.WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
@ -468,7 +468,7 @@ func handlerFromRoute(
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
dnsRouterInteval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
wgInterface iface.IWGIface,
|
||||
wgInterface iface.WGIface,
|
||||
dnsServer nbdns.Server,
|
||||
peerStore *peerstore.Store,
|
||||
useNewDNSRoute bool,
|
||||
|
@ -13,8 +13,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
@ -48,7 +48,7 @@ type Route struct {
|
||||
currentPeerKey string
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgInterface iface.IWGIface
|
||||
wgInterface iface.WGIface
|
||||
resolverAddr string
|
||||
}
|
||||
|
||||
@ -58,7 +58,7 @@ func NewRoute(
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
interval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
wgInterface iface.IWGIface,
|
||||
wgInterface iface.WGIface,
|
||||
resolverAddr string,
|
||||
) *Route {
|
||||
return &Route{
|
||||
|
9
client/internal/routemanager/iface/iface.go
Normal file
9
client/internal/routemanager/iface/iface.go
Normal file
@ -0,0 +1,9 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package iface
|
||||
|
||||
// WGIface defines subset methods of interface required for router
|
||||
type WGIface interface {
|
||||
wgIfaceBase
|
||||
}
|
22
client/internal/routemanager/iface/iface_common.go
Normal file
22
client/internal/routemanager/iface/iface_common.go
Normal file
@ -0,0 +1,22 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type wgIfaceBase interface {
|
||||
AddAllowedIP(peerKey string, allowedIP string) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
ToInterface() *net.Interface
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetStats(peerKey string) (configurer.WGStats, error)
|
||||
}
|
7
client/internal/routemanager/iface/iface_windows.go
Normal file
7
client/internal/routemanager/iface/iface_windows.go
Normal file
@ -0,0 +1,7 @@
|
||||
package iface
|
||||
|
||||
// WGIface defines subset methods of interface required for router
|
||||
type WGIface interface {
|
||||
wgIfaceBase
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
@ -15,13 +15,13 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
@ -52,7 +52,7 @@ type ManagerConfig struct {
|
||||
Context context.Context
|
||||
PublicKey string
|
||||
DNSRouteInterval time.Duration
|
||||
WGInterface iface.IWGIface
|
||||
WGInterface iface.WGIface
|
||||
StatusRecorder *peer.Status
|
||||
RelayManager *relayClient.Manager
|
||||
InitialRoutes []*route.Route
|
||||
@ -74,7 +74,7 @@ type DefaultManager struct {
|
||||
sysOps *systemops.SysOps
|
||||
statusRecorder *peer.Status
|
||||
relayMgr *relayClient.Manager
|
||||
wgInterface iface.IWGIface
|
||||
wgInterface iface.WGIface
|
||||
pubKey string
|
||||
notifier *notifier.Notifier
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
|
@ -7,8 +7,8 @@ import (
|
||||
"fmt"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@ -22,6 +22,6 @@ func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (*serverRouter, error) {
|
||||
func newServerRouter(context.Context, iface.WGIface, firewall.Manager, *peer.Status) (*serverRouter, error) {
|
||||
return nil, fmt.Errorf("server route not supported on this os")
|
||||
}
|
||||
|
@ -11,8 +11,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@ -22,11 +22,11 @@ type serverRouter struct {
|
||||
ctx context.Context
|
||||
routes map[route.ID]*route.Route
|
||||
firewall firewall.Manager
|
||||
wgInterface iface.IWGIface
|
||||
wgInterface iface.WGIface
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) {
|
||||
func newServerRouter(ctx context.Context, wgInterface iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) {
|
||||
return &serverRouter{
|
||||
ctx: ctx,
|
||||
routes: make(map[route.ID]*route.Route),
|
||||
|
@ -13,7 +13,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -23,7 +23,7 @@ const (
|
||||
)
|
||||
|
||||
// Setup configures sysctl settings for RP filtering and source validation.
|
||||
func Setup(wgIface iface.IWGIface) (map[string]int, error) {
|
||||
func Setup(wgIface iface.WGIface) (map[string]int, error) {
|
||||
keys := map[string]int{}
|
||||
var result *multierror.Error
|
||||
|
||||
|
@ -5,7 +5,7 @@ import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
)
|
||||
@ -19,7 +19,7 @@ type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
|
||||
|
||||
type SysOps struct {
|
||||
refCounter *ExclusionCounter
|
||||
wgInterface iface.IWGIface
|
||||
wgInterface iface.WGIface
|
||||
// prefixes is tracking all the current added prefixes im memory
|
||||
// (this is used in iOS as all route updates require a full table update)
|
||||
//nolint
|
||||
@ -30,7 +30,7 @@ type SysOps struct {
|
||||
notifier *notifier.Notifier
|
||||
}
|
||||
|
||||
func NewSysOps(wgInterface iface.IWGIface, notifier *notifier.Notifier) *SysOps {
|
||||
func NewSysOps(wgInterface iface.WGIface, notifier *notifier.Notifier) *SysOps {
|
||||
return &SysOps{
|
||||
wgInterface: wgInterface,
|
||||
notifier: notifier,
|
||||
|
@ -16,8 +16,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
@ -149,7 +149,7 @@ func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
|
||||
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
||||
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
||||
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIface, initialNextHop Nexthop) (Nexthop, error) {
|
||||
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
|
||||
addr := prefix.Addr()
|
||||
switch {
|
||||
case addr.IsLoopback(),
|
||||
|
@ -134,7 +134,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
50
go.mod
50
go.mod
@ -24,8 +24,8 @@ require (
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
google.golang.org/grpc v1.64.1
|
||||
google.golang.org/protobuf v1.34.2
|
||||
google.golang.org/grpc v1.70.0
|
||||
google.golang.org/protobuf v1.36.4
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||
)
|
||||
|
||||
@ -60,7 +60,7 @@ require (
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250220173202-e599d83524fc
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
@ -76,27 +76,27 @@ require (
|
||||
github.com/shirou/gopsutil/v3 v3.24.4
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/testcontainers/testcontainers-go v0.31.0
|
||||
github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0
|
||||
github.com/things-go/go-socks5 v0.0.4
|
||||
github.com/yusufpapurcu/wmi v1.2.4
|
||||
github.com/zcalusic/sysinfo v1.1.3
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0
|
||||
go.opentelemetry.io/otel v1.26.0
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0
|
||||
go.opentelemetry.io/otel v1.34.0
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
|
||||
go.opentelemetry.io/otel/metric v1.26.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.26.0
|
||||
go.opentelemetry.io/otel/metric v1.34.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.32.0
|
||||
go.uber.org/zap v1.27.0
|
||||
goauthentik.io/api/v3 v3.2023051.3
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
|
||||
golang.org/x/net v0.33.0
|
||||
golang.org/x/oauth2 v0.19.0
|
||||
golang.org/x/net v0.34.0
|
||||
golang.org/x/oauth2 v0.26.0
|
||||
golang.org/x/sync v0.10.0
|
||||
golang.org/x/term v0.28.0
|
||||
google.golang.org/api v0.177.0
|
||||
google.golang.org/api v0.220.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/mysql v1.5.7
|
||||
gorm.io/driver/postgres v1.5.7
|
||||
@ -106,9 +106,9 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go/auth v0.3.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||
cloud.google.com/go/auth v0.14.1 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.7 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.6.0 // indirect
|
||||
dario.cat/mergo v1.0.0 // indirect
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
|
||||
@ -151,7 +151,7 @@ require (
|
||||
github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 // indirect
|
||||
github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6 // indirect
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
|
||||
github.com/go-logr/logr v1.4.1 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||
github.com/go-redis/redis/v8 v8.11.5 // indirect
|
||||
@ -160,12 +160,11 @@ require (
|
||||
github.com/go-text/render v0.2.0 // indirect
|
||||
github.com/go-text/typesetting v0.2.0 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect
|
||||
github.com/google/s2a-go v0.1.7 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
|
||||
github.com/gopherjs/gopherjs v1.17.2 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-uuid v1.0.3 // indirect
|
||||
@ -221,20 +220,19 @@ require (
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
github.com/yuin/goldmark v1.7.1 // indirect
|
||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.26.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.26.0 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.34.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.34.0 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/image v0.18.0 // indirect
|
||||
golang.org/x/mod v0.17.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/time v0.10.0 // indirect
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250127172529-29210b9bc287 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
|
||||
|
111
go.sum
111
go.sum
@ -18,10 +18,10 @@ cloud.google.com/go v0.74.0/go.mod h1:VV1xSbzvo+9QJOxLDaJfTjx5e+MePCpCWwvftOeQmW
|
||||
cloud.google.com/go v0.78.0/go.mod h1:QjdrLG0uq+YwhjoVOLsS1t7TW8fs36kLs4XO5R5ECHg=
|
||||
cloud.google.com/go v0.79.0/go.mod h1:3bzgcEeQlzbuEAYu4mrWhKqWjmpprinYgKJLgKHnbb8=
|
||||
cloud.google.com/go v0.81.0/go.mod h1:mk/AM35KwGk/Nm2YSeZbxXdrNK3KZOYHmLkOqC2V6E0=
|
||||
cloud.google.com/go/auth v0.3.0 h1:PRyzEpGfx/Z9e8+lHsbkoUVXD0gnu4MNmm7Gp8TQNIs=
|
||||
cloud.google.com/go/auth v0.3.0/go.mod h1:lBv6NKTWp8E3LPzmO1TbiiRKc4drLOfHsgmlH9ogv5w=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q=
|
||||
cloud.google.com/go/auth v0.14.1 h1:AwoJbzUdxA/whv1qj3TLKwh3XX5sikny2fc40wUl+h0=
|
||||
cloud.google.com/go/auth v0.14.1/go.mod h1:4JHUxlGXisL0AW8kXPtUF6ztuOksyfUQNFjfsOCXkPM=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.7 h1:/Lc7xODdqcEw8IrZ9SvwnlLX6j9FHQM74z6cBk9Rw6M=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.7/go.mod h1:NTbTTzfvPl1Y3V1nPpOgl2w6d/FjO7NNUQaWSox6ZMc=
|
||||
cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o=
|
||||
cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE=
|
||||
cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc=
|
||||
@ -29,8 +29,8 @@ cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUM
|
||||
cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc=
|
||||
cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ=
|
||||
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
|
||||
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
|
||||
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
|
||||
cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
|
||||
cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk=
|
||||
cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqClKRT5SZwBmk=
|
||||
@ -225,8 +225,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
|
||||
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
@ -263,14 +263,12 @@ github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
|
||||
github.com/golang/glog v1.2.0 h1:uCdmnmatrKCgMBlM4rMuJZWOkPDqdbZPnrMXDY4gI68=
|
||||
github.com/golang/glog v1.2.0/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
|
||||
github.com/golang/glog v1.2.3 h1:oDTdz9f5VGVVNGu/Q7UXKWYsD0873HXLHdJUNBsSEKM=
|
||||
github.com/golang/glog v1.2.3/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
|
||||
github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||
github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
|
||||
@ -345,18 +343,18 @@ github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLe
|
||||
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd h1:1FjCyPC+syAzJ5/2S8fqdZK1R22vvA0J7JZKcuOIQ7Y=
|
||||
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg=
|
||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
|
||||
github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw=
|
||||
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
|
||||
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
||||
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA=
|
||||
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
||||
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
|
||||
github.com/googleapis/gax-go/v2 v2.12.3 h1:5/zPPDvw8Q1SuXjrqrZslrqT7dL/uJT2CQii/cLCKqA=
|
||||
github.com/googleapis/gax-go/v2 v2.12.3/go.mod h1:AKloxT6GtNbaLm8QTNSidHUVsHYcBHwWRvkNFJUQcS4=
|
||||
github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q=
|
||||
github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA=
|
||||
github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d/go.mod h1:sJBsCZ4ayReDTBIg8b9dl28c5xFWyhBTVRp3pOg5EKY=
|
||||
github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw=
|
||||
github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs=
|
||||
@ -529,8 +527,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6 h1:I/ODkZ8rSDOzlJbhEjD2luSI71zl+s5JgNvFHY0+mBU=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6/go.mod h1:izUUs1NT7ja+PwSX3kJ7ox8Kkn478tboBJSjL4kU6J0=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250220173202-e599d83524fc h1:18xvjOy2tZVIK7rihNpf9DF/3mAiljYKWaQlWa9vJgI=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250220173202-e599d83524fc/go.mod h1:izUUs1NT7ja+PwSX3kJ7ox8Kkn478tboBJSjL4kU6J0=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||
@ -617,8 +615,8 @@ github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KW
|
||||
github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so=
|
||||
github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM=
|
||||
github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4=
|
||||
@ -683,11 +681,11 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
|
||||
github.com/testcontainers/testcontainers-go v0.31.0 h1:W0VwIhcEVhRflwL9as3dhY6jXjVCA27AkmbnZ+UTh3U=
|
||||
github.com/testcontainers/testcontainers-go v0.31.0/go.mod h1:D2lAoA0zUFiSY+eAflqK5mcUx/A5hrrORaEQrd0SefI=
|
||||
@ -739,28 +737,28 @@ go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
|
||||
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
|
||||
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
|
||||
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc=
|
||||
go.opentelemetry.io/otel v1.26.0 h1:LQwgL5s/1W7YiiRwxf03QGnWLb2HW4pLiAhaA5cZXBs=
|
||||
go.opentelemetry.io/otel v1.26.0/go.mod h1:UmLkJHUAidDval2EICqBMbnAd0/m2vmpf/dAM+fvFs4=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0 h1:PS8wXpbyaDJQ2VDHHncMe9Vct0Zn1fEjpsjrLxGJoSc=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.58.0/go.mod h1:HDBUsEjOuRC0EzKZ1bSaRGZWUBAzo+MhAcUUORSr4D0=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 h1:yd02MEjBdJkG3uabWP9apV+OuWRIXGDuJEUJbOHmCFU=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0/go.mod h1:umTcuxiv1n/s/S6/c2AT/g2CQ7u5C59sHDNmfSwgz7Q=
|
||||
go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY=
|
||||
go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s=
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o=
|
||||
go.opentelemetry.io/otel/metric v1.26.0 h1:7S39CLuY5Jgg9CrnA9HHiEjGMF/X2VHvoXGgSllRz30=
|
||||
go.opentelemetry.io/otel/metric v1.26.0/go.mod h1:SY+rHOI4cEawI9a7N1A4nIg/nTQXe1ccCNWYOJUrpX4=
|
||||
go.opentelemetry.io/otel/sdk v1.26.0 h1:Y7bumHf5tAiDlRYFmGqetNcLaVUZmh4iYfmGxtmz7F8=
|
||||
go.opentelemetry.io/otel/sdk v1.26.0/go.mod h1:0p8MXpqLeJ0pzcszQQN4F0S5FVjBLgypeGSngLsmirs=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.26.0 h1:cWSks5tfriHPdWFnl+qpX3P681aAYqlZHcAyHw5aU9Y=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.26.0/go.mod h1:ClMFFknnThJCksebJwz7KIyEDHO+nTB6gK8obLy8RyE=
|
||||
go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2LP5sQA=
|
||||
go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0=
|
||||
go.opentelemetry.io/otel/metric v1.34.0 h1:+eTR3U0MyfWjRDhmFMxe2SsW64QrZ84AOhvqS7Y+PoQ=
|
||||
go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE=
|
||||
go.opentelemetry.io/otel/sdk v1.34.0 h1:95zS4k/2GOy069d321O8jWgYsW3MzVV+KuSPKp7Wr1A=
|
||||
go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.32.0 h1:rZvFnvmvawYb0alrYkjraqJq0Z4ZUJAiyYCU9snn1CU=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.32.0/go.mod h1:PWeZlq0zt9YkYAp3gjKZ0eicRYvOh1Gd+X99x6GHpCQ=
|
||||
go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k=
|
||||
go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE=
|
||||
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
|
||||
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
|
||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
@ -885,8 +883,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
||||
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
|
||||
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
@ -900,8 +898,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ
|
||||
golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
|
||||
golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
|
||||
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
|
||||
golang.org/x/oauth2 v0.19.0 h1:9+E/EZBCbTLNrbN35fHv/a/d/mOBatymz1zbtQrXpIg=
|
||||
golang.org/x/oauth2 v0.19.0/go.mod h1:vYi7skDa1x015PmRRYZ7+s1cWyPgrPiSYRe4rnsexc8=
|
||||
golang.org/x/oauth2 v0.26.0 h1:afQXWNNaeC4nvZ0Ed9XvCCzXM6UHJG7iCg0W4fPqSBE=
|
||||
golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@ -1019,8 +1017,8 @@ golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
|
||||
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20181011042414-1f849cf54d09/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@ -1114,8 +1112,8 @@ google.golang.org/api v0.40.0/go.mod h1:fYKFpnQN0DsDSKRVRcQSDQNtqWPfM9i+zNPxepjR
|
||||
google.golang.org/api v0.41.0/go.mod h1:RkxM5lITDfTzmyKFPt+wGrCJbVfniCr2ool8kTBzRTU=
|
||||
google.golang.org/api v0.43.0/go.mod h1:nQsDGjRXMo4lvh5hP0TKqF244gqhGcr/YSIykhUk/94=
|
||||
google.golang.org/api v0.44.0/go.mod h1:EBOGZqzyhtvMDoxwS97ctnh0zUmYY6CxqXsc1AvkYD8=
|
||||
google.golang.org/api v0.177.0 h1:8a0p/BbPa65GlqGWtUKxot4p0TV8OGOfyTjtmkXNXmk=
|
||||
google.golang.org/api v0.177.0/go.mod h1:srbhue4MLjkjbkux5p3dw/ocYOSZTaIEvf7bCOnFQDw=
|
||||
google.golang.org/api v0.220.0 h1:3oMI4gdBgB72WFVwE1nerDD8W3HUOS4kypK6rRLbGns=
|
||||
google.golang.org/api v0.220.0/go.mod h1:26ZAlY6aN/8WgpCzjPNy18QpYaz7Zgg1h0qe1GkZEmY=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
@ -1164,10 +1162,11 @@ google.golang.org/genproto v0.0.0-20210310155132-4ce2db91004e/go.mod h1:FWY/as6D
|
||||
google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
|
||||
google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A=
|
||||
google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 h1:OpXbo8JnN8+jZGPrL4SSfaDjSCjupr8lXyBAbexEm/U=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
|
||||
google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576 h1:CkkIfIt50+lT6NHAVoRYEyAvQGFM7xEwXUUywFvEb3Q=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576/go.mod h1:1R3kvZ1dtP3+4p4d3G8uJ8rFk/fWlScl38vanWACI08=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250127172529-29210b9bc287 h1:J1H9f+LEdWAfHcez/4cvaVBox7cOYT+IU6rgqj5x++8=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250127172529-29210b9bc287/go.mod h1:8BS3B93F/U1juMFq9+EDk+qOT5CO1R9IzXxG3PTqiRk=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
|
||||
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
|
||||
@ -1188,8 +1187,8 @@ google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG
|
||||
google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
|
||||
google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU=
|
||||
google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM=
|
||||
google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA=
|
||||
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
|
||||
google.golang.org/grpc v1.70.0 h1:pWFv03aZoHzlRKHWicjsZytKAiYCtNS0dHbXnIdq7jQ=
|
||||
google.golang.org/grpc v1.70.0/go.mod h1:ofIJqVKDXx/JiXrwr2IG4/zwdH9txy3IlF40RmcJSQw=
|
||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
||||
@ -1204,8 +1203,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
|
||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
|
||||
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
|
||||
google.golang.org/protobuf v1.36.4 h1:6A3ZDJHn/eNqc1i+IdefRzy/9PokBTPvcqMySR7NNIM=
|
||||
google.golang.org/protobuf v1.36.4/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
@ -78,7 +78,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -39,13 +39,12 @@ import (
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/metrics"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
@ -255,24 +254,13 @@ var (
|
||||
tlsEnabled = true
|
||||
}
|
||||
|
||||
jwtValidator, err := jwtclaims.NewJWTValidator(
|
||||
ctx,
|
||||
authManager := auth.NewManager(store,
|
||||
config.HttpConfig.AuthIssuer,
|
||||
config.GetAuthAudiences(),
|
||||
config.HttpConfig.AuthAudience,
|
||||
config.HttpConfig.AuthKeysLocation,
|
||||
config.HttpConfig.IdpSignKeyRefreshEnabled,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating JWT validator: %v", err)
|
||||
}
|
||||
|
||||
httpAPIAuthCfg := configs.AuthCfg{
|
||||
Issuer: config.HttpConfig.AuthIssuer,
|
||||
Audience: config.HttpConfig.AuthAudience,
|
||||
UserIDClaim: config.HttpConfig.AuthUserIDClaim,
|
||||
KeysLocation: config.HttpConfig.AuthKeysLocation,
|
||||
}
|
||||
|
||||
config.HttpConfig.AuthUserIDClaim,
|
||||
config.GetAuthAudiences(),
|
||||
config.HttpConfig.IdpSignKeyRefreshEnabled)
|
||||
userManager := users.NewManager(store)
|
||||
settingsManager := settings.NewManager(store)
|
||||
permissionsManager := permissions.NewManager(userManager, settingsManager)
|
||||
@ -281,7 +269,7 @@ var (
|
||||
routersManager := routers.NewManager(store, permissionsManager, accountManager)
|
||||
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)
|
||||
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, authManager, appMetrics, config, integratedPeerValidator)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating HTTP API handler: %v", err)
|
||||
}
|
||||
@ -290,7 +278,7 @@ var (
|
||||
ephemeralManager.LoadInitialPeers(ctx)
|
||||
|
||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||
srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager)
|
||||
srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating gRPC API handler: %v", err)
|
||||
}
|
||||
|
@ -2,11 +2,8 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
@ -24,14 +21,13 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
@ -77,13 +73,10 @@ type AccountManager interface {
|
||||
GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error)
|
||||
AccountExists(ctx context.Context, accountID string) (bool, error)
|
||||
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
||||
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
GetPATInfo(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error)
|
||||
GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||
MarkPATUsed(ctx context.Context, tokenID string) error
|
||||
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
||||
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
|
||||
@ -150,6 +143,7 @@ type AccountManager interface {
|
||||
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
||||
UpdateAccountPeers(ctx context.Context, accountID string)
|
||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
|
||||
}
|
||||
|
||||
type DefaultAccountManager struct {
|
||||
@ -954,11 +948,11 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims,
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth nbcontext.UserAuth,
|
||||
primaryDomain bool,
|
||||
) error {
|
||||
if claims.Domain == "" {
|
||||
log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims)
|
||||
if userAuth.Domain == "" {
|
||||
log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", userAuth)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -971,11 +965,11 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
|
||||
return err
|
||||
}
|
||||
|
||||
if domainIsUpToDate(accountDomain, domainCategory, claims) {
|
||||
if domainIsUpToDate(accountDomain, domainCategory, userAuth) {
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting user: %v", err)
|
||||
return err
|
||||
@ -984,13 +978,13 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
|
||||
newDomain := accountDomain
|
||||
newCategoty := domainCategory
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
lowerDomain := strings.ToLower(userAuth.Domain)
|
||||
if accountDomain != lowerDomain && user.HasAdminPower() {
|
||||
newDomain = lowerDomain
|
||||
}
|
||||
|
||||
if accountDomain == lowerDomain {
|
||||
newCategoty = claims.DomainCategory
|
||||
newCategoty = userAuth.DomainCategory
|
||||
}
|
||||
|
||||
return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain)
|
||||
@ -1006,16 +1000,16 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
ctx context.Context,
|
||||
userAccountID string,
|
||||
domainAccountID string,
|
||||
claims jwtclaims.AuthorizationClaims,
|
||||
userAuth nbcontext.UserAuth,
|
||||
) error {
|
||||
primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
|
||||
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain)
|
||||
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, userAuth, primaryDomain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// we should register the account ID to this user's metadata in our IDP manager
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID)
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, userAccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1025,20 +1019,20 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
|
||||
// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
||||
// otherwise it will create a new account and make it primary account for the domain.
|
||||
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
if claims.UserId == "" {
|
||||
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
|
||||
if userAuth.UserId == "" {
|
||||
return "", fmt.Errorf("user ID is empty")
|
||||
}
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
lowerDomain := strings.ToLower(userAuth.Domain)
|
||||
|
||||
newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain)
|
||||
newAccount, err := am.newAccount(ctx, userAuth.UserId, lowerDomain)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
newAccount.Domain = lowerDomain
|
||||
newAccount.DomainCategory = claims.DomainCategory
|
||||
newAccount.DomainCategory = userAuth.DomainCategory
|
||||
newAccount.IsDomainPrimaryAccount = true
|
||||
|
||||
err = am.Store.SaveAccount(ctx, newAccount)
|
||||
@ -1046,33 +1040,33 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id)
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, newAccount.Id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil)
|
||||
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, newAccount.Id, activity.UserJoined, nil)
|
||||
|
||||
return newAccount.Id, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||
defer unlockAccount()
|
||||
|
||||
newUser := types.NewRegularUser(claims.UserId)
|
||||
newUser := types.NewRegularUser(userAuth.UserId)
|
||||
newUser.AccountID = domainAccountID
|
||||
err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID)
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, domainAccountID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil)
|
||||
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil)
|
||||
|
||||
return domainAccountID, nil
|
||||
}
|
||||
@ -1112,76 +1106,11 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkPATUsed marks a personal access token as used
|
||||
func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error {
|
||||
return am.Store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID)
|
||||
}
|
||||
|
||||
// GetAccount returns an account associated with this account ID.
|
||||
func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*types.Account, error) {
|
||||
return am.Store.GetAccount(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
|
||||
func (am *DefaultAccountManager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
|
||||
user, pat, err = am.extractPATFromToken(ctx, token)
|
||||
if err != nil {
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
|
||||
domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID)
|
||||
if err != nil {
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
|
||||
return user, pat, domain, category, nil
|
||||
}
|
||||
|
||||
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
|
||||
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) {
|
||||
if len(token) != types.PATLength {
|
||||
return nil, nil, fmt.Errorf("token has incorrect length")
|
||||
}
|
||||
|
||||
prefix := token[:len(types.PATPrefix)]
|
||||
if prefix != types.PATPrefix {
|
||||
return nil, nil, fmt.Errorf("token has wrong prefix")
|
||||
}
|
||||
secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength]
|
||||
encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength]
|
||||
|
||||
verificationChecksum, err := base62.Decode(encodedChecksum)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
|
||||
}
|
||||
|
||||
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
|
||||
if secretChecksum != verificationChecksum {
|
||||
return nil, nil, fmt.Errorf("token checksum does not match")
|
||||
}
|
||||
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
|
||||
var user *types.User
|
||||
var pat *types.PersonalAccessToken
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return user, pat, nil
|
||||
}
|
||||
|
||||
// GetAccountByID returns an account associated with this account ID.
|
||||
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
@ -1196,58 +1125,56 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
|
||||
return am.Store.GetAccount(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetAccountIDFromToken returns an account ID associated with this token.
|
||||
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
if claims.UserId == "" {
|
||||
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
if userAuth.UserId == "" {
|
||||
return "", "", errors.New(emptyUserID)
|
||||
}
|
||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||
// This section is mostly related to self-hosted installations.
|
||||
// We override incoming domain claims to group users under a single account.
|
||||
claims.Domain = am.singleAccountModeDomain
|
||||
claims.DomainCategory = types.PrivateCategory
|
||||
userAuth.Domain = am.singleAccountModeDomain
|
||||
userAuth.DomainCategory = types.PrivateCategory
|
||||
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
||||
}
|
||||
|
||||
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims)
|
||||
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
if err != nil {
|
||||
// this is not really possible because we got an account by user ID
|
||||
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||
return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId)
|
||||
}
|
||||
|
||||
if userAuth.IsChild {
|
||||
return accountID, user.Id, nil
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID)
|
||||
return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", userAuth.UserId, accountID)
|
||||
}
|
||||
|
||||
if !user.IsServiceUser && claims.Invited {
|
||||
if !user.IsServiceUser && userAuth.Invited {
|
||||
err = am.redeemInvite(ctx, accountID, user.Id)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
if err = am.syncJWTGroups(ctx, accountID, claims); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return accountID, user.Id, nil
|
||||
}
|
||||
|
||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||
// and propagates changes to peers if group propagation is enabled.
|
||||
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error {
|
||||
if claim, exists := claims.Raw[jwtclaims.IsToken]; exists {
|
||||
if isToken, ok := claim.(bool); ok && isToken {
|
||||
return nil
|
||||
}
|
||||
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
|
||||
func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
if userAuth.IsChild || userAuth.IsPAT {
|
||||
return nil
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1261,9 +1188,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
return nil
|
||||
}
|
||||
|
||||
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAuth.AccountId)
|
||||
defer func() {
|
||||
if unlockAccount != nil {
|
||||
unlockAccount()
|
||||
@ -1275,17 +1200,17 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
var hasChanges bool
|
||||
var user *types.User
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId)
|
||||
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting user: %w", err)
|
||||
}
|
||||
|
||||
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
|
||||
changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames)
|
||||
changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, userAuth.Groups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting JWT groups changes: %w", err)
|
||||
}
|
||||
@ -1310,7 +1235,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
|
||||
// Propagate changes to peers if group propagation is enabled
|
||||
if settings.GroupsPropagationEnabled {
|
||||
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
@ -1320,7 +1245,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, claims.UserId)
|
||||
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting user peers: %w", err)
|
||||
}
|
||||
@ -1334,7 +1259,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, userAuth.AccountId); err != nil {
|
||||
return fmt.Errorf("error incrementing network serial: %w", err)
|
||||
}
|
||||
}
|
||||
@ -1352,45 +1277,45 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
}
|
||||
|
||||
for _, g := range addNewGroups {
|
||||
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g)
|
||||
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
|
||||
} else {
|
||||
meta := map[string]any{
|
||||
"group": group.Name, "group_id": group.ID,
|
||||
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
|
||||
}
|
||||
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta)
|
||||
am.StoreEvent(ctx, user.Id, user.Id, userAuth.AccountId, activity.GroupAddedToUser, meta)
|
||||
}
|
||||
}
|
||||
|
||||
for _, g := range removeOldGroups {
|
||||
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g)
|
||||
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
|
||||
} else {
|
||||
meta := map[string]any{
|
||||
"group": group.Name, "group_id": group.ID,
|
||||
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
|
||||
}
|
||||
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta)
|
||||
am.StoreEvent(ctx, user.Id, user.Id, userAuth.AccountId, activity.GroupRemovedFromUser, meta)
|
||||
}
|
||||
}
|
||||
|
||||
if settings.GroupsPropagationEnabled {
|
||||
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups)
|
||||
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups)
|
||||
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
|
||||
am.UpdateAccountPeers(ctx, userAuth.AccountId)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1415,24 +1340,34 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
|
||||
//
|
||||
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
||||
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
//
|
||||
// UserAuth IsChild -> checks that account exists
|
||||
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
|
||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
||||
userAuth.UserId, userAuth.AccountId, userAuth.Domain, userAuth.DomainCategory)
|
||||
|
||||
if claims.UserId == "" {
|
||||
if userAuth.UserId == "" {
|
||||
return "", errors.New(emptyUserID)
|
||||
}
|
||||
|
||||
if claims.DomainCategory != types.PrivateCategory || !isDomainValid(claims.Domain) {
|
||||
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
|
||||
if userAuth.IsChild {
|
||||
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
if err != nil || !exists {
|
||||
return "", err
|
||||
}
|
||||
return userAuth.AccountId, nil
|
||||
}
|
||||
|
||||
if claims.AccountId != "" {
|
||||
return am.handlePrivateAccountWithIDFromClaim(ctx, claims)
|
||||
if userAuth.DomainCategory != types.PrivateCategory || !isDomainValid(userAuth.Domain) {
|
||||
return am.GetAccountIDByUserID(ctx, userAuth.UserId, userAuth.Domain)
|
||||
}
|
||||
|
||||
if userAuth.AccountId != "" {
|
||||
return am.handlePrivateAccountWithIDFromClaim(ctx, userAuth)
|
||||
}
|
||||
|
||||
// We checked if the domain has a primary account already
|
||||
domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain)
|
||||
domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, userAuth.Domain)
|
||||
if cancel != nil {
|
||||
defer cancel()
|
||||
}
|
||||
@ -1440,14 +1375,14 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
||||
return "", err
|
||||
}
|
||||
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId)
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if userAccountID != "" {
|
||||
if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil {
|
||||
if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, userAuth); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@ -1455,10 +1390,10 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
||||
}
|
||||
|
||||
if domainAccountID != "" {
|
||||
return am.addNewUserToDomainAccount(ctx, domainAccountID, claims)
|
||||
return am.addNewUserToDomainAccount(ctx, domainAccountID, userAuth)
|
||||
}
|
||||
|
||||
return am.addNewPrivateAccount(ctx, domainAccountID, claims)
|
||||
return am.addNewPrivateAccount(ctx, domainAccountID, userAuth)
|
||||
}
|
||||
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
|
||||
@ -1486,40 +1421,40 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
|
||||
return domainAccountID, cancel, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId)
|
||||
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if userAccountID != claims.AccountId {
|
||||
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||
if userAccountID != userAuth.AccountId {
|
||||
return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId)
|
||||
}
|
||||
|
||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, claims.AccountId)
|
||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if domainIsUpToDate(accountDomain, domainCategory, claims) {
|
||||
return claims.AccountId, nil
|
||||
if domainIsUpToDate(accountDomain, domainCategory, userAuth) {
|
||||
return userAuth.AccountId, nil
|
||||
}
|
||||
|
||||
// We checked if the domain has a primary account already
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, claims.Domain)
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, userAuth.Domain)
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims)
|
||||
err = am.handleExistingUserAccount(ctx, userAuth.AccountId, domainAccountID, userAuth)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return claims.AccountId, nil
|
||||
return userAuth.AccountId, nil
|
||||
}
|
||||
|
||||
func handleNotFound(err error) error {
|
||||
@ -1534,8 +1469,8 @@ func handleNotFound(err error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool {
|
||||
return domainCategory == types.PrivateCategory || claims.DomainCategory != types.PrivateCategory || domain != claims.Domain
|
||||
func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.UserAuth) bool {
|
||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
@ -1618,34 +1553,6 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
|
||||
return am.dnsDomain
|
||||
}
|
||||
|
||||
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
|
||||
// group propagation and set the list of groups with access permissions.
|
||||
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||
accountID, _, err := am.GetAccountIDFromToken(ctx, claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensures JWT group synchronization to the management is enabled before,
|
||||
// filtering access based on the allowed groups.
|
||||
if settings != nil && settings.JWTGroupsEnabled {
|
||||
if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
||||
userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||
|
||||
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
|
||||
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
|
||||
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
@ -1804,39 +1711,6 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
|
||||
return acc
|
||||
}
|
||||
|
||||
// extractJWTGroups extracts the group names from a JWT token's claims.
|
||||
func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string {
|
||||
userJWTGroups := make([]string, 0)
|
||||
|
||||
if claim, ok := claims.Raw[claimName]; ok {
|
||||
if claimGroups, ok := claim.([]interface{}); ok {
|
||||
for _, g := range claimGroups {
|
||||
if group, ok := g.(string); ok {
|
||||
userJWTGroups = append(userJWTGroups, group)
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName)
|
||||
}
|
||||
|
||||
return userJWTGroups
|
||||
}
|
||||
|
||||
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
|
||||
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
||||
for _, userGroup := range userGroups {
|
||||
for _, allowedGroup := range allowedGroups {
|
||||
if userGroup == allowedGroup {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// separateGroups separates user's auto groups into non-JWT and JWT groups.
|
||||
// Returns the list of standard auto groups and a map of JWT auto groups,
|
||||
// where the keys are the group names and the values are the group IDs.
|
||||
|
@ -2,8 +2,6 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
@ -12,8 +10,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@ -26,7 +22,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@ -433,7 +429,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
type initUserParams jwtclaims.AuthorizationClaims
|
||||
type initUserParams nbcontext.UserAuth
|
||||
|
||||
var (
|
||||
publicDomain = "public.com"
|
||||
@ -456,7 +452,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputClaims jwtclaims.AuthorizationClaims
|
||||
inputClaims nbcontext.UserAuth
|
||||
inputInitUserParams initUserParams
|
||||
inputUpdateAttrs bool
|
||||
inputUpdateClaimAccount bool
|
||||
@ -471,7 +467,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "New User With Public Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
inputClaims: nbcontext.UserAuth{
|
||||
Domain: publicDomain,
|
||||
UserId: "pub-domain-user",
|
||||
DomainCategory: types.PublicCategory,
|
||||
@ -488,7 +484,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "New User With Unknown Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
inputClaims: nbcontext.UserAuth{
|
||||
Domain: unknownDomain,
|
||||
UserId: "unknown-domain-user",
|
||||
DomainCategory: types.UnknownCategory,
|
||||
@ -505,7 +501,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "New User With Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
inputClaims: nbcontext.UserAuth{
|
||||
Domain: privateDomain,
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: types.PrivateCategory,
|
||||
@ -522,7 +518,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "New Regular User With Existing Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
inputClaims: nbcontext.UserAuth{
|
||||
Domain: privateDomain,
|
||||
UserId: "new-pvt-domain-user",
|
||||
DomainCategory: types.PrivateCategory,
|
||||
@ -540,7 +536,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Existing User With Existing Reclassified Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
inputClaims: nbcontext.UserAuth{
|
||||
Domain: defaultInitAccount.Domain,
|
||||
UserId: defaultInitAccount.UserId,
|
||||
DomainCategory: types.PrivateCategory,
|
||||
@ -557,7 +553,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Existing Account Id With Existing Reclassified Private Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
inputClaims: nbcontext.UserAuth{
|
||||
Domain: defaultInitAccount.Domain,
|
||||
UserId: defaultInitAccount.UserId,
|
||||
DomainCategory: types.PrivateCategory,
|
||||
@ -575,7 +571,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "User With Private Category And Empty Domain",
|
||||
inputClaims: jwtclaims.AuthorizationClaims{
|
||||
inputClaims: nbcontext.UserAuth{
|
||||
Domain: "",
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: types.PrivateCategory,
|
||||
@ -604,7 +600,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
require.NoError(t, err, "get init account failed")
|
||||
|
||||
if testCase.inputUpdateAttrs {
|
||||
err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||
err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, nbcontext.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||
require.NoError(t, err, "update init user failed")
|
||||
}
|
||||
|
||||
@ -612,7 +608,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
testCase.inputClaims.AccountId = initAccount.Id
|
||||
}
|
||||
|
||||
accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims)
|
||||
accountID, _, err = manager.GetAccountIDFromUserAuth(context.Background(), testCase.inputClaims)
|
||||
require.NoError(t, err, "support function failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
@ -631,14 +627,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain)
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
// as initAccount was created without account id we have to take the id after account initialization
|
||||
@ -646,65 +640,50 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
||||
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get init account failed")
|
||||
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
claims := nbcontext.UserAuth{
|
||||
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
|
||||
Domain: domain,
|
||||
UserId: userId,
|
||||
DomainCategory: "test-category",
|
||||
Raw: jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}},
|
||||
Groups: []string{"group1", "group2"},
|
||||
}
|
||||
|
||||
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
err := manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
require.NoError(t, err, "synt user jwt groups failed")
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
|
||||
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
err := manager.Store.SaveAccount(context.Background(), initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
require.NoError(t, err, "synt user jwt groups failed")
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
|
||||
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
err := manager.Store.SaveAccount(context.Background(), initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
require.NoError(t, err, "synt user jwt groups failed")
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
|
||||
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
||||
|
||||
groupsByNames := map[string]*types.Group{}
|
||||
for _, g := range account.Groups {
|
||||
groupsByNames[g.Name] = g
|
||||
}
|
||||
|
||||
g1, ok := groupsByNames["group1"]
|
||||
require.True(t, ok, "group1 should be added to the account")
|
||||
require.Equal(t, g1.Name, "group1", "group1 name should match")
|
||||
require.Equal(t, g1.Issued, types.GroupIssuedJWT, "group1 issued should match")
|
||||
|
||||
g2, ok := groupsByNames["group2"]
|
||||
require.True(t, ok, "group2 should be added to the account")
|
||||
require.Equal(t, g2.Name, "group2", "group2 name should match")
|
||||
@ -712,88 +691,6 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
account.Users["someUser"] = &types.User{
|
||||
Id: "someUser",
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"tokenId": {
|
||||
ID: "tokenId",
|
||||
UserID: "someUser",
|
||||
HashedToken: encodedHashedToken,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
}
|
||||
|
||||
user, pat, _, _, err := am.GetPATInfo(context.Background(), token)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting Account from PAT: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, "account_id", user.AccountID)
|
||||
assert.Equal(t, "someUser", user.Id)
|
||||
assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
account.Users["someUser"] = &types.User{
|
||||
Id: "someUser",
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"tokenId": {
|
||||
ID: "tokenId",
|
||||
HashedToken: encodedHashedToken,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
}
|
||||
|
||||
err = am.MarkPATUsed(context.Background(), "tokenId")
|
||||
if err != nil {
|
||||
t.Fatalf("Error when marking PAT used: %s", err)
|
||||
}
|
||||
|
||||
account, err = am.Store.GetAccount(context.Background(), "account_id")
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting account: %s", err)
|
||||
}
|
||||
assert.True(t, !account.Users["someUser"].PATs["tokenId"].GetLastUsed().IsZero())
|
||||
}
|
||||
|
||||
func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
@ -2575,11 +2472,13 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
|
||||
|
||||
t.Run("skip sync for token auth type", func(t *testing.T) {
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
UserId: "user1",
|
||||
Raw: jwt.MapClaims{"groups": []interface{}{"group3"}, "is_token": true},
|
||||
claims := nbcontext.UserAuth{
|
||||
UserId: "user1",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{"group3"},
|
||||
IsPAT: true,
|
||||
}
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
@ -2588,11 +2487,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("empty jwt groups", func(t *testing.T) {
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
UserId: "user1",
|
||||
Raw: jwt.MapClaims{"groups": []interface{}{}},
|
||||
claims := nbcontext.UserAuth{
|
||||
UserId: "user1",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{},
|
||||
}
|
||||
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
err := manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
@ -2601,11 +2501,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("jwt match existing api group", func(t *testing.T) {
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
UserId: "user1",
|
||||
Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
|
||||
claims := nbcontext.UserAuth{
|
||||
UserId: "user1",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{"group1"},
|
||||
}
|
||||
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
err := manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
@ -2621,11 +2522,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
account.Users["user1"].AutoGroups = []string{"group1"}
|
||||
assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"]))
|
||||
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
UserId: "user1",
|
||||
Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
|
||||
claims := nbcontext.UserAuth{
|
||||
UserId: "user1",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{"group1"},
|
||||
}
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
@ -2638,11 +2540,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("add jwt group", func(t *testing.T) {
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
UserId: "user1",
|
||||
Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group2"}},
|
||||
claims := nbcontext.UserAuth{
|
||||
UserId: "user1",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{"group1", "group2"},
|
||||
}
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
@ -2651,11 +2554,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("existed group not update", func(t *testing.T) {
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
UserId: "user1",
|
||||
Raw: jwt.MapClaims{"groups": []interface{}{"group2"}},
|
||||
claims := nbcontext.UserAuth{
|
||||
UserId: "user1",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{"group2"},
|
||||
}
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
@ -2664,11 +2568,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("add new group", func(t *testing.T) {
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
UserId: "user2",
|
||||
Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group3"}},
|
||||
claims := nbcontext.UserAuth{
|
||||
UserId: "user2",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{"group1", "group3"},
|
||||
}
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID")
|
||||
@ -2681,11 +2586,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
UserId: "user1",
|
||||
Raw: jwt.MapClaims{"groups": []interface{}{}},
|
||||
claims := nbcontext.UserAuth{
|
||||
UserId: "user1",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{},
|
||||
}
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
@ -2695,11 +2601,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) {
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
UserId: "user2",
|
||||
Raw: jwt.MapClaims{},
|
||||
claims := nbcontext.UserAuth{
|
||||
UserId: "user2",
|
||||
AccountId: "accountID",
|
||||
Groups: []string{},
|
||||
}
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
|
||||
|
@ -1,15 +1,17 @@
|
||||
package jwtclaims
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"errors"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
)
|
||||
|
||||
const (
|
||||
// TokenUserProperty key for the user property in the request context
|
||||
TokenUserProperty = "user"
|
||||
// AccountIDSuffix suffix for the account id claim
|
||||
AccountIDSuffix = "wt_account_id"
|
||||
// DomainIDSuffix suffix for the domain id claim
|
||||
@ -22,19 +24,16 @@ const (
|
||||
LastLoginSuffix = "nb_last_login"
|
||||
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
|
||||
Invited = "nb_invited"
|
||||
// IsToken claim indicates that auth type from the user is a token
|
||||
IsToken = "is_token"
|
||||
)
|
||||
|
||||
// ExtractClaims Extract function type
|
||||
type ExtractClaims func(r *http.Request) AuthorizationClaims
|
||||
var (
|
||||
errUserIDClaimEmpty = errors.New("user ID claim token value is empty")
|
||||
)
|
||||
|
||||
// ClaimsExtractor struct that holds the extract function
|
||||
type ClaimsExtractor struct {
|
||||
authAudience string
|
||||
userIDClaim string
|
||||
|
||||
FromRequestContext ExtractClaims
|
||||
}
|
||||
|
||||
// ClaimsExtractorOption is a function that configures the ClaimsExtractor
|
||||
@ -54,13 +53,6 @@ func WithUserIDClaim(userIDClaim string) ClaimsExtractorOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithFromRequestContext sets the function that extracts claims from the request context
|
||||
func WithFromRequestContext(ec ExtractClaims) ClaimsExtractorOption {
|
||||
return func(c *ClaimsExtractor) {
|
||||
c.FromRequestContext = ec
|
||||
}
|
||||
}
|
||||
|
||||
// NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature,
|
||||
// then it will use that logic. Uses ExtractClaimsFromRequestContext by default
|
||||
func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
|
||||
@ -68,49 +60,13 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
|
||||
for _, option := range options {
|
||||
option(ce)
|
||||
}
|
||||
if ce.FromRequestContext == nil {
|
||||
ce.FromRequestContext = ce.fromRequestContext
|
||||
}
|
||||
|
||||
if ce.userIDClaim == "" {
|
||||
ce.userIDClaim = UserIDClaim
|
||||
}
|
||||
return ce
|
||||
}
|
||||
|
||||
// FromToken extracts claims from the token (after auth)
|
||||
func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
jwtClaims := AuthorizationClaims{
|
||||
Raw: claims,
|
||||
}
|
||||
userID, ok := claims[c.userIDClaim].(string)
|
||||
if !ok {
|
||||
return jwtClaims
|
||||
}
|
||||
jwtClaims.UserId = userID
|
||||
accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix]
|
||||
if ok {
|
||||
jwtClaims.AccountId = accountIDClaim.(string)
|
||||
}
|
||||
domainClaim, ok := claims[c.authAudience+DomainIDSuffix]
|
||||
if ok {
|
||||
jwtClaims.Domain = domainClaim.(string)
|
||||
}
|
||||
domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix]
|
||||
if ok {
|
||||
jwtClaims.DomainCategory = domainCategoryClaim.(string)
|
||||
}
|
||||
LastLoginClaimString, ok := claims[c.authAudience+LastLoginSuffix]
|
||||
if ok {
|
||||
jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string))
|
||||
}
|
||||
invitedBool, ok := claims[c.authAudience+Invited]
|
||||
if ok {
|
||||
jwtClaims.Invited = invitedBool.(bool)
|
||||
}
|
||||
return jwtClaims
|
||||
}
|
||||
|
||||
func parseTime(timeString string) time.Time {
|
||||
if timeString == "" {
|
||||
return time.Time{}
|
||||
@ -122,11 +78,67 @@ func parseTime(timeString string) time.Time {
|
||||
return parsedTime
|
||||
}
|
||||
|
||||
// fromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
|
||||
func (c *ClaimsExtractor) fromRequestContext(r *http.Request) AuthorizationClaims {
|
||||
if r.Context().Value(TokenUserProperty) == nil {
|
||||
return AuthorizationClaims{}
|
||||
func (c ClaimsExtractor) audienceClaim(claimName string) string {
|
||||
url, err := url.JoinPath(c.authAudience, claimName)
|
||||
if err != nil {
|
||||
return c.authAudience + claimName // as it was previously
|
||||
}
|
||||
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
|
||||
return c.FromToken(token)
|
||||
|
||||
return url
|
||||
}
|
||||
|
||||
func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, error) {
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
userAuth := nbcontext.UserAuth{}
|
||||
|
||||
userID, ok := claims[c.userIDClaim].(string)
|
||||
if !ok {
|
||||
return userAuth, errUserIDClaimEmpty
|
||||
}
|
||||
userAuth.UserId = userID
|
||||
|
||||
if accountIDClaim, ok := claims[c.audienceClaim(AccountIDSuffix)]; ok {
|
||||
userAuth.AccountId = accountIDClaim.(string)
|
||||
}
|
||||
|
||||
if domainClaim, ok := claims[c.audienceClaim(DomainIDSuffix)]; ok {
|
||||
userAuth.Domain = domainClaim.(string)
|
||||
}
|
||||
|
||||
if domainCategoryClaim, ok := claims[c.audienceClaim(DomainCategorySuffix)]; ok {
|
||||
userAuth.DomainCategory = domainCategoryClaim.(string)
|
||||
}
|
||||
|
||||
if lastLoginClaimString, ok := claims[c.audienceClaim(LastLoginSuffix)]; ok {
|
||||
userAuth.LastLogin = parseTime(lastLoginClaimString.(string))
|
||||
}
|
||||
|
||||
if invitedBool, ok := claims[c.audienceClaim(Invited)]; ok {
|
||||
if value, ok := invitedBool.(bool); ok {
|
||||
userAuth.Invited = value
|
||||
}
|
||||
}
|
||||
|
||||
return userAuth, nil
|
||||
}
|
||||
|
||||
func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string {
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
userJWTGroups := make([]string, 0)
|
||||
|
||||
if claim, ok := claims[claimName]; ok {
|
||||
if claimGroups, ok := claim.([]interface{}); ok {
|
||||
for _, g := range claimGroups {
|
||||
if group, ok := g.(string); ok {
|
||||
userJWTGroups = append(userJWTGroups, group)
|
||||
} else {
|
||||
log.Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Debugf("JWT claim %q is not a string array", claimName)
|
||||
}
|
||||
|
||||
return userJWTGroups
|
||||
}
|
302
management/server/auth/jwt/validator.go
Normal file
302
management/server/auth/jwt/validator.go
Normal file
@ -0,0 +1,302 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
|
||||
type Jwks struct {
|
||||
Keys []JSONWebKey `json:"keys"`
|
||||
expiresInTime time.Time
|
||||
}
|
||||
|
||||
// The supported elliptic curves types
|
||||
const (
|
||||
// p256 represents a cryptographic elliptical curve type.
|
||||
p256 = "P-256"
|
||||
|
||||
// p384 represents a cryptographic elliptical curve type.
|
||||
p384 = "P-384"
|
||||
|
||||
// p521 represents a cryptographic elliptical curve type.
|
||||
p521 = "P-521"
|
||||
)
|
||||
|
||||
// JSONWebKey is a representation of a Jason Web Key
|
||||
type JSONWebKey struct {
|
||||
Kty string `json:"kty"`
|
||||
Kid string `json:"kid"`
|
||||
Use string `json:"use"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
Crv string `json:"crv"`
|
||||
X string `json:"x"`
|
||||
Y string `json:"y"`
|
||||
X5c []string `json:"x5c"`
|
||||
}
|
||||
|
||||
type Validator struct {
|
||||
lock sync.Mutex
|
||||
issuer string
|
||||
audienceList []string
|
||||
keysLocation string
|
||||
idpSignkeyRefreshEnabled bool
|
||||
keys *Jwks
|
||||
}
|
||||
|
||||
var (
|
||||
errKeyNotFound = errors.New("unable to find appropriate key")
|
||||
errInvalidAudience = errors.New("invalid audience")
|
||||
errInvalidIssuer = errors.New("invalid issuer")
|
||||
errTokenEmpty = errors.New("required authorization token not found")
|
||||
errTokenInvalid = errors.New("token is invalid")
|
||||
errTokenParsing = errors.New("token could not be parsed")
|
||||
)
|
||||
|
||||
func NewValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) *Validator {
|
||||
keys, err := getPemKeys(keysLocation)
|
||||
if err != nil {
|
||||
log.WithField("keysLocation", keysLocation).Errorf("could not get keys from location: %s", err)
|
||||
}
|
||||
|
||||
return &Validator{
|
||||
keys: keys,
|
||||
issuer: issuer,
|
||||
audienceList: audienceList,
|
||||
keysLocation: keysLocation,
|
||||
idpSignkeyRefreshEnabled: idpSignkeyRefreshEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc {
|
||||
return func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify 'aud' claim
|
||||
var checkAud bool
|
||||
for _, audience := range v.audienceList {
|
||||
checkAud = token.Claims.(jwt.MapClaims).VerifyAudience(audience, false)
|
||||
if checkAud {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !checkAud {
|
||||
return token, errInvalidAudience
|
||||
}
|
||||
|
||||
// Verify 'issuer' claim
|
||||
checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(v.issuer, false)
|
||||
if !checkIss {
|
||||
return token, errInvalidIssuer
|
||||
}
|
||||
|
||||
// If keys are rotated, verify the keys prior to token validation
|
||||
if v.idpSignkeyRefreshEnabled {
|
||||
// If the keys are invalid, retrieve new ones
|
||||
// @todo propose a separate go routine to regularly check these to prevent blocking when actually
|
||||
// validating the token
|
||||
if !v.keys.stillValid() {
|
||||
v.lock.Lock()
|
||||
defer v.lock.Unlock()
|
||||
|
||||
refreshedKeys, err := getPemKeys(v.keysLocation)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
|
||||
refreshedKeys = v.keys
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
|
||||
|
||||
v.keys = refreshedKeys
|
||||
}
|
||||
}
|
||||
|
||||
publicKey, err := getPublicKey(token, v.keys)
|
||||
if err == nil {
|
||||
return publicKey, nil
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("getPublicKey error: %s", err)
|
||||
if errors.Is(err, errKeyNotFound) && !v.idpSignkeyRefreshEnabled {
|
||||
msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Error(msg)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateAndParse validates the token and returns the parsed token
|
||||
func (m *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
|
||||
// If the token is empty...
|
||||
if token == "" {
|
||||
// If we get here, the required token is missing
|
||||
log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)")
|
||||
return nil, errTokenEmpty
|
||||
}
|
||||
|
||||
// Now parse the token
|
||||
parsedToken, err := jwt.Parse(token, m.getKeyFunc(ctx))
|
||||
|
||||
// Check if there was an error in parsing...
|
||||
if err != nil {
|
||||
err = fmt.Errorf("%w: %s", errTokenParsing, err)
|
||||
log.WithContext(ctx).Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if the parsed token is valid...
|
||||
if !parsedToken.Valid {
|
||||
log.WithContext(ctx).Debug(errTokenInvalid.Error())
|
||||
return nil, errTokenInvalid
|
||||
}
|
||||
|
||||
return parsedToken, nil
|
||||
}
|
||||
|
||||
// stillValid returns true if the JSONWebKey still valid and have enough time to be used
|
||||
func (jwks *Jwks) stillValid() bool {
|
||||
return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
|
||||
}
|
||||
|
||||
func getPemKeys(keysLocation string) (*Jwks, error) {
|
||||
jwks := &Jwks{}
|
||||
|
||||
url, err := url.ParseRequestURI(keysLocation)
|
||||
if err != nil {
|
||||
return jwks, err
|
||||
}
|
||||
|
||||
resp, err := http.Get(url.String())
|
||||
if err != nil {
|
||||
return jwks, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
err = json.NewDecoder(resp.Body).Decode(jwks)
|
||||
if err != nil {
|
||||
return jwks, err
|
||||
}
|
||||
|
||||
cacheControlHeader := resp.Header.Get("Cache-Control")
|
||||
expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader)
|
||||
jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
func getPublicKey(token *jwt.Token, jwks *Jwks) (interface{}, error) {
|
||||
// todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
|
||||
for k := range jwks.Keys {
|
||||
if token.Header["kid"] != jwks.Keys[k].Kid {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(jwks.Keys[k].X5c) != 0 {
|
||||
cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
|
||||
return jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
|
||||
}
|
||||
|
||||
if jwks.Keys[k].Kty == "RSA" {
|
||||
return getPublicKeyFromRSA(jwks.Keys[k])
|
||||
}
|
||||
if jwks.Keys[k].Kty == "EC" {
|
||||
return getPublicKeyFromECDSA(jwks.Keys[k])
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errKeyNotFound
|
||||
}
|
||||
|
||||
func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) {
|
||||
if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" {
|
||||
return nil, fmt.Errorf("ecdsa key incomplete")
|
||||
}
|
||||
|
||||
var xCoordinate []byte
|
||||
if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var yCoordinate []byte
|
||||
if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publicKey = &ecdsa.PublicKey{}
|
||||
|
||||
var curve elliptic.Curve
|
||||
switch jwk.Crv {
|
||||
case p256:
|
||||
curve = elliptic.P256()
|
||||
case p384:
|
||||
curve = elliptic.P384()
|
||||
case p521:
|
||||
curve = elliptic.P521()
|
||||
}
|
||||
|
||||
publicKey.Curve = curve
|
||||
publicKey.X = big.NewInt(0).SetBytes(xCoordinate)
|
||||
publicKey.Y = big.NewInt(0).SetBytes(yCoordinate)
|
||||
|
||||
return publicKey, nil
|
||||
}
|
||||
|
||||
func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) {
|
||||
decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var n, e big.Int
|
||||
e.SetBytes(decodedE)
|
||||
n.SetBytes(decodedN)
|
||||
|
||||
return &rsa.PublicKey{
|
||||
E: int(e.Int64()),
|
||||
N: &n,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
|
||||
func getMaxAgeFromCacheHeader(cacheControl string) int {
|
||||
// Split into individual directives
|
||||
directives := strings.Split(cacheControl, ",")
|
||||
|
||||
for _, directive := range directives {
|
||||
directive = strings.TrimSpace(directive)
|
||||
if strings.HasPrefix(directive, "max-age=") {
|
||||
// Extract the max-age value
|
||||
maxAgeStr := strings.TrimPrefix(directive, "max-age=")
|
||||
maxAge, err := strconv.Atoi(maxAgeStr)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return maxAge
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
170
management/server/auth/manager.go
Normal file
170
management/server/auth/manager.go
Normal file
@ -0,0 +1,170 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
var _ Manager = (*manager)(nil)
|
||||
|
||||
type Manager interface {
|
||||
ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error)
|
||||
EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error)
|
||||
MarkPATUsed(ctx context.Context, tokenID string) error
|
||||
GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
|
||||
}
|
||||
|
||||
type manager struct {
|
||||
store store.Store
|
||||
|
||||
validator *nbjwt.Validator
|
||||
extractor *nbjwt.ClaimsExtractor
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool) Manager {
|
||||
// @note if invalid/missing parameters are sent the validator will instantiate
|
||||
// but it will fail when validating and parsing the token
|
||||
jwtValidator := nbjwt.NewValidator(
|
||||
issuer,
|
||||
allAudiences,
|
||||
keysLocation,
|
||||
idpRefreshKeys,
|
||||
)
|
||||
|
||||
claimsExtractor := nbjwt.NewClaimsExtractor(
|
||||
nbjwt.WithAudience(audience),
|
||||
nbjwt.WithUserIDClaim(userIdClaim),
|
||||
)
|
||||
|
||||
return &manager{
|
||||
store: store,
|
||||
|
||||
validator: jwtValidator,
|
||||
extractor: claimsExtractor,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) {
|
||||
token, err := m.validator.ValidateAndParse(ctx, value)
|
||||
if err != nil {
|
||||
return nbcontext.UserAuth{}, nil, err
|
||||
}
|
||||
|
||||
userAuth, err := m.extractor.ToUserAuth(token)
|
||||
if err != nil {
|
||||
return nbcontext.UserAuth{}, nil, err
|
||||
}
|
||||
return userAuth, token, err
|
||||
}
|
||||
|
||||
func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
|
||||
if userAuth.IsChild || userAuth.IsPAT {
|
||||
return userAuth, nil
|
||||
}
|
||||
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
if err != nil {
|
||||
return userAuth, err
|
||||
}
|
||||
|
||||
// Ensures JWT group synchronization to the management is enabled before,
|
||||
// filtering access based on the allowed groups.
|
||||
if settings != nil && settings.JWTGroupsEnabled {
|
||||
userAuth.Groups = m.extractor.ToGroups(token, settings.JWTGroupsClaimName)
|
||||
if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
||||
if !userHasAllowedGroup(allowedGroups, userAuth.Groups) {
|
||||
return userAuth, fmt.Errorf("user does not belong to any of the allowed JWT groups")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return userAuth, nil
|
||||
}
|
||||
|
||||
// MarkPATUsed marks a personal access token as used
|
||||
func (am *manager) MarkPATUsed(ctx context.Context, tokenID string) error {
|
||||
return am.store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID)
|
||||
}
|
||||
|
||||
// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
|
||||
func (am *manager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
|
||||
user, pat, err = am.extractPATFromToken(ctx, token)
|
||||
if err != nil {
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
|
||||
domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID)
|
||||
if err != nil {
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
|
||||
return user, pat, domain, category, nil
|
||||
}
|
||||
|
||||
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
|
||||
func (am *manager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) {
|
||||
if len(token) != types.PATLength {
|
||||
return nil, nil, fmt.Errorf("PAT has incorrect length")
|
||||
}
|
||||
|
||||
prefix := token[:len(types.PATPrefix)]
|
||||
if prefix != types.PATPrefix {
|
||||
return nil, nil, fmt.Errorf("PAT has wrong prefix")
|
||||
}
|
||||
secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength]
|
||||
encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength]
|
||||
|
||||
verificationChecksum, err := base62.Decode(encodedChecksum)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("PAT checksum decoding failed: %w", err)
|
||||
}
|
||||
|
||||
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
|
||||
if secretChecksum != verificationChecksum {
|
||||
return nil, nil, fmt.Errorf("PAT checksum does not match")
|
||||
}
|
||||
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := base64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
|
||||
var user *types.User
|
||||
var pat *types.PersonalAccessToken
|
||||
|
||||
err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return user, pat, nil
|
||||
}
|
||||
|
||||
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
|
||||
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
||||
for _, userGroup := range userGroups {
|
||||
for _, allowedGroup := range allowedGroups {
|
||||
if userGroup == allowedGroup {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
54
management/server/auth/manager_mock.go
Normal file
54
management/server/auth/manager_mock.go
Normal file
@ -0,0 +1,54 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Manager = (*MockManager)(nil)
|
||||
)
|
||||
|
||||
// @note really dislike this mocking approach but rather than have to do additional test refactoring.
|
||||
type MockManager struct {
|
||||
ValidateAndParseTokenFunc func(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error)
|
||||
EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error)
|
||||
MarkPATUsedFunc func(ctx context.Context, tokenID string) error
|
||||
GetPATInfoFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
|
||||
}
|
||||
|
||||
// EnsureUserAccessByJWTGroups implements Manager.
|
||||
func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
|
||||
if m.EnsureUserAccessByJWTGroupsFunc != nil {
|
||||
return m.EnsureUserAccessByJWTGroupsFunc(ctx, userAuth, token)
|
||||
}
|
||||
return nbcontext.UserAuth{}, nil
|
||||
}
|
||||
|
||||
// GetPATInfo implements Manager.
|
||||
func (m *MockManager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
|
||||
if m.GetPATInfoFunc != nil {
|
||||
return m.GetPATInfoFunc(ctx, token)
|
||||
}
|
||||
return &types.User{}, &types.PersonalAccessToken{}, "", "", nil
|
||||
}
|
||||
|
||||
// MarkPATUsed implements Manager.
|
||||
func (m *MockManager) MarkPATUsed(ctx context.Context, tokenID string) error {
|
||||
if m.MarkPATUsedFunc != nil {
|
||||
return m.MarkPATUsedFunc(ctx, tokenID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAndParseToken implements Manager.
|
||||
func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) {
|
||||
if m.ValidateAndParseTokenFunc != nil {
|
||||
return m.ValidateAndParseTokenFunc(ctx, value)
|
||||
}
|
||||
return nbcontext.UserAuth{}, &jwt.Token{}, nil
|
||||
}
|
407
management/server/auth/manager_test.go
Normal file
407
management/server/auth/manager_test.go
Normal file
@ -0,0 +1,407 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) {
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := base64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
account := &types.Account{
|
||||
Id: "account_id",
|
||||
Users: map[string]*types.User{"someUser": {
|
||||
Id: "someUser",
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"tokenId": {
|
||||
ID: "tokenId",
|
||||
UserID: "someUser",
|
||||
HashedToken: encodedHashedToken,
|
||||
},
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
|
||||
user, pat, _, _, err := manager.GetPATInfo(context.Background(), token)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting Account from PAT: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, "account_id", user.AccountID)
|
||||
assert.Equal(t, "someUser", user.Id)
|
||||
assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID)
|
||||
}
|
||||
|
||||
func TestAuthManager_MarkPATUsed(t *testing.T) {
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := base64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
account := &types.Account{
|
||||
Id: "account_id",
|
||||
Users: map[string]*types.User{"someUser": {
|
||||
Id: "someUser",
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"tokenId": {
|
||||
ID: "tokenId",
|
||||
HashedToken: encodedHashedToken,
|
||||
},
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
|
||||
err = manager.MarkPATUsed(context.Background(), "tokenId")
|
||||
if err != nil {
|
||||
t.Fatalf("Error when marking PAT used: %s", err)
|
||||
}
|
||||
|
||||
account, err = store.GetAccount(context.Background(), "account_id")
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting account: %s", err)
|
||||
}
|
||||
assert.True(t, !account.Users["someUser"].PATs["tokenId"].GetLastUsed().IsZero())
|
||||
}
|
||||
|
||||
func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
|
||||
account := &types.Account{
|
||||
Id: "account_id",
|
||||
Domain: domain,
|
||||
Users: map[string]*types.User{"someUser": {
|
||||
Id: "someUser",
|
||||
}},
|
||||
Settings: &types.Settings{},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
// this has been validated and parsed by ValidateAndParseToken
|
||||
userAuth := nbcontext.UserAuth{
|
||||
AccountId: account.Id,
|
||||
Domain: domain,
|
||||
UserId: userId,
|
||||
DomainCategory: "test-category",
|
||||
// Groups: []string{"group1", "group2"},
|
||||
}
|
||||
|
||||
// these tests only assert groups are parsed from token as per account settings
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}})
|
||||
|
||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
|
||||
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
||||
require.NoError(t, err, "ensure user access by JWT groups failed")
|
||||
require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups")
|
||||
})
|
||||
|
||||
t.Run("User impersonated", func(t *testing.T) {
|
||||
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
||||
require.NoError(t, err, "ensure user access by JWT groups failed")
|
||||
require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups")
|
||||
})
|
||||
|
||||
t.Run("User PAT", func(t *testing.T) {
|
||||
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
||||
require.NoError(t, err, "ensure user access by JWT groups failed")
|
||||
require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
|
||||
account.Settings.JWTGroupsEnabled = true
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err, "save account failed")
|
||||
|
||||
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
||||
require.NoError(t, err, "ensure user access by JWT groups failed")
|
||||
require.Len(t, userAuth.Groups, 0, "account missing groups claim name")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled without allowed groups", func(t *testing.T) {
|
||||
account.Settings.JWTGroupsEnabled = true
|
||||
account.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err, "save account failed")
|
||||
|
||||
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
||||
require.NoError(t, err, "ensure user access by JWT groups failed")
|
||||
require.Equal(t, []string{"group1", "group2"}, userAuth.Groups, "group parsed do not match")
|
||||
})
|
||||
|
||||
t.Run("User in allowed JWT groups", func(t *testing.T) {
|
||||
account.Settings.JWTGroupsEnabled = true
|
||||
account.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
account.Settings.JWTAllowGroups = []string{"group1"}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err, "save account failed")
|
||||
|
||||
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
||||
require.NoError(t, err, "ensure user access by JWT groups failed")
|
||||
|
||||
require.Equal(t, []string{"group1", "group2"}, userAuth.Groups, "group parsed do not match")
|
||||
})
|
||||
|
||||
t.Run("User not in allowed JWT groups", func(t *testing.T) {
|
||||
account.Settings.JWTGroupsEnabled = true
|
||||
account.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
account.Settings.JWTAllowGroups = []string{"not-a-group"}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err, "save account failed")
|
||||
|
||||
_, err = manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
||||
require.Error(t, err, "ensure user access is not in allowed groups")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthManager_ValidateAndParseToken(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add("Cache-Control", "max-age=30") // set a 30s expiry to these keys
|
||||
http.ServeFile(w, r, "test_data/jwks.json")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
issuer := "http://issuer.local"
|
||||
audience := "http://audience.local"
|
||||
userIdClaim := "" // defaults to "sub"
|
||||
|
||||
// we're only testing with RSA256
|
||||
keyData, _ := os.ReadFile("test_data/sample_key")
|
||||
key, _ := jwt.ParseRSAPrivateKeyFromPEM(keyData)
|
||||
keyId := "test-key"
|
||||
|
||||
// note, we can use a nil store because ValidateAndParseToken does not use it in it's flow
|
||||
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false)
|
||||
|
||||
customClaim := func(name string) string {
|
||||
return fmt.Sprintf("%s/%s", audience, name)
|
||||
}
|
||||
|
||||
lastLogin := time.Date(2025, 2, 12, 14, 25, 26, 0, time.UTC) //"2025-02-12T14:25:26.186Z"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenFunc func() string
|
||||
expected *nbcontext.UserAuth // nil indicates expected error
|
||||
}{
|
||||
{
|
||||
name: "Valid with custom claims",
|
||||
tokenFunc: func() string {
|
||||
token := jwt.New(jwt.SigningMethodRS256)
|
||||
token.Header["kid"] = keyId
|
||||
token.Claims = jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": []string{audience},
|
||||
"iat": time.Now().Unix(),
|
||||
"exp": time.Now().Add(time.Hour * 1).Unix(),
|
||||
"sub": "user-id|123",
|
||||
customClaim(nbjwt.AccountIDSuffix): "account-id|567",
|
||||
customClaim(nbjwt.DomainIDSuffix): "http://localhost",
|
||||
customClaim(nbjwt.DomainCategorySuffix): "private",
|
||||
customClaim(nbjwt.LastLoginSuffix): lastLogin.Format(time.RFC3339),
|
||||
customClaim(nbjwt.Invited): false,
|
||||
}
|
||||
tokenString, _ := token.SignedString(key)
|
||||
return tokenString
|
||||
},
|
||||
expected: &nbcontext.UserAuth{
|
||||
UserId: "user-id|123",
|
||||
AccountId: "account-id|567",
|
||||
Domain: "http://localhost",
|
||||
DomainCategory: "private",
|
||||
LastLogin: lastLogin,
|
||||
Invited: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid without custom claims",
|
||||
tokenFunc: func() string {
|
||||
token := jwt.New(jwt.SigningMethodRS256)
|
||||
token.Header["kid"] = keyId
|
||||
token.Claims = jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": []string{audience},
|
||||
"iat": time.Now().Unix(),
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"sub": "user-id|123",
|
||||
}
|
||||
tokenString, _ := token.SignedString(key)
|
||||
return tokenString
|
||||
},
|
||||
expected: &nbcontext.UserAuth{
|
||||
UserId: "user-id|123",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Expired token",
|
||||
tokenFunc: func() string {
|
||||
token := jwt.New(jwt.SigningMethodRS256)
|
||||
token.Header["kid"] = keyId
|
||||
token.Claims = jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": []string{audience},
|
||||
"iat": time.Now().Add(time.Hour * -2).Unix(),
|
||||
"exp": time.Now().Add(time.Hour * -1).Unix(),
|
||||
"sub": "user-id|123",
|
||||
}
|
||||
tokenString, _ := token.SignedString(key)
|
||||
return tokenString
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Not yet valid",
|
||||
tokenFunc: func() string {
|
||||
token := jwt.New(jwt.SigningMethodRS256)
|
||||
token.Header["kid"] = keyId
|
||||
token.Claims = jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": []string{audience},
|
||||
"iat": time.Now().Add(time.Hour).Unix(),
|
||||
"exp": time.Now().Add(time.Hour * 2).Unix(),
|
||||
"sub": "user-id|123",
|
||||
}
|
||||
tokenString, _ := token.SignedString(key)
|
||||
return tokenString
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid signature",
|
||||
tokenFunc: func() string {
|
||||
token := jwt.New(jwt.SigningMethodRS256)
|
||||
token.Header["kid"] = keyId
|
||||
token.Claims = jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": []string{audience},
|
||||
"iat": time.Now().Unix(),
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"sub": "user-id|123",
|
||||
}
|
||||
tokenString, _ := token.SignedString(key)
|
||||
parts := strings.Split(tokenString, ".")
|
||||
parts[2] = "invalid-signature"
|
||||
return strings.Join(parts, ".")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid issuer",
|
||||
tokenFunc: func() string {
|
||||
token := jwt.New(jwt.SigningMethodRS256)
|
||||
token.Header["kid"] = keyId
|
||||
token.Claims = jwt.MapClaims{
|
||||
"iss": "not-the-issuer",
|
||||
"aud": []string{audience},
|
||||
"iat": time.Now().Unix(),
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"sub": "user-id|123",
|
||||
}
|
||||
tokenString, _ := token.SignedString(key)
|
||||
return tokenString
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid audience",
|
||||
tokenFunc: func() string {
|
||||
token := jwt.New(jwt.SigningMethodRS256)
|
||||
token.Header["kid"] = keyId
|
||||
token.Claims = jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": []string{"not-the-audience"},
|
||||
"iat": time.Now().Unix(),
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"sub": "user-id|123",
|
||||
}
|
||||
tokenString, _ := token.SignedString(key)
|
||||
return tokenString
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid user claim",
|
||||
tokenFunc: func() string {
|
||||
token := jwt.New(jwt.SigningMethodRS256)
|
||||
token.Header["kid"] = keyId
|
||||
token.Claims = jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": []string{audience},
|
||||
"iat": time.Now().Unix(),
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"not-sub": "user-id|123",
|
||||
}
|
||||
tokenString, _ := token.SignedString(key)
|
||||
return tokenString
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenString := tt.tokenFunc()
|
||||
|
||||
userAuth, token, err := manager.ValidateAndParseToken(context.Background(), tokenString)
|
||||
|
||||
if tt.expected != nil {
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, token.Valid)
|
||||
assert.Equal(t, *tt.expected, userAuth)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, token)
|
||||
assert.Empty(t, userAuth)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
11
management/server/auth/test_data/jwks.json
Normal file
11
management/server/auth/test_data/jwks.json
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"kty": "RSA",
|
||||
"kid": "test-key",
|
||||
"use": "sig",
|
||||
"n": "4f5wg5l2hKsTeNem_V41fGnJm6gOdrj8ym3rFkEU_wT8RDtnSgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7mCpz9Er5qLaMXJwZxzHzAahlfA0icqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBpHssPnpYGIn20ZZuNlX2BrClciHhCPUIIZOQn_MmqTD31jSyjoQoV7MhhMTATKJx2XrHhR-1DcKJzQBSTAGnpYVaqpsARap-nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3bODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy7w",
|
||||
"e": "AQAB"
|
||||
}
|
||||
]
|
||||
}
|
27
management/server/auth/test_data/sample_key
Normal file
27
management/server/auth/test_data/sample_key
Normal file
@ -0,0 +1,27 @@
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEowIBAAKCAQEA4f5wg5l2hKsTeNem/V41fGnJm6gOdrj8ym3rFkEU/wT8RDtn
|
||||
SgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7mCpz9Er5qLaMXJwZxzHzAahlfA0i
|
||||
cqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBpHssPnpYGIn20ZZuNlX2BrClciHhC
|
||||
PUIIZOQn/MmqTD31jSyjoQoV7MhhMTATKJx2XrHhR+1DcKJzQBSTAGnpYVaqpsAR
|
||||
ap+nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3bODIRe1AuTyHceAbewn8b462yEWKA
|
||||
Rdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy7wIDAQABAoIBAQCwia1k7+2oZ2d3
|
||||
n6agCAbqIE1QXfCmh41ZqJHbOY3oRQG3X1wpcGH4Gk+O+zDVTV2JszdcOt7E5dAy
|
||||
MaomETAhRxB7hlIOnEN7WKm+dGNrKRvV0wDU5ReFMRHg31/Lnu8c+5BvGjZX+ky9
|
||||
POIhFFYJqwCRlopGSUIxmVj5rSgtzk3iWOQXr+ah1bjEXvlxDOWkHN6YfpV5ThdE
|
||||
KdBIPGEVqa63r9n2h+qazKrtiRqJqGnOrHzOECYbRFYhexsNFz7YT02xdfSHn7gM
|
||||
IvabDDP/Qp0PjE1jdouiMaFHYnLBbgvlnZW9yuVf/rpXTUq/njxIXMmvmEyyvSDn
|
||||
FcFikB8pAoGBAPF77hK4m3/rdGT7X8a/gwvZ2R121aBcdPwEaUhvj/36dx596zvY
|
||||
mEOjrWfZhF083/nYWE2kVquj2wjs+otCLfifEEgXcVPTnEOPO9Zg3uNSL0nNQghj
|
||||
FuD3iGLTUBCtM66oTe0jLSslHe8gLGEQqyMzHOzYxNqibxcOZIe8Qt0NAoGBAO+U
|
||||
I5+XWjWEgDmvyC3TrOSf/KCGjtu0TSv30ipv27bDLMrpvPmD/5lpptTFwcxvVhCs
|
||||
2b+chCjlghFSWFbBULBrfci2FtliClOVMYrlNBdUSJhf3aYSG2Doe6Bgt1n2CpNn
|
||||
/iu37Y3NfemZBJA7hNl4dYe+f+uzM87cdQ214+jrAoGAXA0XxX8ll2+ToOLJsaNT
|
||||
OvNB9h9Uc5qK5X5w+7G7O998BN2PC/MWp8H+2fVqpXgNENpNXttkRm1hk1dych86
|
||||
EunfdPuqsX+as44oCyJGFHVBnWpm33eWQw9YqANRI+pCJzP08I5WK3osnPiwshd+
|
||||
hR54yjgfYhBFNI7B95PmEQkCgYBzFSz7h1+s34Ycr8SvxsOBWxymG5zaCsUbPsL0
|
||||
4aCgLScCHb9J+E86aVbbVFdglYa5Id7DPTL61ixhl7WZjujspeXZGSbmq0Kcnckb
|
||||
mDgqkLECiOJW2NHP/j0McAkDLL4tysF8TLDO8gvuvzNC+WQ6drO2ThrypLVZQ+ry
|
||||
eBIPmwKBgEZxhqa0gVvHQG/7Od69KWj4eJP28kq13RhKay8JOoN0vPmspXJo1HY3
|
||||
CKuHRG+AP579dncdUnOMvfXOtkdM4vk0+hWASBQzM9xzVcztCa+koAugjVaLS9A+
|
||||
9uQoqEeVNTckxx0S2bYevRy7hGQmUJTyQm3j1zEUR5jpdbL83Fbq
|
||||
-----END RSA PRIVATE KEY-----
|
9
management/server/auth/test_data/sample_key.pub
Normal file
9
management/server/auth/test_data/sample_key.pub
Normal file
@ -0,0 +1,9 @@
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4f5wg5l2hKsTeNem/V41
|
||||
fGnJm6gOdrj8ym3rFkEU/wT8RDtnSgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7
|
||||
mCpz9Er5qLaMXJwZxzHzAahlfA0icqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBp
|
||||
HssPnpYGIn20ZZuNlX2BrClciHhCPUIIZOQn/MmqTD31jSyjoQoV7MhhMTATKJx2
|
||||
XrHhR+1DcKJzQBSTAGnpYVaqpsARap+nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3b
|
||||
ODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy
|
||||
7wIDAQAB
|
||||
-----END PUBLIC KEY-----
|
@ -2,7 +2,6 @@ package server
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"net/url"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@ -180,9 +179,3 @@ type ReverseProxy struct {
|
||||
// trusted IP prefixes.
|
||||
TrustedPeers []netip.Prefix
|
||||
}
|
||||
|
||||
// validateURL validates input http url
|
||||
func validateURL(httpURL string) bool {
|
||||
_, err := url.ParseRequestURI(httpURL)
|
||||
return err == nil
|
||||
}
|
||||
|
60
management/server/context/auth.go
Normal file
60
management/server/context/auth.go
Normal file
@ -0,0 +1,60 @@
|
||||
package context
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type key int
|
||||
|
||||
const (
|
||||
UserAuthContextKey key = iota
|
||||
)
|
||||
|
||||
type UserAuth struct {
|
||||
// The account id the user is accessing
|
||||
AccountId string
|
||||
// The account domain
|
||||
Domain string
|
||||
// The account domain category, TBC values
|
||||
DomainCategory string
|
||||
// Indicates whether this user was invited, TBC logic
|
||||
Invited bool
|
||||
// Indicates whether this is a child account
|
||||
IsChild bool
|
||||
|
||||
// The user id
|
||||
UserId string
|
||||
// Last login time for this user
|
||||
LastLogin time.Time
|
||||
// The Groups the user belongs to on this account
|
||||
Groups []string
|
||||
|
||||
// Indicates whether this user has authenticated with a Personal Access Token
|
||||
IsPAT bool
|
||||
}
|
||||
|
||||
func GetUserAuthFromRequest(r *http.Request) (UserAuth, error) {
|
||||
return GetUserAuthFromContext(r.Context())
|
||||
}
|
||||
|
||||
func SetUserAuthInRequest(r *http.Request, userAuth UserAuth) *http.Request {
|
||||
return r.WithContext(SetUserAuthInContext(r.Context(), userAuth))
|
||||
}
|
||||
|
||||
func GetUserAuthFromContext(ctx context.Context) (UserAuth, error) {
|
||||
if userAuth, ok := ctx.Value(UserAuthContextKey).(UserAuth); ok {
|
||||
return userAuth, nil
|
||||
}
|
||||
return UserAuth{}, fmt.Errorf("user auth not in context")
|
||||
}
|
||||
|
||||
func SetUserAuthInContext(ctx context.Context, userAuth UserAuth) context.Context {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, UserIDKey, userAuth.UserId)
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, AccountIDKey, userAuth.AccountId)
|
||||
return context.WithValue(ctx, UserAuthContextKey, userAuth)
|
||||
}
|
@ -20,8 +20,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@ -39,11 +39,10 @@ type GRPCServer struct {
|
||||
peersUpdateManager *PeersUpdateManager
|
||||
config *Config
|
||||
secretsManager SecretsManager
|
||||
jwtValidator jwtclaims.JWTValidator
|
||||
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
||||
appMetrics telemetry.AppMetrics
|
||||
ephemeralManager *EphemeralManager
|
||||
peerLocks sync.Map
|
||||
authManager auth.Manager
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
@ -56,29 +55,13 @@ func NewServer(
|
||||
secretsManager SecretsManager,
|
||||
appMetrics telemetry.AppMetrics,
|
||||
ephemeralManager *EphemeralManager,
|
||||
authManager auth.Manager,
|
||||
) (*GRPCServer, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var jwtValidator jwtclaims.JWTValidator
|
||||
|
||||
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
|
||||
jwtValidator, err = jwtclaims.NewJWTValidator(
|
||||
ctx,
|
||||
config.HttpConfig.AuthIssuer,
|
||||
config.GetAuthAudiences(),
|
||||
config.HttpConfig.AuthKeysLocation,
|
||||
config.HttpConfig.IdpSignKeyRefreshEnabled,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debug("unable to use http config to create new jwt middleware")
|
||||
}
|
||||
|
||||
if appMetrics != nil {
|
||||
// update gauge based on number of connected peers which is equal to open gRPC streams
|
||||
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
|
||||
@ -89,16 +72,6 @@ func NewServer(
|
||||
}
|
||||
}
|
||||
|
||||
var audience, userIDClaim string
|
||||
if config.HttpConfig != nil {
|
||||
audience = config.HttpConfig.AuthAudience
|
||||
userIDClaim = config.HttpConfig.AuthUserIDClaim
|
||||
}
|
||||
jwtClaimsExtractor := jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(audience),
|
||||
jwtclaims.WithUserIDClaim(userIDClaim),
|
||||
)
|
||||
|
||||
return &GRPCServer{
|
||||
wgKey: key,
|
||||
// peerKey -> event channel
|
||||
@ -107,8 +80,7 @@ func NewServer(
|
||||
settingsManager: settingsManager,
|
||||
config: config,
|
||||
secretsManager: secretsManager,
|
||||
jwtValidator: jwtValidator,
|
||||
jwtClaimsExtractor: jwtClaimsExtractor,
|
||||
authManager: authManager,
|
||||
appMetrics: appMetrics,
|
||||
ephemeralManager: ephemeralManager,
|
||||
}, nil
|
||||
@ -294,26 +266,37 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p
|
||||
}
|
||||
|
||||
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||
if s.jwtValidator == nil {
|
||||
return "", status.Error(codes.Internal, "no jwt validator set")
|
||||
if s.authManager == nil {
|
||||
return "", status.Errorf(codes.Internal, "missing auth manager")
|
||||
}
|
||||
|
||||
token, err := s.jwtValidator.ValidateAndParse(ctx, jwtToken)
|
||||
userAuth, token, err := s.authManager.ValidateAndParseToken(ctx, jwtToken)
|
||||
if err != nil {
|
||||
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
|
||||
}
|
||||
claims := s.jwtClaimsExtractor.FromToken(token)
|
||||
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
_, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims)
|
||||
accountId, _, err := s.accountManager.GetAccountIDFromUserAuth(ctx, userAuth)
|
||||
if err != nil {
|
||||
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
||||
}
|
||||
|
||||
if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil {
|
||||
if userAuth.AccountId != accountId {
|
||||
log.WithContext(ctx).Debugf("gRPC server sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
|
||||
userAuth.AccountId = accountId
|
||||
}
|
||||
|
||||
userAuth, err = s.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, token)
|
||||
if err != nil {
|
||||
return "", status.Error(codes.PermissionDenied, err.Error())
|
||||
}
|
||||
|
||||
return claims.UserId, nil
|
||||
err = s.accountManager.SyncUserJWTGroups(ctx, userAuth)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("gRPC server failed to sync user JWT groups: %s", err)
|
||||
}
|
||||
|
||||
return userAuth.UserId, nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
|
@ -11,9 +11,9 @@ import (
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
s "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/accounts"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/dns"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/events"
|
||||
@ -26,7 +26,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/users"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
@ -36,55 +35,51 @@ import (
|
||||
const apiPrefix = "/api"
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
|
||||
claimsExtractor := jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
)
|
||||
func NewAPIHandler(
|
||||
ctx context.Context,
|
||||
accountManager s.AccountManager,
|
||||
networksManager nbnetworks.Manager,
|
||||
resourceManager resources.Manager,
|
||||
routerManager routers.Manager,
|
||||
groupsManager nbgroups.Manager,
|
||||
LocationManager geolocation.Geolocation,
|
||||
authManager auth.Manager,
|
||||
appMetrics telemetry.AppMetrics,
|
||||
config *s.Config,
|
||||
integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
accountManager.GetPATInfo,
|
||||
jwtValidator.ValidateAndParse,
|
||||
accountManager.MarkPATUsed,
|
||||
accountManager.CheckUserAccessByJWTGroups,
|
||||
claimsExtractor,
|
||||
authCfg.Audience,
|
||||
authCfg.UserIDClaim,
|
||||
authManager,
|
||||
accountManager.GetAccountIDFromUserAuth,
|
||||
accountManager.SyncUserJWTGroups,
|
||||
)
|
||||
|
||||
corsMiddleware := cors.AllowAll()
|
||||
|
||||
claimsExtractor = jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
)
|
||||
|
||||
acMiddleware := middleware.NewAccessControl(
|
||||
authCfg.Audience,
|
||||
authCfg.UserIDClaim,
|
||||
accountManager.GetUser)
|
||||
acMiddleware := middleware.NewAccessControl(accountManager.GetUserFromUserAuth)
|
||||
|
||||
rootRouter := mux.NewRouter()
|
||||
metricsMiddleware := appMetrics.HTTPMiddleware()
|
||||
|
||||
prefix := apiPrefix
|
||||
router := rootRouter.PathPrefix(prefix).Subrouter()
|
||||
|
||||
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
|
||||
|
||||
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil {
|
||||
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter()); err != nil {
|
||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
||||
}
|
||||
|
||||
accounts.AddEndpoints(accountManager, authCfg, router)
|
||||
peers.AddEndpoints(accountManager, authCfg, router)
|
||||
users.AddEndpoints(accountManager, authCfg, router)
|
||||
setup_keys.AddEndpoints(accountManager, authCfg, router)
|
||||
policies.AddEndpoints(accountManager, LocationManager, authCfg, router)
|
||||
groups.AddEndpoints(accountManager, authCfg, router)
|
||||
routes.AddEndpoints(accountManager, authCfg, router)
|
||||
dns.AddEndpoints(accountManager, authCfg, router)
|
||||
events.AddEndpoints(accountManager, authCfg, router)
|
||||
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, accountManager.GetAccountIDFromToken, authCfg, router)
|
||||
accounts.AddEndpoints(accountManager, router)
|
||||
peers.AddEndpoints(accountManager, router)
|
||||
users.AddEndpoints(accountManager, router)
|
||||
setup_keys.AddEndpoints(accountManager, router)
|
||||
policies.AddEndpoints(accountManager, LocationManager, router)
|
||||
groups.AddEndpoints(accountManager, router)
|
||||
routes.AddEndpoints(accountManager, router)
|
||||
dns.AddEndpoints(accountManager, router)
|
||||
events.AddEndpoints(accountManager, router)
|
||||
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
|
||||
|
||||
return rootRouter, nil
|
||||
}
|
||||
|
@ -9,47 +9,42 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// handler is a handler that handles the server.Account HTTP endpoints
|
||||
type handler struct {
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
accountsHandler := newHandler(accountManager, authCfg)
|
||||
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
|
||||
accountsHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new handler HTTP handler
|
||||
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
|
||||
func newHandler(accountManager server.AccountManager) *handler {
|
||||
return &handler{
|
||||
accountManager: accountManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
||||
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@ -62,13 +57,14 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
||||
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
_, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
_, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
accountID := vars["accountId"]
|
||||
if len(accountID) == 0 {
|
||||
@ -125,7 +121,12 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// deleteAccount is a HTTP DELETE handler to delete an account
|
||||
func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetAccountID := vars["accountId"]
|
||||
if len(targetAccountID) == 0 {
|
||||
@ -133,7 +134,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.DeleteAccount(r.Context(), targetAccountID, claims.UserId)
|
||||
err = h.accountManager.DeleteAccount(r.Context(), targetAccountID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
@ -13,19 +13,16 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func initAccountsTestData(account *types.Account, admin *types.User) *handler {
|
||||
func initAccountsTestData(account *types.Account) *handler {
|
||||
return &handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return account.Id, admin.Id, nil
|
||||
},
|
||||
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
||||
return account.Settings, nil
|
||||
},
|
||||
@ -44,15 +41,6 @@ func initAccountsTestData(account *types.Account, admin *types.User) *handler {
|
||||
return accCopy, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_account",
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -75,7 +63,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
PeerLoginExpiration: time.Hour,
|
||||
RegularUsersViewBlocked: true,
|
||||
},
|
||||
}, adminUser)
|
||||
})
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
@ -191,6 +179,11 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: adminUser.Id,
|
||||
AccountId: accountID,
|
||||
Domain: "hotmail.com",
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET")
|
||||
|
@ -8,51 +8,44 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// dnsSettingsHandler is a handler that returns the DNS settings of the account
|
||||
type dnsSettingsHandler struct {
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
addDNSSettingEndpoint(accountManager, authCfg, router)
|
||||
addDNSNameserversEndpoint(accountManager, authCfg, router)
|
||||
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
|
||||
addDNSSettingEndpoint(accountManager, router)
|
||||
addDNSNameserversEndpoint(accountManager, router)
|
||||
}
|
||||
|
||||
func addDNSSettingEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
dnsSettingsHandler := newDNSSettingsHandler(accountManager, authCfg)
|
||||
func addDNSSettingEndpoint(accountManager server.AccountManager, router *mux.Router) {
|
||||
dnsSettingsHandler := newDNSSettingsHandler(accountManager)
|
||||
router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS")
|
||||
}
|
||||
|
||||
// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler
|
||||
func newDNSSettingsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *dnsSettingsHandler {
|
||||
return &dnsSettingsHandler{
|
||||
accountManager: accountManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
func newDNSSettingsHandler(accountManager server.AccountManager) *dnsSettingsHandler {
|
||||
return &dnsSettingsHandler{accountManager: accountManager}
|
||||
}
|
||||
|
||||
// getDNSSettings returns the DNS settings for the account
|
||||
func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@ -68,13 +61,14 @@ func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Reque
|
||||
|
||||
// updateDNSSettings handles update to DNS settings of an account
|
||||
func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
var req api.PutApiDnsSettingsJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
|
@ -17,7 +17,8 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
|
||||
@ -52,19 +53,7 @@ func initDNSSettingsTestData() *dnsSettingsHandler {
|
||||
}
|
||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||
},
|
||||
GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: testDNSSettingsAccountID,
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -118,6 +107,11 @@ func TestDNSSettingsHandlers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id,
|
||||
AccountId: testingDNSSettingsAccount.Id,
|
||||
Domain: testingDNSSettingsAccount.Domain,
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET")
|
||||
|
@ -10,21 +10,19 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
// nameserversHandler is the nameserver group handler of the account
|
||||
type nameserversHandler struct {
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
}
|
||||
|
||||
func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
nameserversHandler := newNameserversHandler(accountManager, authCfg)
|
||||
func addDNSNameserversEndpoint(accountManager server.AccountManager, router *mux.Router) {
|
||||
nameserversHandler := newNameserversHandler(accountManager)
|
||||
router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS")
|
||||
@ -33,26 +31,21 @@ func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg con
|
||||
}
|
||||
|
||||
// newNameserversHandler returns a new instance of nameserversHandler handler
|
||||
func newNameserversHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *nameserversHandler {
|
||||
return &nameserversHandler{
|
||||
accountManager: accountManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
func newNameserversHandler(accountManager server.AccountManager) *nameserversHandler {
|
||||
return &nameserversHandler{accountManager: accountManager}
|
||||
}
|
||||
|
||||
// getAllNameservers returns the list of nameserver groups for the account
|
||||
func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@ -69,13 +62,14 @@ func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Re
|
||||
|
||||
// createNameserverGroup handles nameserver group creation request
|
||||
func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
var req api.PostApiDnsNameserversJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@ -102,13 +96,14 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
|
||||
|
||||
// updateNameserverGroup handles update to a nameserver group identified by a given ID
|
||||
func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
@ -153,13 +148,14 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
|
||||
// deleteNameserverGroup handles nameserver group deletion request
|
||||
func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
@ -177,14 +173,14 @@ func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *htt
|
||||
|
||||
// getNameserverGroup handles a nameserver group Get request identified by ID
|
||||
func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
|
@ -18,7 +18,8 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
|
||||
@ -81,19 +82,7 @@ func initNameserversTestData() *nameserversHandler {
|
||||
}
|
||||
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
|
||||
},
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: testNSGroupAccountID,
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -204,6 +193,11 @@ func TestNameserversHandlers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: "test_user",
|
||||
AccountId: testNSGroupAccountID,
|
||||
Domain: "hotmail.com",
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET")
|
||||
|
@ -10,44 +10,37 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
)
|
||||
|
||||
// handler HTTP handler
|
||||
type handler struct {
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
eventsHandler := newHandler(accountManager, authCfg)
|
||||
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
|
||||
eventsHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new events handler
|
||||
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
|
||||
return &handler{
|
||||
accountManager: accountManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
func newHandler(accountManager server.AccountManager) *handler {
|
||||
return &handler{accountManager: accountManager}
|
||||
}
|
||||
|
||||
// getAllEvents list of the given account
|
||||
func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
|
@ -13,9 +13,10 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
@ -29,22 +30,10 @@ func initEventsTestData(account string, events ...*activity.Event) *handler {
|
||||
}
|
||||
return []*activity.Event{}, nil
|
||||
},
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) (map[string]*types.UserInfo, error) {
|
||||
return make(map[string]*types.UserInfo), nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_account",
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -199,6 +188,11 @@ func TestEvents_GetEvents(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_account",
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET")
|
||||
|
@ -7,24 +7,23 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// handler is a handler that returns groups of the account
|
||||
type handler struct {
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
groupsHandler := newHandler(accountManager, authCfg)
|
||||
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
|
||||
groupsHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS")
|
||||
@ -33,25 +32,21 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg,
|
||||
}
|
||||
|
||||
// newHandler creates a new groups handler
|
||||
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
|
||||
func newHandler(accountManager server.AccountManager) *handler {
|
||||
return &handler{
|
||||
accountManager: accountManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// getAllGroups list for the account
|
||||
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
@ -75,13 +70,14 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// updateGroup handles update to a group identified by a given ID
|
||||
func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
groupID, ok := vars["groupId"]
|
||||
if !ok {
|
||||
@ -164,13 +160,14 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// createGroup handles group creation request
|
||||
func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
var req api.PostApiGroupsJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@ -223,13 +220,14 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// deleteGroup handles group deletion request
|
||||
func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
groupID := mux.Vars(r)["groupId"]
|
||||
if len(groupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
@ -253,12 +251,13 @@ func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// getGroup returns a group
|
||||
func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
groupID := mux.Vars(r)["groupId"]
|
||||
if len(groupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
|
@ -18,9 +18,9 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
@ -59,9 +59,6 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
|
||||
|
||||
return group, nil
|
||||
},
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
|
||||
if groupName == "All" {
|
||||
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
|
||||
@ -87,15 +84,6 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -134,6 +122,11 @@ func TestGetGroup(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET")
|
||||
@ -255,6 +248,11 @@ func TestWriteGroup(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/groups", p.createGroup).Methods("POST")
|
||||
@ -332,7 +330,11 @@ func TestDeleteGroup(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
})
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
@ -10,11 +10,10 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
s "github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
@ -31,16 +30,14 @@ type handler struct {
|
||||
routerManager routers.Manager
|
||||
accountManager s.AccountManager
|
||||
|
||||
groupsManager groups.Manager
|
||||
extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
groupsManager groups.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) {
|
||||
addRouterEndpoints(routerManager, extractFromToken, authCfg, router)
|
||||
addResourceEndpoints(resourceManager, groupsManager, extractFromToken, authCfg, router)
|
||||
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, router *mux.Router) {
|
||||
addRouterEndpoints(routerManager, router)
|
||||
addResourceEndpoints(resourceManager, groupsManager, router)
|
||||
|
||||
networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager, extractFromToken, authCfg)
|
||||
networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager)
|
||||
router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS")
|
||||
@ -48,29 +45,25 @@ func AddEndpoints(networksManager networks.Manager, resourceManager resources.Ma
|
||||
router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler {
|
||||
func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager) *handler {
|
||||
return &handler{
|
||||
networksManager: networksManager,
|
||||
resourceManager: resourceManager,
|
||||
routerManager: routerManager,
|
||||
groupsManager: groupsManager,
|
||||
accountManager: accountManager,
|
||||
extractFromToken: extractFromToken,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
networksManager: networksManager,
|
||||
resourceManager: resourceManager,
|
||||
routerManager: routerManager,
|
||||
groupsManager: groupsManager,
|
||||
accountManager: accountManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@ -105,12 +98,12 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
var req api.NetworkRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
@ -141,12 +134,12 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
networkID := vars["networkId"]
|
||||
@ -179,13 +172,13 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
networkID := vars["networkId"]
|
||||
if len(networkID) == 0 {
|
||||
@ -229,13 +222,13 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
networkID := vars["networkId"]
|
||||
if len(networkID) == 0 {
|
||||
|
@ -1,30 +1,26 @@
|
||||
package networks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
)
|
||||
|
||||
type resourceHandler struct {
|
||||
resourceManager resources.Manager
|
||||
groupsManager groups.Manager
|
||||
extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
resourceManager resources.Manager
|
||||
groupsManager groups.Manager
|
||||
}
|
||||
|
||||
func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) {
|
||||
resourceHandler := newResourceHandler(resourcesManager, groupsManager, extractFromToken, authCfg)
|
||||
func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, router *mux.Router) {
|
||||
resourceHandler := newResourceHandler(resourcesManager, groupsManager)
|
||||
router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS")
|
||||
@ -33,26 +29,21 @@ func addResourceEndpoints(resourcesManager resources.Manager, groupsManager grou
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *resourceHandler {
|
||||
func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager) *resourceHandler {
|
||||
return &resourceHandler{
|
||||
resourceManager: resourceManager,
|
||||
groupsManager: groupsManager,
|
||||
extractFromToken: extractFromToken,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
resourceManager: resourceManager,
|
||||
groupsManager: groupsManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID)
|
||||
if err != nil {
|
||||
@ -76,13 +67,14 @@ func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *htt
|
||||
util.WriteJSONObject(r.Context(), w, resourcesResponse)
|
||||
}
|
||||
func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@ -106,13 +98,14 @@ func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *htt
|
||||
}
|
||||
|
||||
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
var req api.NetworkResourceRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@ -144,13 +137,13 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
resourceID := mux.Vars(r)["resourceId"]
|
||||
resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID)
|
||||
@ -171,13 +164,13 @@ func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
var req api.NetworkResourceRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@ -209,12 +202,12 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
resourceID := mux.Vars(r)["resourceId"]
|
||||
|
@ -1,28 +1,24 @@
|
||||
package networks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
)
|
||||
|
||||
type routersHandler struct {
|
||||
routersManager routers.Manager
|
||||
extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
routersManager routers.Manager
|
||||
}
|
||||
|
||||
func addRouterEndpoints(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) {
|
||||
routersHandler := newRoutersHandler(routersManager, extractFromToken, authCfg)
|
||||
func addRouterEndpoints(routersManager routers.Manager, router *mux.Router) {
|
||||
routersHandler := newRoutersHandler(routersManager)
|
||||
router.HandleFunc("/networks/{networkId}/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS")
|
||||
@ -30,25 +26,21 @@ func addRouterEndpoints(routersManager routers.Manager, extractFromToken func(ct
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newRoutersHandler(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *routersHandler {
|
||||
func newRoutersHandler(routersManager routers.Manager) *routersHandler {
|
||||
return &routersHandler{
|
||||
routersManager: routersManager,
|
||||
extractFromToken: extractFromToken,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
routersManager: routersManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID)
|
||||
if err != nil {
|
||||
@ -65,13 +57,14 @@ func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
var req api.NetworkRouterRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
@ -96,13 +89,14 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
routerID := mux.Vars(r)["routerId"]
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID)
|
||||
@ -115,13 +109,14 @@ func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
var req api.NetworkRouterRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@ -146,13 +141,13 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.extractFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
routerID := mux.Vars(r)["routerId"]
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID)
|
||||
|
@ -10,11 +10,10 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@ -22,12 +21,11 @@ import (
|
||||
|
||||
// Handler is a handler that returns peers of the account
|
||||
type Handler struct {
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
peersHandler := NewHandler(accountManager, authCfg)
|
||||
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
|
||||
peersHandler := NewHandler(accountManager)
|
||||
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
|
||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||
@ -35,13 +33,9 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg,
|
||||
}
|
||||
|
||||
// NewHandler creates a new peers Handler
|
||||
func NewHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *Handler {
|
||||
func NewHandler(accountManager server.AccountManager) *Handler {
|
||||
return &Handler{
|
||||
accountManager: accountManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -149,12 +143,13 @@ func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peer
|
||||
|
||||
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
|
||||
func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
@ -179,13 +174,14 @@ func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// GetAllPeers returns a list of all peers associated with a provided account
|
||||
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@ -230,13 +226,14 @@ func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPee
|
||||
|
||||
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
|
||||
func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
|
@ -15,8 +15,8 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
@ -25,16 +25,13 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
|
||||
type ctxKey string
|
||||
|
||||
const (
|
||||
testPeerID = "test_peer"
|
||||
noUpdateChannelTestPeerID = "no-update-channel"
|
||||
|
||||
adminUser = "admin_user"
|
||||
regularUser = "regular_user"
|
||||
serviceUser = "service_user"
|
||||
userIDKey ctxKey = "user_id"
|
||||
adminUser = "admin_user"
|
||||
regularUser = "regular_user"
|
||||
serviceUser = "service_user"
|
||||
)
|
||||
|
||||
func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
|
||||
@ -146,9 +143,6 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
|
||||
GetDNSDomainFunc: func() string {
|
||||
return "netbird.selfhosted"
|
||||
},
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetAccountFunc: func(ctx context.Context, accountID string) (*types.Account, error) {
|
||||
return account, nil
|
||||
},
|
||||
@ -167,16 +161,6 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
|
||||
return ok
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
userID := r.Context().Value(userIDKey).(string)
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: userID,
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -267,8 +251,11 @@ func TestGetPeers(t *testing.T) {
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
ctx := context.WithValue(context.Background(), userIDKey, "admin_user")
|
||||
req = req.WithContext(ctx)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: "admin_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
|
||||
@ -412,8 +399,11 @@ func TestGetAccessiblePeers(t *testing.T) {
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
|
||||
ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID)
|
||||
req = req.WithContext(ctx)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: tc.callerUserID,
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
|
||||
|
@ -13,9 +13,9 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@ -43,23 +43,11 @@ func initGeolocationTestData(t *testing.T) *geolocationsHandler {
|
||||
|
||||
return &geolocationsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
|
||||
return types.NewAdminUser(id), nil
|
||||
},
|
||||
},
|
||||
geolocationManager: geo,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -112,6 +100,11 @@ func TestGetCitiesByCountry(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET")
|
||||
@ -200,6 +193,11 @@ func TestGetAllCountries(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET")
|
||||
|
@ -7,11 +7,10 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
@ -23,24 +22,19 @@ var (
|
||||
type geolocationsHandler struct {
|
||||
accountManager server.AccountManager
|
||||
geolocationManager geolocation.Geolocation
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg)
|
||||
func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, router *mux.Router) {
|
||||
locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager)
|
||||
router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// newGeolocationsHandlerHandler creates a new Geolocations handler
|
||||
func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler {
|
||||
func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation) *geolocationsHandler {
|
||||
return &geolocationsHandler{
|
||||
accountManager: accountManager,
|
||||
geolocationManager: geolocationManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -104,12 +98,13 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.
|
||||
}
|
||||
|
||||
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
|
||||
claims := l.claimsExtractor.FromRequestContext(r)
|
||||
_, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
user, err := l.accountManager.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -8,51 +8,46 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// handler is a handler that returns policy of the account
|
||||
type handler struct {
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
policiesHandler := newHandler(accountManager, authCfg)
|
||||
func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, router *mux.Router) {
|
||||
policiesHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS")
|
||||
addPostureCheckEndpoint(accountManager, locationManager, authCfg, router)
|
||||
addPostureCheckEndpoint(accountManager, locationManager, router)
|
||||
}
|
||||
|
||||
// newHandler creates a new policies handler
|
||||
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
|
||||
func newHandler(accountManager server.AccountManager) *handler {
|
||||
return &handler{
|
||||
accountManager: accountManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// getAllPolicies list for the account
|
||||
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@ -80,13 +75,14 @@ func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// updatePolicy handles update to a policy identified by a given ID
|
||||
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
@ -105,13 +101,14 @@ func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// createPolicy handles policy creation request
|
||||
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
h.savePolicy(w, r, accountID, userID, "")
|
||||
}
|
||||
|
||||
@ -306,13 +303,13 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
||||
|
||||
// deletePolicy handles policy deletion request
|
||||
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
@ -330,13 +327,14 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// getPolicy handles a group Get request identified by ID
|
||||
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
|
@ -13,8 +13,8 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@ -44,9 +44,6 @@ func initPoliciesTestData(policies ...*types.Policy) *handler {
|
||||
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*types.Group, error) {
|
||||
return []*types.Group{{ID: "F"}, {ID: "G"}}, nil
|
||||
},
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) {
|
||||
user := types.NewAdminUser(userID)
|
||||
return &types.Account{
|
||||
@ -65,15 +62,6 @@ func initPoliciesTestData(policies ...*types.Policy) *handler {
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -115,6 +103,11 @@ func TestPoliciesGetPolicy(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET")
|
||||
@ -274,6 +267,11 @@ func TestPoliciesWritePolicy(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/policies", p.createPolicy).Methods("POST")
|
||||
|
@ -7,11 +7,10 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
@ -20,40 +19,35 @@ import (
|
||||
type postureChecksHandler struct {
|
||||
accountManager server.AccountManager
|
||||
geolocationManager geolocation.Geolocation
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg)
|
||||
func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, router *mux.Router) {
|
||||
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager)
|
||||
router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS")
|
||||
addLocationsEndpoint(accountManager, locationManager, authCfg, router)
|
||||
addLocationsEndpoint(accountManager, locationManager, router)
|
||||
}
|
||||
|
||||
// newPostureChecksHandler creates a new PostureChecks handler
|
||||
func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler {
|
||||
func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation) *postureChecksHandler {
|
||||
return &postureChecksHandler{
|
||||
accountManager: accountManager,
|
||||
geolocationManager: geolocationManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// getAllPostureChecks list for the account
|
||||
func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@ -70,13 +64,14 @@ func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *htt
|
||||
|
||||
// updatePostureCheck handles update to a posture check identified by a given ID
|
||||
func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
@ -95,25 +90,26 @@ func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http
|
||||
|
||||
// createPostureCheck handles posture check creation request
|
||||
func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
p.savePostureChecks(w, r, accountID, userID, "")
|
||||
}
|
||||
|
||||
// getPostureCheck handles a posture check Get request identified by ID
|
||||
func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
@ -132,13 +128,13 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re
|
||||
|
||||
// deletePostureCheck handles posture check deletion request
|
||||
func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
|
@ -14,9 +14,9 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
@ -66,20 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH
|
||||
}
|
||||
return accountPostureChecks, nil
|
||||
},
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
},
|
||||
geolocationManager: &geolocation.Mock{},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -187,6 +175,11 @@ func TestGetPostureCheck(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET")
|
||||
@ -835,6 +828,11 @@ func TestPostureCheckUpdate(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
})
|
||||
|
||||
defaultHandler := *p
|
||||
if tc.setupHandlerFunc != nil {
|
||||
|
@ -10,10 +10,9 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@ -22,12 +21,11 @@ const failedToConvertRoute = "failed to convert route to response: %v"
|
||||
|
||||
// handler is the routes handler of the account
|
||||
type handler struct {
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
routesHandler := newHandler(accountManager, authCfg)
|
||||
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
|
||||
routesHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS")
|
||||
@ -36,25 +34,22 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg,
|
||||
}
|
||||
|
||||
// newHandler returns a new instance of routes handler
|
||||
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
|
||||
func newHandler(accountManager server.AccountManager) *handler {
|
||||
return &handler{
|
||||
accountManager: accountManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// getAllRoutes returns the list of routes for the account
|
||||
func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@ -75,13 +70,14 @@ func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// createRoute handles route creation request
|
||||
func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
var req api.PostApiRoutesJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@ -172,13 +168,13 @@ func (h *handler) validateRoute(req api.PostApiRoutesJSONRequestBody) error {
|
||||
|
||||
// updateRoute handles update to a route identified by a given ID
|
||||
func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
routeID := vars["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
@ -265,13 +261,13 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// deleteRoute handles route deletion request
|
||||
func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
routeID := mux.Vars(r)["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
@ -289,13 +285,14 @@ func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// getRoute handles a route Get request identified by ID
|
||||
func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
routeID := mux.Vars(r)["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
|
@ -16,12 +16,10 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@ -60,32 +58,6 @@ var baseExistingRoute = &route.Route{
|
||||
Groups: []string{existingGroupID},
|
||||
}
|
||||
|
||||
var testingAccount = &types.Account{
|
||||
Id: testAccountID,
|
||||
Domain: "hotmail.com",
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
existingPeerID: {
|
||||
Key: existingPeerKey,
|
||||
IP: netip.MustParseAddr(existingPeerIP1).AsSlice(),
|
||||
ID: existingPeerID,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
GoOS: "linux",
|
||||
},
|
||||
},
|
||||
nonLinuxExistingPeerID: {
|
||||
Key: nonLinuxExistingPeerID,
|
||||
IP: netip.MustParseAddr(existingPeerIP2).AsSlice(),
|
||||
ID: nonLinuxExistingPeerID,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
GoOS: "darwin",
|
||||
},
|
||||
},
|
||||
},
|
||||
Users: map[string]*types.User{
|
||||
"test_user": types.NewAdminUser("test_user"),
|
||||
},
|
||||
}
|
||||
|
||||
func initRoutesTestData() *handler {
|
||||
return &handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
@ -150,20 +122,7 @@ func initRoutesTestData() *handler {
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
// return testingAccount, testingAccount.Users["test_user"], nil
|
||||
return testingAccount.Id, testingAccount.Users["test_user"].Id, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -526,6 +485,11 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET")
|
||||
|
@ -10,22 +10,20 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// handler is a handler that returns a list of setup keys of the account
|
||||
type handler struct {
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
keysHandler := newHandler(accountManager, authCfg)
|
||||
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
|
||||
keysHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS")
|
||||
@ -34,25 +32,21 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg,
|
||||
}
|
||||
|
||||
// newHandler creates a new setup key handler
|
||||
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
|
||||
func newHandler(accountManager server.AccountManager) *handler {
|
||||
return &handler{
|
||||
accountManager: accountManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// createSetupKey is a POST requests that creates a new SetupKey
|
||||
func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
req := &api.PostApiSetupKeysJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@ -108,12 +102,12 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// getSetupKey is a GET request to get a SetupKey by ID
|
||||
func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
@ -133,13 +127,13 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// updateSetupKey is a PUT request to update server.SetupKey
|
||||
func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
@ -174,13 +168,13 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// getAllSetupKeys is a GET request that returns a list of SetupKey
|
||||
func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@ -196,13 +190,13 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
|
@ -14,8 +14,8 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@ -28,14 +28,9 @@ const (
|
||||
notFoundSetupKeyID = "notFoundSetupKeyID"
|
||||
)
|
||||
|
||||
func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey,
|
||||
user *types.User,
|
||||
) *handler {
|
||||
func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey) *handler {
|
||||
return &handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ types.SetupKeyType, _ time.Duration, _ []string,
|
||||
_ int, _ string, ephemeral bool, allowExtraDNSLabels bool,
|
||||
) (*types.SetupKey, error) {
|
||||
@ -76,15 +71,6 @@ func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKe
|
||||
return status.Errorf(status.NotFound, "key %s not found", keyID)
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: user.Id,
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "testAccountId",
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -171,12 +157,17 @@ func TestSetupKeysHandlers(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser)
|
||||
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: adminUser.Id,
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "testAccountId",
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS")
|
||||
|
@ -7,22 +7,20 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// patHandler is the nameserver group handler of the account
|
||||
type patHandler struct {
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
}
|
||||
|
||||
func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
tokenHandler := newPATsHandler(accountManager, authCfg)
|
||||
func addUsersTokensEndpoint(accountManager server.AccountManager, router *mux.Router) {
|
||||
tokenHandler := newPATsHandler(accountManager)
|
||||
router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS")
|
||||
@ -30,25 +28,21 @@ func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg config
|
||||
}
|
||||
|
||||
// newPATsHandler creates a new patHandler HTTP handler
|
||||
func newPATsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *patHandler {
|
||||
func newPATsHandler(accountManager server.AccountManager) *patHandler {
|
||||
return &patHandler{
|
||||
accountManager: accountManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
||||
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(userID) == 0 {
|
||||
@ -72,13 +66,13 @@ func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// getToken is HTTP GET handler that returns a personal access token for the given user
|
||||
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@ -103,13 +97,13 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// createToken is HTTP POST handler that creates a personal access token for the given user
|
||||
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@ -135,13 +129,13 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
||||
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
|
@ -12,11 +12,12 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@ -77,10 +78,6 @@ func initPATTestData() *patHandler {
|
||||
PersonalAccessToken: types.PersonalAccessToken{},
|
||||
}, nil
|
||||
},
|
||||
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
if accountID != existingAccountID {
|
||||
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
@ -115,15 +112,6 @@ func initPATTestData() *patHandler {
|
||||
return []*types.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -185,6 +173,11 @@ func TestTokenHandlers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")
|
||||
|
@ -9,39 +9,33 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
)
|
||||
|
||||
// handler is a handler that returns users of the account
|
||||
type handler struct {
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
userHandler := newHandler(accountManager, authCfg)
|
||||
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
|
||||
userHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
|
||||
addUsersTokensEndpoint(accountManager, authCfg, router)
|
||||
addUsersTokensEndpoint(accountManager, router)
|
||||
}
|
||||
|
||||
// newHandler creates a new UsersHandler HTTP handler
|
||||
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
|
||||
func newHandler(accountManager server.AccountManager) *handler {
|
||||
return &handler{
|
||||
accountManager: accountManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -52,13 +46,13 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@ -103,7 +97,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
|
||||
}
|
||||
|
||||
// deleteUser is a DELETE request to delete a user
|
||||
@ -113,13 +107,13 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@ -143,12 +137,12 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
req := &api.PostApiUsersJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
@ -184,7 +178,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
|
||||
}
|
||||
|
||||
// getAllUsers returns a list of users of the account this user belongs to.
|
||||
@ -195,13 +189,13 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@ -216,7 +210,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
continue
|
||||
}
|
||||
if serviceUser == "" {
|
||||
users = append(users, toUserResponse(d, claims.UserId))
|
||||
users = append(users, toUserResponse(d, userID))
|
||||
continue
|
||||
}
|
||||
|
||||
@ -227,7 +221,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if includeServiceUser == d.IsServiceUser {
|
||||
users = append(users, toUserResponse(d, claims.UserId))
|
||||
users = append(users, toUserResponse(d, userID))
|
||||
}
|
||||
}
|
||||
|
||||
@ -242,12 +236,12 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
|
@ -13,8 +13,8 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@ -64,9 +64,6 @@ var usersTestAccount = &types.Account{
|
||||
func initUsersTestData() *handler {
|
||||
return &handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return usersTestAccount.Id, claims.UserId, nil
|
||||
},
|
||||
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
|
||||
return usersTestAccount.Users[id], nil
|
||||
},
|
||||
@ -127,15 +124,6 @@ func initUsersTestData() *handler {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -158,6 +146,11 @@ func TestGetUsers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
userHandler.getAllUsers(recorder, req)
|
||||
|
||||
@ -263,6 +256,11 @@ func TestUpdateUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT")
|
||||
@ -355,6 +353,11 @@ func TestCreateUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
rr := httptest.NewRecorder()
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
userHandler.createUser(rr, req)
|
||||
|
||||
@ -399,6 +402,12 @@ func TestInviteUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = mux.SetURLVars(req, tc.requestVars)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.inviteUser(rr, req)
|
||||
@ -452,6 +461,12 @@ func TestDeleteUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = mux.SetURLVars(req, tc.requestVars)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.deleteUser(rr, req)
|
||||
|
@ -7,30 +7,24 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
)
|
||||
|
||||
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
|
||||
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
|
||||
type GetUser func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
||||
|
||||
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
||||
type AccessControl struct {
|
||||
claimsExtract jwtclaims.ClaimsExtractor
|
||||
getUser GetUser
|
||||
getUser GetUser
|
||||
}
|
||||
|
||||
// NewAccessControl instance constructor
|
||||
func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessControl {
|
||||
func NewAccessControl(getUser GetUser) *AccessControl {
|
||||
return &AccessControl{
|
||||
claimsExtract: *jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(audience),
|
||||
jwtclaims.WithUserIDClaim(userIDClaim),
|
||||
),
|
||||
getUser: getUser,
|
||||
}
|
||||
}
|
||||
@ -45,12 +39,16 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
claims := a.claimsExtract.FromRequestContext(r)
|
||||
|
||||
user, err := a.getUser(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromRequest(r)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||
log.WithContext(r.Context()).Errorf("failed to get user auth from request: %s", err)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid user auth"), w)
|
||||
}
|
||||
|
||||
user, err := a.getUser(r.Context(), userAuth)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get user: %s", err)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid user auth"), w)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -8,67 +8,41 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// GetAccountInfoFromPATFunc function
|
||||
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
|
||||
|
||||
// ValidateAndParseTokenFunc function
|
||||
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
||||
|
||||
// MarkPATUsedFunc function
|
||||
type MarkPATUsedFunc func(ctx context.Context, token string) error
|
||||
|
||||
// CheckUserAccessByJWTGroupsFunc function
|
||||
type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
getAccountInfoFromPAT GetAccountInfoFromPATFunc
|
||||
validateAndParseToken ValidateAndParseTokenFunc
|
||||
markPATUsed MarkPATUsedFunc
|
||||
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
audience string
|
||||
userIDClaim string
|
||||
authManager auth.Manager
|
||||
ensureAccount EnsureAccountFunc
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||
}
|
||||
|
||||
const (
|
||||
userProperty = "user"
|
||||
)
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
||||
audience string, userIdClaim string) *AuthMiddleware {
|
||||
if userIdClaim == "" {
|
||||
userIdClaim = jwtclaims.UserIDClaim
|
||||
}
|
||||
|
||||
func NewAuthMiddleware(
|
||||
authManager auth.Manager,
|
||||
ensureAccount EnsureAccountFunc,
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||
) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
getAccountInfoFromPAT: getAccountInfoFromPAT,
|
||||
validateAndParseToken: validateAndParseToken,
|
||||
markPATUsed: markPATUsed,
|
||||
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
|
||||
claimsExtractor: claimsExtractor,
|
||||
audience: audience,
|
||||
userIDClaim: userIdClaim,
|
||||
authManager: authManager,
|
||||
ensureAccount: ensureAccount,
|
||||
syncUserJWTGroups: syncUserJWTGroups,
|
||||
}
|
||||
}
|
||||
|
||||
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
|
||||
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
|
||||
return
|
||||
}
|
||||
@ -84,108 +58,111 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
|
||||
switch authType {
|
||||
case "bearer":
|
||||
err := m.checkJWTFromRequest(w, r, auth)
|
||||
request, err := m.checkJWTFromRequest(r, auth)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("Error when validating JWT claims: %s", err.Error())
|
||||
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, request)
|
||||
case "token":
|
||||
err := m.checkPATFromRequest(w, r, auth)
|
||||
request, err := m.checkPATFromRequest(r, auth)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT claims: %s", err.Error())
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, request)
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||
return
|
||||
}
|
||||
claims := m.claimsExtractor.FromRequestContext(r)
|
||||
//nolint
|
||||
ctx := context.WithValue(r.Context(), nbContext.UserIDKey, claims.UserId)
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, claims.AccountId)
|
||||
h.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// CheckJWTFromRequest checks if the JWT is valid
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) {
|
||||
token, err := getTokenFromJWTRequest(auth)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
validatedToken, err := m.validateAndParseToken(r.Context(), token)
|
||||
ctx := r.Context()
|
||||
|
||||
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
|
||||
if err != nil {
|
||||
return err
|
||||
return r, err
|
||||
}
|
||||
|
||||
if validatedToken == nil {
|
||||
return nil
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
userAuth.AccountId = impersonate[0]
|
||||
userAuth.IsChild = ok
|
||||
}
|
||||
|
||||
if err := m.verifyUserAccess(r.Context(), validatedToken); err != nil {
|
||||
return err
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
|
||||
// If we get here, everything worked and we can set the
|
||||
// user property in context.
|
||||
newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint
|
||||
// Update the current request with the new context information.
|
||||
*r = *newRequest
|
||||
return nil
|
||||
}
|
||||
if userAuth.AccountId != accountId {
|
||||
log.WithContext(ctx).Debugf("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
|
||||
userAuth.AccountId = accountId
|
||||
}
|
||||
|
||||
// verifyUserAccess checks if a user, based on a validated JWT token,
|
||||
// is allowed access, particularly in cases where the admin enabled JWT
|
||||
// group propagation and designated certain groups with access permissions.
|
||||
func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *jwt.Token) error {
|
||||
authClaims := m.claimsExtractor.FromToken(validatedToken)
|
||||
return m.checkUserAccessByJWTGroups(ctx, authClaims)
|
||||
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
|
||||
err = m.syncUserJWTGroups(ctx, userAuth)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err)
|
||||
}
|
||||
|
||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) {
|
||||
token, err := getTokenFromPATRequest(auth)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
|
||||
ctx := r.Context()
|
||||
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid Token: %w", err)
|
||||
return r, fmt.Errorf("invalid Token: %w", err)
|
||||
}
|
||||
if time.Now().After(pat.GetExpirationDate()) {
|
||||
return fmt.Errorf("token expired")
|
||||
return r, fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
err = m.markPATUsed(r.Context(), pat.ID)
|
||||
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
return r, err
|
||||
}
|
||||
|
||||
claimMaps := jwt.MapClaims{}
|
||||
claimMaps[m.userIDClaim] = user.Id
|
||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
|
||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
|
||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
|
||||
claimMaps[jwtclaims.IsToken] = true
|
||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
||||
// Update the current request with the new context information.
|
||||
*r = *newRequest
|
||||
return nil
|
||||
userAuth := nbcontext.UserAuth{
|
||||
UserId: user.Id,
|
||||
AccountId: user.AccountID,
|
||||
Domain: accDomain,
|
||||
DomainCategory: accCategory,
|
||||
IsPAT: true,
|
||||
}
|
||||
|
||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||
}
|
||||
|
||||
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
|
||||
// the JWT token from the Authorization header.
|
||||
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
|
||||
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
|
||||
return "", errors.New("Authorization header format must be Bearer {token}")
|
||||
return "", errors.New("authorization header format must be Bearer {token}")
|
||||
}
|
||||
|
||||
return authHeaderParts[1], nil
|
||||
@ -195,7 +172,7 @@ func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
|
||||
// the PAT token from the Authorization header.
|
||||
func getTokenFromPATRequest(authHeaderParts []string) (string, error) {
|
||||
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" {
|
||||
return "", errors.New("Authorization header format must be Token {token}")
|
||||
return "", errors.New("authorization header format must be Token {token}")
|
||||
}
|
||||
|
||||
return authHeaderParts[1], nil
|
||||
|
@ -9,10 +9,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
@ -58,17 +62,23 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
|
||||
return nil, nil, "", "", fmt.Errorf("PAT invalid")
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
|
||||
if token == JWT {
|
||||
return &jwt.Token{
|
||||
Claims: jwt.MapClaims{
|
||||
userIDClaim: userID,
|
||||
audience + jwtclaims.AccountIDSuffix: accountID,
|
||||
return nbcontext.UserAuth{
|
||||
UserId: userID,
|
||||
AccountId: accountID,
|
||||
Domain: testAccount.Domain,
|
||||
DomainCategory: testAccount.DomainCategory,
|
||||
},
|
||||
Valid: true,
|
||||
}, nil
|
||||
&jwt.Token{
|
||||
Claims: jwt.MapClaims{
|
||||
userIDClaim: userID,
|
||||
audience + nbjwt.AccountIDSuffix: accountID,
|
||||
},
|
||||
Valid: true,
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("JWT invalid")
|
||||
return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid")
|
||||
}
|
||||
|
||||
func mockMarkPATUsed(_ context.Context, token string) error {
|
||||
@ -78,16 +88,20 @@ func mockMarkPATUsed(_ context.Context, token string) error {
|
||||
return fmt.Errorf("Should never get reached")
|
||||
}
|
||||
|
||||
func mockCheckUserAccessByJWTGroups(_ context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||
if testAccount.Id != claims.AccountId {
|
||||
return fmt.Errorf("account with id %s does not exist", claims.AccountId)
|
||||
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
|
||||
if userAuth.IsChild || userAuth.IsPAT {
|
||||
return userAuth, nil
|
||||
}
|
||||
|
||||
if _, ok := testAccount.Users[claims.UserId]; !ok {
|
||||
return fmt.Errorf("user with id %s does not exist", claims.UserId)
|
||||
if testAccount.Id != userAuth.AccountId {
|
||||
return userAuth, fmt.Errorf("account with id %s does not exist", userAuth.AccountId)
|
||||
}
|
||||
|
||||
return nil
|
||||
if _, ok := testAccount.Users[userAuth.UserId]; !ok {
|
||||
return userAuth, fmt.Errorf("user with id %s does not exist", userAuth.UserId)
|
||||
}
|
||||
|
||||
return userAuth, nil
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
@ -158,22 +172,24 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
}
|
||||
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// do nothing
|
||||
|
||||
})
|
||||
|
||||
claimsExtractor := jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(audience),
|
||||
jwtclaims.WithUserIDClaim(userIDClaim),
|
||||
)
|
||||
mockAuth := &auth.MockManager{
|
||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||
EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
|
||||
MarkPATUsedFunc: mockMarkPATUsed,
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockGetAccountInfoFromPAT,
|
||||
mockValidateAndParseToken,
|
||||
mockMarkPATUsed,
|
||||
mockCheckUserAccessByJWTGroups,
|
||||
claimsExtractor,
|
||||
audience,
|
||||
userIDClaim,
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||
@ -195,9 +211,115 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
|
||||
result := rec.Result()
|
||||
defer result.Body.Close()
|
||||
|
||||
if result.StatusCode != tc.expectedStatusCode {
|
||||
t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, result.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
path string
|
||||
authHeader string
|
||||
expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status
|
||||
}{
|
||||
{
|
||||
name: "Valid PAT Token",
|
||||
path: "/test",
|
||||
authHeader: "Token " + PAT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
DomainCategory: testAccount.DomainCategory,
|
||||
IsPAT: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid PAT Token ignores child",
|
||||
path: "/test?account=xyz",
|
||||
authHeader: "Token " + PAT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
DomainCategory: testAccount.DomainCategory,
|
||||
IsPAT: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid JWT Token",
|
||||
path: "/test",
|
||||
authHeader: "Bearer " + JWT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
DomainCategory: testAccount.DomainCategory,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "Valid JWT Token with child",
|
||||
path: "/test?account=xyz",
|
||||
authHeader: "Bearer " + JWT,
|
||||
expectedUserAuth: &nbcontext.UserAuth{
|
||||
AccountId: "xyz",
|
||||
UserId: userID,
|
||||
Domain: testAccount.Domain,
|
||||
DomainCategory: testAccount.DomainCategory,
|
||||
IsChild: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mockAuth := &auth.MockManager{
|
||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||
EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
|
||||
MarkPATUsedFunc: mockMarkPATUsed,
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
handlerToTest := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromRequest(r)
|
||||
if tc.expectedUserAuth != nil {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, *tc.expectedUserAuth, userAuth)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, userAuth)
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "http://testing"+tc.path, nil)
|
||||
req.Header.Set("Authorization", tc.authHeader)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handlerToTest.ServeHTTP(rec, req)
|
||||
|
||||
result := rec.Result()
|
||||
defer result.Body.Close()
|
||||
|
||||
if tc.expectedUserAuth != nil {
|
||||
assert.Equal(t, 200, result.StatusCode)
|
||||
} else {
|
||||
assert.Equal(t, 401, result.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -77,13 +77,13 @@ func BenchmarkUpdatePeer(b *testing.B) {
|
||||
|
||||
func BenchmarkGetOnePeer(b *testing.B) {
|
||||
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||
"Peers - XS": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 70},
|
||||
"Peers - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 30},
|
||||
"Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 50},
|
||||
"Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130},
|
||||
"Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200},
|
||||
"Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130},
|
||||
"Peers - XS": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 70},
|
||||
"Peers - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 70},
|
||||
"Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 70},
|
||||
"Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 200},
|
||||
"Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 200},
|
||||
"Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 200},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 200},
|
||||
"Peers - XL": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 750},
|
||||
}
|
||||
|
||||
@ -111,9 +111,9 @@ func BenchmarkGetOnePeer(b *testing.B) {
|
||||
|
||||
func BenchmarkGetAllPeers(b *testing.B) {
|
||||
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||
"Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 150},
|
||||
"Peers - S": {MinMsPerOpLocal: 2, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 30},
|
||||
"Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 70},
|
||||
"Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 100},
|
||||
"Peers - S": {MinMsPerOpLocal: 2, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 100},
|
||||
"Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 100},
|
||||
"Peers - L": {MinMsPerOpLocal: 110, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300},
|
||||
"Groups - L": {MinMsPerOpLocal: 150, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 500},
|
||||
"Users - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 400},
|
||||
|
@ -48,13 +48,12 @@ func BenchmarkUpdateUser(b *testing.B) {
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
for name, bc := range benchCasesUsers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
@ -97,13 +96,12 @@ func BenchmarkGetOneUser(b *testing.B) {
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
for name, bc := range benchCasesUsers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
@ -118,26 +116,25 @@ func BenchmarkGetOneUser(b *testing.B) {
|
||||
|
||||
func BenchmarkGetAllUsers(b *testing.B) {
|
||||
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10},
|
||||
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10},
|
||||
"Users - M": {MinMsPerOpLocal: 3, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 15},
|
||||
"Users - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 50},
|
||||
"Peers - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 55},
|
||||
"Groups - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55},
|
||||
"Users - XL": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 120, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300},
|
||||
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 75},
|
||||
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 75},
|
||||
"Users - M": {MinMsPerOpLocal: 3, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 75},
|
||||
"Users - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 10, MaxMsPerOpCICD: 100},
|
||||
"Peers - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 10, MaxMsPerOpCICD: 100},
|
||||
"Groups - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 10, MaxMsPerOpCICD: 100},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 10, MaxMsPerOpCICD: 100},
|
||||
"Users - XL": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 120, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 300},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
for name, bc := range benchCasesUsers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
@ -152,26 +149,25 @@ func BenchmarkGetAllUsers(b *testing.B) {
|
||||
|
||||
func BenchmarkDeleteUsers(b *testing.B) {
|
||||
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Users - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Users - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
|
||||
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
|
||||
"Users - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
|
||||
"Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
|
||||
"Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
|
||||
"Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
|
||||
"Users - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
for name, bc := range benchCasesUsers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, 1000, bc.SetupKeys)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
@ -3,6 +3,7 @@ package testing_tools
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -13,17 +14,17 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
@ -32,6 +33,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -115,11 +117,20 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
||||
authManager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
authManagerMock := &auth.MockManager{
|
||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
|
||||
MarkPATUsedFunc: authManager.MarkPATUsed,
|
||||
GetPATInfoFunc: authManager.GetPATInfo,
|
||||
}
|
||||
|
||||
networksManagerMock := networks.NewManagerMock()
|
||||
resourcesManagerMock := resources.NewManagerMock()
|
||||
routersManagerMock := routers.NewManagerMock()
|
||||
groupsManagerMock := groups.NewManagerMock()
|
||||
apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, &jwtclaims.JwtValidatorMock{}, metrics, configs.AuthCfg{}, validatorMock)
|
||||
apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, &server.Config{}, validatorMock)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
@ -309,3 +320,25 @@ func EvaluateBenchmarkResults(b *testing.B, name string, duration time.Duration,
|
||||
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", name, msPerOp, maxExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
|
||||
userAuth := nbcontext.UserAuth{}
|
||||
|
||||
switch token {
|
||||
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
|
||||
userAuth.UserId = token
|
||||
userAuth.AccountId = "testAccountId"
|
||||
userAuth.Domain = "test.com"
|
||||
userAuth.DomainCategory = "private"
|
||||
case "otherUserId":
|
||||
userAuth.UserId = "otherUserId"
|
||||
userAuth.AccountId = "otherAccountId"
|
||||
userAuth.Domain = "other.com"
|
||||
userAuth.DomainCategory = "private"
|
||||
case "invalidToken":
|
||||
return userAuth, nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
jwtToken := jwt.New(jwt.SigningMethodHS256)
|
||||
return userAuth, jwtToken, nil
|
||||
}
|
||||
|
@ -1,19 +0,0 @@
|
||||
package jwtclaims
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
// AuthorizationClaims stores authorization information from JWTs
|
||||
type AuthorizationClaims struct {
|
||||
UserId string
|
||||
AccountId string
|
||||
Domain string
|
||||
DomainCategory string
|
||||
LastLogin time.Time
|
||||
Invited bool
|
||||
|
||||
Raw jwt.MapClaims
|
||||
}
|
@ -1,227 +0,0 @@
|
||||
package jwtclaims
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audience string) *http.Request {
|
||||
t.Helper()
|
||||
const layout = "2006-01-02T15:04:05.999Z"
|
||||
|
||||
claimMaps := jwt.MapClaims{}
|
||||
if claims.UserId != "" {
|
||||
claimMaps[UserIDClaim] = claims.UserId
|
||||
}
|
||||
if claims.AccountId != "" {
|
||||
claimMaps[audience+AccountIDSuffix] = claims.AccountId
|
||||
}
|
||||
if claims.Domain != "" {
|
||||
claimMaps[audience+DomainIDSuffix] = claims.Domain
|
||||
}
|
||||
if claims.DomainCategory != "" {
|
||||
claimMaps[audience+DomainCategorySuffix] = claims.DomainCategory
|
||||
}
|
||||
if claims.LastLogin != (time.Time{}) {
|
||||
claimMaps[audience+LastLoginSuffix] = claims.LastLogin.Format(layout)
|
||||
}
|
||||
|
||||
if claims.Invited {
|
||||
claimMaps[audience+Invited] = true
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
require.NoError(t, err, "creating testing request failed")
|
||||
testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) // nolint
|
||||
|
||||
return testRequest
|
||||
}
|
||||
|
||||
func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputAuthorizationClaims AuthorizationClaims
|
||||
inputAudiance string
|
||||
testingFunc require.ComparisonAssertionFunc
|
||||
expectedMSG string
|
||||
}
|
||||
|
||||
const layout = "2006-01-02T15:04:05.999Z"
|
||||
lastLogin, _ := time.Parse(layout, "2023-08-17T09:30:40.465Z")
|
||||
|
||||
testCase1 := test{
|
||||
name: "All Claim Fields",
|
||||
inputAudiance: "https://login/",
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
Domain: "test.com",
|
||||
AccountId: "testAcc",
|
||||
LastLogin: lastLogin,
|
||||
DomainCategory: "public",
|
||||
Invited: true,
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_domain": "test.com",
|
||||
"https://login/wt_account_domain_category": "public",
|
||||
"https://login/wt_account_id": "testAcc",
|
||||
"https://login/nb_last_login": lastLogin.Format(layout),
|
||||
"sub": "test",
|
||||
"https://login/" + Invited: true,
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
}
|
||||
|
||||
testCase2 := test{
|
||||
name: "Domain Is Empty",
|
||||
inputAudiance: "https://login/",
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
AccountId: "testAcc",
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_id": "testAcc",
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
}
|
||||
|
||||
testCase3 := test{
|
||||
name: "Account ID Is Empty",
|
||||
inputAudiance: "https://login/",
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
Domain: "test.com",
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_domain": "test.com",
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
}
|
||||
|
||||
testCase4 := test{
|
||||
name: "Category Is Empty",
|
||||
inputAudiance: "https://login/",
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
Domain: "test.com",
|
||||
AccountId: "testAcc",
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_domain": "test.com",
|
||||
"https://login/wt_account_id": "testAcc",
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
}
|
||||
|
||||
testCase5 := test{
|
||||
name: "Only User ID Is set",
|
||||
inputAudiance: "https://login/",
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
Raw: jwt.MapClaims{
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance)
|
||||
|
||||
extractor := NewClaimsExtractor(WithAudience(testCase.inputAudiance))
|
||||
extractedClaims := extractor.FromRequestContext(request)
|
||||
|
||||
testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaimsSetOptions(t *testing.T) {
|
||||
t.Helper()
|
||||
type test struct {
|
||||
name string
|
||||
extractor *ClaimsExtractor
|
||||
check func(t *testing.T, c test)
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "No custom options",
|
||||
extractor: NewClaimsExtractor(),
|
||||
check: func(t *testing.T, c test) {
|
||||
t.Helper()
|
||||
if c.extractor.authAudience != "" {
|
||||
t.Error("audience should be empty")
|
||||
return
|
||||
}
|
||||
if c.extractor.userIDClaim != UserIDClaim {
|
||||
t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim)
|
||||
return
|
||||
}
|
||||
if c.extractor.FromRequestContext == nil {
|
||||
t.Error("from request context should not be nil")
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
testCase2 := test{
|
||||
name: "Custom audience",
|
||||
extractor: NewClaimsExtractor(WithAudience("https://login/")),
|
||||
check: func(t *testing.T, c test) {
|
||||
t.Helper()
|
||||
if c.extractor.authAudience != "https://login/" {
|
||||
t.Errorf("audience expected %s, got %s", "https://login/", c.extractor.authAudience)
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
testCase3 := test{
|
||||
name: "Custom user id claim",
|
||||
extractor: NewClaimsExtractor(WithUserIDClaim("customUserId")),
|
||||
check: func(t *testing.T, c test) {
|
||||
t.Helper()
|
||||
if c.extractor.userIDClaim != "customUserId" {
|
||||
t.Errorf("user id claim expected %s, got %s", "customUserId", c.extractor.userIDClaim)
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
testCase4 := test{
|
||||
name: "Custom extractor from request context",
|
||||
extractor: NewClaimsExtractor(
|
||||
WithFromRequestContext(func(r *http.Request) AuthorizationClaims {
|
||||
return AuthorizationClaims{
|
||||
UserId: "testCustomRequest",
|
||||
}
|
||||
})),
|
||||
check: func(t *testing.T, c test) {
|
||||
t.Helper()
|
||||
claims := c.extractor.FromRequestContext(&http.Request{})
|
||||
if claims.UserId != "testCustomRequest" {
|
||||
t.Errorf("user id claim expected %s, got %s", "testCustomRequest", claims.UserId)
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
testCase.check(t, testCase)
|
||||
})
|
||||
}
|
||||
}
|
@ -1,349 +0,0 @@
|
||||
package jwtclaims
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Options is a struct for specifying configuration options for the middleware.
|
||||
type Options struct {
|
||||
// The function that will return the Key to validate the JWT.
|
||||
// It can be either a shared secret or a public key.
|
||||
// Default value: nil
|
||||
ValidationKeyGetter jwt.Keyfunc
|
||||
// The name of the property in the request where the user information
|
||||
// from the JWT will be stored.
|
||||
// Default value: "user"
|
||||
UserProperty string
|
||||
// The function that will be called when there's an error validating the token
|
||||
// Default value:
|
||||
CredentialsOptional bool
|
||||
// A function that extracts the token from the request
|
||||
// Default: FromAuthHeader (i.e., from Authorization header as bearer token)
|
||||
Debug bool
|
||||
// When set, all requests with the OPTIONS method will use authentication
|
||||
// Default: false
|
||||
EnableAuthOnOptions bool
|
||||
}
|
||||
|
||||
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
|
||||
type Jwks struct {
|
||||
Keys []JSONWebKey `json:"keys"`
|
||||
expiresInTime time.Time
|
||||
}
|
||||
|
||||
// The supported elliptic curves types
|
||||
const (
|
||||
// p256 represents a cryptographic elliptical curve type.
|
||||
p256 = "P-256"
|
||||
|
||||
// p384 represents a cryptographic elliptical curve type.
|
||||
p384 = "P-384"
|
||||
|
||||
// p521 represents a cryptographic elliptical curve type.
|
||||
p521 = "P-521"
|
||||
)
|
||||
|
||||
// JSONWebKey is a representation of a Jason Web Key
|
||||
type JSONWebKey struct {
|
||||
Kty string `json:"kty"`
|
||||
Kid string `json:"kid"`
|
||||
Use string `json:"use"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
Crv string `json:"crv"`
|
||||
X string `json:"x"`
|
||||
Y string `json:"y"`
|
||||
X5c []string `json:"x5c"`
|
||||
}
|
||||
|
||||
type JWTValidator interface {
|
||||
ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error)
|
||||
}
|
||||
|
||||
// jwtValidatorImpl struct to handle token validation and parsing
|
||||
type jwtValidatorImpl struct {
|
||||
options Options
|
||||
}
|
||||
|
||||
var keyNotFound = errors.New("unable to find appropriate key")
|
||||
|
||||
// NewJWTValidator constructor
|
||||
func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (JWTValidator, error) {
|
||||
keys, err := getPemKeys(ctx, keysLocation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var lock sync.Mutex
|
||||
options := Options{
|
||||
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify 'aud' claim
|
||||
var checkAud bool
|
||||
for _, audience := range audienceList {
|
||||
checkAud = token.Claims.(jwt.MapClaims).VerifyAudience(audience, false)
|
||||
if checkAud {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !checkAud {
|
||||
return token, errors.New("invalid audience")
|
||||
}
|
||||
// Verify 'issuer' claim
|
||||
checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(issuer, false)
|
||||
if !checkIss {
|
||||
return token, errors.New("invalid issuer")
|
||||
}
|
||||
|
||||
// If keys are rotated, verify the keys prior to token validation
|
||||
if idpSignkeyRefreshEnabled {
|
||||
// If the keys are invalid, retrieve new ones
|
||||
if !keys.stillValid() {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
refreshedKeys, err := getPemKeys(ctx, keysLocation)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
|
||||
refreshedKeys = keys
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
|
||||
|
||||
keys = refreshedKeys
|
||||
}
|
||||
}
|
||||
|
||||
publicKey, err := getPublicKey(ctx, token, keys)
|
||||
if err == nil {
|
||||
return publicKey, nil
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("getPublicKey error: %s", err)
|
||||
if errors.Is(err, keyNotFound) && !idpSignkeyRefreshEnabled {
|
||||
msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Error(msg)
|
||||
|
||||
return nil, err
|
||||
},
|
||||
EnableAuthOnOptions: false,
|
||||
}
|
||||
|
||||
if options.UserProperty == "" {
|
||||
options.UserProperty = "user"
|
||||
}
|
||||
|
||||
return &jwtValidatorImpl{
|
||||
options: options,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateAndParse validates the token and returns the parsed token
|
||||
func (m *jwtValidatorImpl) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
|
||||
// If the token is empty...
|
||||
if token == "" {
|
||||
// Check if it was required
|
||||
if m.options.CredentialsOptional {
|
||||
log.WithContext(ctx).Debugf("no credentials found (CredentialsOptional=true)")
|
||||
// No error, just no token (and that is ok given that CredentialsOptional is true)
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
// If we get here, the required token is missing
|
||||
errorMsg := "required authorization token not found"
|
||||
log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)")
|
||||
return nil, errors.New(errorMsg)
|
||||
}
|
||||
|
||||
// Now parse the token
|
||||
parsedToken, err := jwt.Parse(token, m.options.ValidationKeyGetter)
|
||||
|
||||
// Check if there was an error in parsing...
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error parsing token: %v", err)
|
||||
return nil, fmt.Errorf("error parsing token: %w", err)
|
||||
}
|
||||
|
||||
// Check if the parsed token is valid...
|
||||
if !parsedToken.Valid {
|
||||
errorMsg := "token is invalid"
|
||||
log.WithContext(ctx).Debug(errorMsg)
|
||||
return nil, errors.New(errorMsg)
|
||||
}
|
||||
|
||||
return parsedToken, nil
|
||||
}
|
||||
|
||||
// stillValid returns true if the JSONWebKey still valid and have enough time to be used
|
||||
func (jwks *Jwks) stillValid() bool {
|
||||
return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
|
||||
}
|
||||
|
||||
func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) {
|
||||
resp, err := http.Get(keysLocation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
jwks := &Jwks{}
|
||||
err = json.NewDecoder(resp.Body).Decode(jwks)
|
||||
if err != nil {
|
||||
return jwks, err
|
||||
}
|
||||
|
||||
cacheControlHeader := resp.Header.Get("Cache-Control")
|
||||
expiresIn := getMaxAgeFromCacheHeader(ctx, cacheControlHeader)
|
||||
jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
return jwks, err
|
||||
}
|
||||
|
||||
func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{}, error) {
|
||||
// todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
|
||||
|
||||
for k := range jwks.Keys {
|
||||
if token.Header["kid"] != jwks.Keys[k].Kid {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(jwks.Keys[k].X5c) != 0 {
|
||||
cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
|
||||
return jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
|
||||
}
|
||||
|
||||
if jwks.Keys[k].Kty == "RSA" {
|
||||
log.WithContext(ctx).Debugf("generating PublicKey from RSA JWK")
|
||||
return getPublicKeyFromRSA(jwks.Keys[k])
|
||||
}
|
||||
if jwks.Keys[k].Kty == "EC" {
|
||||
log.WithContext(ctx).Debugf("generating PublicKey from ECDSA JWK")
|
||||
return getPublicKeyFromECDSA(jwks.Keys[k])
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty)
|
||||
}
|
||||
|
||||
return nil, keyNotFound
|
||||
}
|
||||
|
||||
func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) {
|
||||
|
||||
if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" {
|
||||
return nil, fmt.Errorf("ecdsa key incomplete")
|
||||
}
|
||||
|
||||
var xCoordinate []byte
|
||||
if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var yCoordinate []byte
|
||||
if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publicKey = &ecdsa.PublicKey{}
|
||||
|
||||
var curve elliptic.Curve
|
||||
switch jwk.Crv {
|
||||
case p256:
|
||||
curve = elliptic.P256()
|
||||
case p384:
|
||||
curve = elliptic.P384()
|
||||
case p521:
|
||||
curve = elliptic.P521()
|
||||
}
|
||||
|
||||
publicKey.Curve = curve
|
||||
publicKey.X = big.NewInt(0).SetBytes(xCoordinate)
|
||||
publicKey.Y = big.NewInt(0).SetBytes(yCoordinate)
|
||||
|
||||
return publicKey, nil
|
||||
}
|
||||
|
||||
func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) {
|
||||
|
||||
decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var n, e big.Int
|
||||
e.SetBytes(decodedE)
|
||||
n.SetBytes(decodedN)
|
||||
|
||||
return &rsa.PublicKey{
|
||||
E: int(e.Int64()),
|
||||
N: &n,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
|
||||
func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
|
||||
// Split into individual directives
|
||||
directives := strings.Split(cacheControl, ",")
|
||||
|
||||
for _, directive := range directives {
|
||||
directive = strings.TrimSpace(directive)
|
||||
if strings.HasPrefix(directive, "max-age=") {
|
||||
// Extract the max-age value
|
||||
maxAgeStr := strings.TrimPrefix(directive, "max-age=")
|
||||
maxAge, err := strconv.Atoi(maxAgeStr)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error parsing max-age: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return maxAge
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
type JwtValidatorMock struct{}
|
||||
|
||||
func (j *JwtValidatorMock) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
|
||||
claimMaps := jwt.MapClaims{}
|
||||
|
||||
switch token {
|
||||
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
|
||||
claimMaps[UserIDClaim] = token
|
||||
claimMaps[AccountIDSuffix] = "testAccountId"
|
||||
claimMaps[DomainIDSuffix] = "test.com"
|
||||
claimMaps[DomainCategorySuffix] = "private"
|
||||
case "otherUserId":
|
||||
claimMaps[UserIDClaim] = "otherUserId"
|
||||
claimMaps[AccountIDSuffix] = "otherAccountId"
|
||||
claimMaps[DomainIDSuffix] = "other.com"
|
||||
claimMaps[DomainCategorySuffix] = "private"
|
||||
case "invalidToken":
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
return jwtToken, nil
|
||||
}
|
||||
|
@ -440,7 +440,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp
|
||||
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||
|
||||
ephemeralMgr := NewEphemeralManager(store, accountManager)
|
||||
mgmtServer, err := NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, ephemeralMgr)
|
||||
mgmtServer, err := NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, ephemeralMgr, nil)
|
||||
if err != nil {
|
||||
return nil, nil, "", cleanup, err
|
||||
}
|
||||
|
@ -204,6 +204,7 @@ func startServer(
|
||||
secretsManager,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed creating management server: %v", err)
|
||||
|
@ -13,14 +13,16 @@ import (
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
var _ server.AccountManager = (*MockAccountManager)(nil)
|
||||
|
||||
type MockAccountManager struct {
|
||||
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*types.Account, error)
|
||||
GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error)
|
||||
@ -29,7 +31,7 @@ type MockAccountManager struct {
|
||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
|
||||
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
|
||||
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
|
||||
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
||||
@ -54,8 +56,6 @@ type MockAccountManager struct {
|
||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
|
||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||
GetPATInfoFunc func(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error)
|
||||
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||
@ -80,8 +80,7 @@ type MockAccountManager struct {
|
||||
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error)
|
||||
GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
|
||||
GetDNSDomainFunc func() string
|
||||
StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||
@ -240,14 +239,6 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
// GetPATInfo mock implementation of GetPATInfo from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetPATInfo(ctx context.Context, pat string) (*types.User, *types.PersonalAccessToken, string, string, error) {
|
||||
if am.GetPATInfoFunc != nil {
|
||||
return am.GetPATInfoFunc(ctx, pat)
|
||||
}
|
||||
return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetPATInfo is not implemented")
|
||||
}
|
||||
|
||||
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
|
||||
if am.DeleteAccountFunc != nil {
|
||||
@ -256,14 +247,6 @@ func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, user
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteAccount is not implemented")
|
||||
}
|
||||
|
||||
// MarkPATUsed mock implementation of MarkPATUsed from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPATUsed(ctx context.Context, pat string) error {
|
||||
if am.MarkPATUsedFunc != nil {
|
||||
return am.MarkPATUsedFunc(ctx, pat)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPATUsed is not implemented")
|
||||
}
|
||||
|
||||
// CreatePAT mock implementation of GetPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
if am.CreatePATFunc != nil {
|
||||
@ -430,11 +413,11 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string,
|
||||
}
|
||||
|
||||
// GetUser mock implementation of GetUser from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) {
|
||||
if am.GetUserFunc != nil {
|
||||
return am.GetUserFunc(ctx, claims)
|
||||
func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
if am.GetUserFromUserAuthFunc != nil {
|
||||
return am.GetUserFromUserAuthFunc(ctx, userAuth)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented")
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetUserFromUserAuth is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) {
|
||||
@ -614,19 +597,11 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountIDFromToken mocks GetAccountIDFromToken of the AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
if am.GetAccountIDFromTokenFunc != nil {
|
||||
return am.GetAccountIDFromTokenFunc(ctx, claims)
|
||||
func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
if am.GetAccountIDFromUserAuthFunc != nil {
|
||||
return am.GetAccountIDFromUserAuthFunc(ctx, userAuth)
|
||||
}
|
||||
return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||
if am.CheckUserAccessByJWTGroupsFunc != nil {
|
||||
return am.CheckUserAccessByJWTGroupsFunc(ctx, claims)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented")
|
||||
return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromUserAuth is not implemented")
|
||||
}
|
||||
|
||||
// GetPeers mocks GetPeers of the AccountManager interface
|
||||
@ -859,3 +834,7 @@ func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, acco
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return status.Errorf(codes.Unimplemented, "method SyncUserJWTGroups is not implemented")
|
||||
}
|
||||
|
@ -8,16 +8,16 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// createServiceUser creates a new service user under the given account.
|
||||
@ -174,31 +174,26 @@ func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*t
|
||||
return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id)
|
||||
}
|
||||
|
||||
// GetUser looks up a user by provided authorization claims.
|
||||
// It will also create an account if didn't exist for this user before.
|
||||
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) {
|
||||
accountID, userID, err := am.GetAccountIDFromToken(ctx, claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
// GetUser looks up a user by provided nbContext.UserAuths.
|
||||
// Expects account to have been created already.
|
||||
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// this code should be outside of the am.GetAccountIDFromToken(claims) because this method is called also by the gRPC
|
||||
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
|
||||
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
|
||||
newLogin := user.LastDashboardLoginChanged(userAuth.LastLogin)
|
||||
|
||||
err = am.Store.SaveUserLastLogin(ctx, accountID, userID, claims.LastLogin)
|
||||
err = am.Store.SaveUserLastLogin(ctx, userAuth.AccountId, userAuth.UserId, userAuth.LastLogin)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
|
||||
}
|
||||
|
||||
if newLogin {
|
||||
meta := map[string]any{"timestamp": claims.LastLogin}
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, accountID, activity.DashboardLogin, meta)
|
||||
meta := map[string]any{"timestamp": userAuth.LastLogin}
|
||||
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, userAuth.AccountId, activity.DashboardLogin, meta)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
|
@ -10,6 +10,8 @@ import (
|
||||
"github.com/eko/gocache/v3/cache"
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
@ -25,7 +27,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integration_reference"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -925,11 +926,12 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
UserId: mockUserID,
|
||||
claims := nbcontext.UserAuth{
|
||||
UserId: mockUserID,
|
||||
AccountId: mockAccountID,
|
||||
}
|
||||
|
||||
user, err := am.GetUser(context.Background(), claims)
|
||||
user, err := am.GetUserFromUserAuth(context.Background(), claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when checking user role: %s", err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user