mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-21 23:53:14 +01:00
Super user invites (#483)
This PR brings user invites logic to the Management service via HTTP API. The POST /users/ API endpoint creates a new user in the Idp and then in the local storage. Once the invited user signs ups, the account invitation is redeemed. There are a few limitations. This works only with an enabled IdP manager. Users that already have a registered account can't be invited.
This commit is contained in:
parent
abd1230a69
commit
06055af361
7
go.mod
7
go.mod
@ -32,7 +32,7 @@ require (
|
||||
github.com/c-robinson/iplib v1.0.3
|
||||
github.com/coreos/go-iptables v0.6.0
|
||||
github.com/creack/pty v1.1.18
|
||||
github.com/eko/gocache/v2 v2.3.1
|
||||
github.com/eko/gocache/v3 v3.1.1
|
||||
github.com/getlantern/systray v1.2.1
|
||||
github.com/gliderlabs/ssh v0.3.4
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
|
||||
@ -41,7 +41,7 @@ require (
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
github.com/stretchr/testify v1.7.1
|
||||
github.com/stretchr/testify v1.8.0
|
||||
golang.org/x/net v0.0.0-20220630215102-69896b714898
|
||||
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467
|
||||
)
|
||||
@ -99,6 +99,7 @@ require (
|
||||
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
|
||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
|
||||
github.com/yuin/goldmark v1.4.1 // indirect
|
||||
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf // indirect
|
||||
golang.org/x/image v0.0.0-20200430140353-33d19683fad8 // indirect
|
||||
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
|
||||
@ -112,7 +113,7 @@ require (
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
honnef.co/go/tools v0.2.2 // indirect
|
||||
k8s.io/apimachinery v0.23.5 // indirect
|
||||
)
|
||||
|
13
go.sum
13
go.sum
@ -134,8 +134,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu
|
||||
github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZgvJUkLughtfhJv5dyTYa91l1fOUCrgjqmcifM=
|
||||
github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE=
|
||||
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
|
||||
github.com/eko/gocache/v2 v2.3.1 h1:8MMkfqGJ0KIA9OXT0rXevcEIrU16oghrGDiIDJDFCa0=
|
||||
github.com/eko/gocache/v2 v2.3.1/go.mod h1:l2z8OmpZHL0CpuzDJtxm267eF3mZW1NqUsMj+sKrbUs=
|
||||
github.com/eko/gocache/v3 v3.1.1 h1:r3CBwLnqPkcK56h9Do2CWw1kZ4TeKK0wDE1Oo/YZnhs=
|
||||
github.com/eko/gocache/v3 v3.1.1/go.mod h1:UpP/LyHAioP/a/dizgl0MpgZ3A3CkS4NbG/mWkGTQ9M=
|
||||
github.com/elazarl/goproxy v0.0.0-20170405201442-c4fc26588b6e/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc=
|
||||
github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc=
|
||||
github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs=
|
||||
@ -609,6 +609,7 @@ github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9/go.mod h1:mvWM0+15
|
||||
github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
@ -616,8 +617,9 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
|
||||
github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM=
|
||||
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
|
||||
@ -676,6 +678,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
|
||||
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
|
||||
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
|
||||
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
|
||||
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf h1:oXVg4h2qJDd9htKxb5SCpFBHLipW6hXmL3qpUixS2jw=
|
||||
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf/go.mod h1:yh0Ynu2b5ZUe3MQfp2nM0ecK7wsgouWTDN0FNeJuIys=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/image v0.0.0-20200430140353-33d19683fad8 h1:6WW6V3x1P/jokJBpRQYUJnMHRP6isStQwCozxnU7XQw=
|
||||
@ -1190,8 +1194,9 @@ gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
|
@ -3,8 +3,8 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/eko/gocache/v2/cache"
|
||||
cacheStore "github.com/eko/gocache/v2/store"
|
||||
"github.com/eko/gocache/v3/cache"
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@ -30,6 +30,11 @@ const (
|
||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||
)
|
||||
|
||||
func cacheEntryExpiration() time.Duration {
|
||||
r := rand.Intn(int(CacheExpirationMax.Milliseconds()-CacheExpirationMin.Milliseconds())) + int(CacheExpirationMin.Milliseconds())
|
||||
return time.Duration(r) * time.Millisecond
|
||||
}
|
||||
|
||||
type AccountManager interface {
|
||||
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
|
||||
GetAccountByUser(userId string) (*Account, error)
|
||||
@ -41,12 +46,13 @@ type AccountManager interface {
|
||||
autoGroups []string,
|
||||
) (*SetupKey, error)
|
||||
SaveSetupKey(accountID string, key *SetupKey) (*SetupKey, error)
|
||||
CreateUser(accountID string, key *UserInfo) (*UserInfo, error)
|
||||
ListSetupKeys(accountID string) ([]*SetupKey, error)
|
||||
SaveUser(accountID string, key *User) (*UserInfo, error)
|
||||
GetSetupKey(accountID, keyID string) (*SetupKey, error)
|
||||
GetAccountById(accountId string) (*Account, error)
|
||||
GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error)
|
||||
GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error)
|
||||
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error)
|
||||
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
AccountExists(accountId string) (*bool, error)
|
||||
GetPeer(peerKey string) (*Peer, error)
|
||||
@ -90,11 +96,15 @@ type AccountManager interface {
|
||||
|
||||
type DefaultAccountManager struct {
|
||||
Store Store
|
||||
// mutex to synchronise account operations (e.g. generating Peer IP address inside the Network)
|
||||
mux sync.Mutex
|
||||
// mux to synchronise account operations (e.g. generating Peer IP address inside the Network)
|
||||
mux sync.Mutex
|
||||
// cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID
|
||||
cacheMux sync.Mutex
|
||||
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
|
||||
cacheLoading map[string]chan struct{}
|
||||
peersUpdateManager *PeersUpdateManager
|
||||
idpManager idp.Manager
|
||||
cacheManager cache.CacheInterface
|
||||
cacheManager cache.CacheInterface[[]*idp.UserData]
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
@ -122,6 +132,7 @@ type UserInfo struct {
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
Status string `json:"-"`
|
||||
}
|
||||
|
||||
func (a *Account) Copy() *Account {
|
||||
@ -193,6 +204,8 @@ func BuildManager(
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
idpManager: idpManager,
|
||||
ctx: context.Background(),
|
||||
cacheMux: sync.Mutex{},
|
||||
cacheLoading: map[string]chan struct{}{},
|
||||
}
|
||||
|
||||
// if account has not default group
|
||||
@ -209,9 +222,9 @@ func BuildManager(
|
||||
}
|
||||
|
||||
gocacheClient := gocache.New(CacheExpirationMax, 30*time.Minute)
|
||||
gocacheStore := cacheStore.NewGoCache(gocacheClient, nil)
|
||||
gocacheStore := cacheStore.NewGoCache(gocacheClient)
|
||||
|
||||
am.cacheManager = cache.NewLoadable(am.loadFromCache, cache.New(gocacheStore))
|
||||
am.cacheManager = cache.NewLoadable[[]*idp.UserData](am.loadAccount, cache.New[[]*idp.UserData](gocacheStore))
|
||||
|
||||
if !isNil(am.idpManager) {
|
||||
go func() {
|
||||
@ -256,11 +269,7 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
|
||||
}
|
||||
|
||||
for accountID, users := range userData {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
r := rand.Intn(int(CacheExpirationMax.Milliseconds()-CacheExpirationMin.Milliseconds())) + int(CacheExpirationMin.Milliseconds())
|
||||
expiration := time.Duration(r) * time.Millisecond
|
||||
err = am.cacheManager.Set(am.ctx, accountID, users, &cacheStore.Options{Expiration: expiration})
|
||||
err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -294,7 +303,7 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountId(
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId)
|
||||
}
|
||||
err = am.updateIDPMetadata(userId, account.Id)
|
||||
err = am.addAccountIDToIDPAppMeta(userId, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -308,10 +317,28 @@ func isNil(i idp.Manager) bool {
|
||||
return i == nil || reflect.ValueOf(i).IsNil()
|
||||
}
|
||||
|
||||
// updateIDPMetadata update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) updateIDPMetadata(userId, accountID string) error {
|
||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account *Account) error {
|
||||
if !isNil(am.idpManager) {
|
||||
err := am.idpManager.UpdateUserAppMetadata(userId, idp.AppMetadata{WTAccountId: accountID})
|
||||
|
||||
// user can be nil if it wasn't found (e.g., just created)
|
||||
user, err := am.lookupUserInCache(userID, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user != nil && user.AppMetadata.WTAccountID == account.Id {
|
||||
// it was already set, so we skip the unnecessary update
|
||||
log.Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
|
||||
account.Id, userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = am.idpManager.UpdateUserAppMetadata(userID, idp.AppMetadata{WTAccountID: account.Id})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return status.Errorf(
|
||||
codes.Internal,
|
||||
@ -319,45 +346,113 @@ func (am *DefaultAccountManager) updateIDPMetadata(userId, accountID string) err
|
||||
err,
|
||||
)
|
||||
}
|
||||
// refresh cache to reflect the update
|
||||
_, err = am.refreshCache(account.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) loadFromCache(_ context.Context, accountID interface{}) (interface{}, error) {
|
||||
func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interface{}) ([]*idp.UserData, error) {
|
||||
log.Debugf("account %s not found in cache, reloading", accountID)
|
||||
return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID))
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) lookupUserInCache(user *User, accountID string) (*idp.UserData, error) {
|
||||
userData, err := am.lookupCache(map[string]*User{user.Id: user}, accountID)
|
||||
func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountID string) (*idp.UserData, error) {
|
||||
data, err := am.getAccountFromCache(accountID, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, datum := range userData {
|
||||
if datum.ID == user.Id {
|
||||
for _, datum := range data {
|
||||
if datum.Email == email {
|
||||
return datum, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.NotFound, "user %s not found in the IdP", user.Id)
|
||||
return nil, nil
|
||||
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]*User, accountID string) ([]*idp.UserData, error) {
|
||||
data, err := am.cacheManager.Get(am.ctx, accountID)
|
||||
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
|
||||
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
|
||||
users := make(map[string]struct{}, len(account.Users))
|
||||
for _, user := range account.Users {
|
||||
users[user.Id] = struct{}{}
|
||||
}
|
||||
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
|
||||
userData, err := am.lookupCache(users, account.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userData := data.([]*idp.UserData)
|
||||
for _, datum := range userData {
|
||||
if datum.ID == userID {
|
||||
return datum, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) refreshCache(accountID string) ([]*idp.UserData, error) {
|
||||
return am.getAccountFromCache(accountID, true)
|
||||
}
|
||||
|
||||
// getAccountFromCache returns user data for a given account ensuring that cache load happens only once
|
||||
func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceReload bool) ([]*idp.UserData, error) {
|
||||
am.cacheMux.Lock()
|
||||
loadingChan := am.cacheLoading[accountID]
|
||||
if loadingChan == nil {
|
||||
loadingChan = make(chan struct{})
|
||||
am.cacheLoading[accountID] = loadingChan
|
||||
am.cacheMux.Unlock()
|
||||
|
||||
defer func() {
|
||||
am.cacheMux.Lock()
|
||||
delete(am.cacheLoading, accountID)
|
||||
close(loadingChan)
|
||||
am.cacheMux.Unlock()
|
||||
}()
|
||||
|
||||
if forceReload {
|
||||
err := am.cacheManager.Delete(am.ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return am.cacheManager.Get(am.ctx, accountID)
|
||||
}
|
||||
am.cacheMux.Unlock()
|
||||
|
||||
log.Debugf("one request to get account %s is already running", accountID)
|
||||
|
||||
select {
|
||||
case <-loadingChan:
|
||||
// channel has been closed meaning cache was loaded => simply return from cache
|
||||
return am.cacheManager.Get(am.ctx, accountID)
|
||||
case <-time.After(5 * time.Second):
|
||||
return nil, fmt.Errorf("timeout while waiting for account %s cache to reload", accountID)
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, accountID string) ([]*idp.UserData, error) {
|
||||
data, err := am.getAccountFromCache(accountID, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userDataMap := make(map[string]struct{})
|
||||
for _, datum := range userData {
|
||||
for _, datum := range data {
|
||||
userDataMap[datum.ID] = struct{}{}
|
||||
}
|
||||
|
||||
// check whether we need to reload the cache
|
||||
// the accountUsers ID list is the source of truth and all the users should be in the cache
|
||||
reload := len(accountUsers) != len(userData)
|
||||
reload := len(accountUsers) != len(data)
|
||||
for user := range accountUsers {
|
||||
if _, ok := userDataMap[user]; !ok {
|
||||
reload = true
|
||||
@ -366,19 +461,13 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]*User, acco
|
||||
|
||||
if reload {
|
||||
// reload cache once avoiding loops
|
||||
err := am.cacheManager.Delete(am.ctx, accountID)
|
||||
data, err = am.refreshCache(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data, err = am.cacheManager.Get(am.ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userData = data.([]*idp.UserData)
|
||||
}
|
||||
|
||||
return userData, err
|
||||
return data, err
|
||||
}
|
||||
|
||||
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
|
||||
@ -433,7 +522,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
}
|
||||
|
||||
// we should register the account ID to this user's metadata in our IDP manager
|
||||
err = am.updateIDPMetadata(claims.UserId, existingAcc.Id)
|
||||
err = am.addAccountIDToIDPAppMeta(claims.UserId, existingAcc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -471,7 +560,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(
|
||||
}
|
||||
}
|
||||
|
||||
err = am.updateIDPMetadata(claims.UserId, account.Id)
|
||||
err = am.addAccountIDToIDPAppMeta(claims.UserId, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -479,7 +568,56 @@ func (am *DefaultAccountManager) handleNewUserAccount(
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// GetAccountWithAuthorizationClaims retrievs an account using JWT Claims.
|
||||
// redeemInvite checks whether user has been invited and redeems the invite
|
||||
func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) error {
|
||||
// only possible with the enabled IdP manager
|
||||
if am.idpManager == nil {
|
||||
log.Warnf("invites only work with enabled IdP manager")
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := am.lookupUserInCache(userID, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return status.Errorf(codes.NotFound, "user %s not found in the IdP", userID)
|
||||
}
|
||||
|
||||
if user.AppMetadata.WTPendingInvite {
|
||||
log.Infof("redeeming invite for user %s account %s", userID, account.Id)
|
||||
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
|
||||
// Our job is to just reload cache.
|
||||
go func() {
|
||||
_, err = am.refreshCache(account.Id)
|
||||
if err != nil {
|
||||
log.Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id)
|
||||
return
|
||||
}
|
||||
log.Debugf("user %s of account %s redeemed invite", user.ID, account.Id)
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountFromToken returns an account associated with this token
|
||||
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
||||
account, err := am.getAccountWithAuthorizationClaims(claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.redeemInvite(account, claims.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
|
||||
// if domain is of the PrivateCategory category, it will evaluate
|
||||
// if account is new, existing or if there is another account with the same domain
|
||||
//
|
||||
@ -496,7 +634,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(
|
||||
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
|
||||
//
|
||||
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
||||
func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims(
|
||||
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(
|
||||
claims jwtclaims.AuthorizationClaims,
|
||||
) (*Account, error) {
|
||||
// if Account ID is part of the claims
|
||||
|
@ -127,7 +127,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
|
||||
func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
||||
type initUserParams jwtclaims.AuthorizationClaims
|
||||
|
||||
type test struct {
|
||||
@ -310,7 +310,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
|
||||
testCase.inputClaims.AccountId = initAccount.Id
|
||||
}
|
||||
|
||||
account, err := manager.GetAccountWithAuthorizationClaims(testCase.inputClaims)
|
||||
account, err := manager.GetAccountFromToken(testCase.inputClaims)
|
||||
require.NoError(t, err, "support function failed")
|
||||
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
|
||||
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
|
||||
|
52
management/server/error.go
Normal file
52
management/server/error.go
Normal file
@ -0,0 +1,52 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
// UserAlreadyExists indicates that user already exists
|
||||
UserAlreadyExists ErrorType = 1
|
||||
// AccountNotFound indicates that specified account hasn't been found
|
||||
AccountNotFound ErrorType = iota
|
||||
// PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled
|
||||
PreconditionFailed ErrorType = iota
|
||||
)
|
||||
|
||||
// ErrorType is a type of the Error
|
||||
type ErrorType int32
|
||||
|
||||
// Error is an internal error
|
||||
type Error struct {
|
||||
errorType ErrorType
|
||||
message string
|
||||
}
|
||||
|
||||
// Type returns the Type of the error
|
||||
func (e *Error) Type() ErrorType {
|
||||
return e.errorType
|
||||
}
|
||||
|
||||
// Error is an error string
|
||||
func (e *Error) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
// Errorf returns Error(errorType, fmt.Sprintf(format, a...)).
|
||||
func Errorf(errorType ErrorType, format string, a ...interface{}) error {
|
||||
return &Error{
|
||||
errorType: errorType,
|
||||
message: fmt.Sprintf(format, a...),
|
||||
}
|
||||
}
|
||||
|
||||
// FromError returns Error, true if the provided error is of type of Error. nil, false otherwise
|
||||
func FromError(err error) (s *Error, ok bool) {
|
||||
if err == nil {
|
||||
return nil, true
|
||||
}
|
||||
if e, ok := err.(*Error); ok {
|
||||
return e, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
@ -181,7 +181,7 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
|
||||
return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err)
|
||||
}
|
||||
claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience)
|
||||
_, err = s.accountManager.GetAccountWithAuthorizationClaims(claims)
|
||||
_, err = s.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
||||
}
|
||||
|
@ -35,6 +35,10 @@ components:
|
||||
role:
|
||||
description: User's NetBird account role
|
||||
type: string
|
||||
status:
|
||||
description: User's status
|
||||
type: string
|
||||
enum: [ "active","invited","disabled" ]
|
||||
auto_groups:
|
||||
description: Groups to auto-assign to peers registered by this user
|
||||
type: array
|
||||
@ -46,6 +50,7 @@ components:
|
||||
- name
|
||||
- role
|
||||
- auto_groups
|
||||
- status
|
||||
UserRequest:
|
||||
type: object
|
||||
properties:
|
||||
@ -60,6 +65,27 @@ components:
|
||||
required:
|
||||
- role
|
||||
- auto_groups
|
||||
UserCreateRequest:
|
||||
type: object
|
||||
properties:
|
||||
role:
|
||||
description: User's NetBird account role
|
||||
type: string
|
||||
email:
|
||||
description: User's Email to send invite to
|
||||
type: string
|
||||
name:
|
||||
description: User's full name
|
||||
type: string
|
||||
auto_groups:
|
||||
description: Groups to auto-assign to peers registered by this user
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
required:
|
||||
- role
|
||||
- auto_groups
|
||||
- email
|
||||
PeerMinimum:
|
||||
type: object
|
||||
properties:
|
||||
@ -499,6 +525,33 @@ paths:
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/:
|
||||
post:
|
||||
summary: Create a User (invite)
|
||||
tags: [ Users]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
requestBody:
|
||||
description: User invite information
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
$ref: '#/components/schemas/UserCreateRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: A User object
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/User'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/{id}:
|
||||
put:
|
||||
summary: Update information about a User
|
||||
|
@ -87,6 +87,13 @@ const (
|
||||
RulePatchOperationPathSources RulePatchOperationPath = "sources"
|
||||
)
|
||||
|
||||
// Defines values for UserStatus.
|
||||
const (
|
||||
UserStatusActive UserStatus = "active"
|
||||
UserStatusDisabled UserStatus = "disabled"
|
||||
UserStatusInvited UserStatus = "invited"
|
||||
)
|
||||
|
||||
// Group defines model for Group.
|
||||
type Group struct {
|
||||
// Id Group ID
|
||||
@ -466,6 +473,27 @@ type User struct {
|
||||
|
||||
// Role User's NetBird account role
|
||||
Role string `json:"role"`
|
||||
|
||||
// Status User's status
|
||||
Status UserStatus `json:"status"`
|
||||
}
|
||||
|
||||
// UserStatus User's status
|
||||
type UserStatus string
|
||||
|
||||
// UserCreateRequest defines model for UserCreateRequest.
|
||||
type UserCreateRequest struct {
|
||||
// AutoGroups Groups to auto-assign to peers registered by this user
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Email User's Email to send invite to
|
||||
Email string `json:"email"`
|
||||
|
||||
// Name User's full name
|
||||
Name *string `json:"name,omitempty"`
|
||||
|
||||
// Role User's NetBird account role
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// UserRequest defines model for UserRequest.
|
||||
@ -586,5 +614,8 @@ type PostApiSetupKeysJSONRequestBody = SetupKeyRequest
|
||||
// PutApiSetupKeysIdJSONRequestBody defines body for PutApiSetupKeysId for application/json ContentType.
|
||||
type PutApiSetupKeysIdJSONRequestBody = SetupKeyRequest
|
||||
|
||||
// PostApiUsersJSONRequestBody defines body for PostApiUsers for application/json ContentType.
|
||||
type PostApiUsersJSONRequestBody = UserCreateRequest
|
||||
|
||||
// PutApiUsersIdJSONRequestBody defines body for PutApiUsersId for application/json ContentType.
|
||||
type PutApiUsersIdJSONRequestBody = UserRequest
|
||||
|
@ -67,14 +67,14 @@ func initGroupTestData(groups ...*server.Group) *Groups {
|
||||
}
|
||||
return nil, fmt.Errorf("peer not found")
|
||||
},
|
||||
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
Peers: TestPeers,
|
||||
Groups: map[string]*server.Group{
|
||||
"id-existed": &server.Group{ID: "id-existed", Peers: []string{"A", "B"}},
|
||||
"id-all": &server.Group{ID: "id-all", Name: "All"}},
|
||||
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
|
||||
"id-all": {ID: "id-all", Name: "All"}},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
|
@ -41,6 +41,7 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
|
||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||
apiHandler.HandleFunc("/api/users", userHandler.GetUsers).Methods("GET", "OPTIONS")
|
||||
apiHandler.HandleFunc("/api/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
|
||||
apiHandler.HandleFunc("/api/users", userHandler.CreateUserHandler).Methods("POST", "OPTIONS")
|
||||
|
||||
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")
|
||||
apiHandler.HandleFunc("/api/setup-keys", keysHandler.CreateSetupKeyHandler).Methods("POST", "OPTIONS")
|
||||
|
@ -104,7 +104,7 @@ func initNameserversTestData() *Nameservers {
|
||||
}
|
||||
return nsGroupToUpdate, nil
|
||||
},
|
||||
GetAccountWithAuthorizationClaimsFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
return testingNSAccount, nil
|
||||
},
|
||||
},
|
||||
|
@ -19,7 +19,7 @@ import (
|
||||
func initTestMetaData(peer ...*server.Peer) *Peers {
|
||||
return &Peers{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
|
@ -120,7 +120,7 @@ func initRoutesTestData() *Routes {
|
||||
}
|
||||
return routeToUpdate, nil
|
||||
},
|
||||
GetAccountWithAuthorizationClaimsFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
return testingAccount, nil
|
||||
},
|
||||
},
|
||||
|
@ -66,14 +66,14 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
|
||||
}
|
||||
return &rule, nil
|
||||
},
|
||||
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
Rules: map[string]*server.Rule{"id-existed": &server.Rule{ID: "id-existed"}},
|
||||
Groups: map[string]*server.Group{
|
||||
"F": &server.Group{ID: "F"},
|
||||
"G": &server.Group{ID: "G"},
|
||||
"F": {ID: "F"},
|
||||
"G": {ID: "G"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
|
@ -31,7 +31,7 @@ const (
|
||||
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey) *SetupKeys {
|
||||
return &SetupKeys{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
return &server.Account{
|
||||
Id: testAccountID,
|
||||
Domain: "hotmail.com",
|
||||
|
@ -5,12 +5,11 @@ import (
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"net/http"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
)
|
||||
@ -82,6 +81,50 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
}
|
||||
|
||||
// CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite).
|
||||
func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
|
||||
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
req := &api.PostApiUsersJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
|
||||
http.Error(w, "unknown user role "+req.Role, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.CreateUser(account.Id, &server.UserInfo{
|
||||
Email: req.Email,
|
||||
Name: *req.Name,
|
||||
Role: req.Role,
|
||||
AutoGroups: req.AutoGroups,
|
||||
})
|
||||
if err != nil {
|
||||
if e, ok := server.FromError(err); ok {
|
||||
switch e.Type() {
|
||||
case server.UserAlreadyExists:
|
||||
http.Error(w, "You can't invite users with an existing NetBird account.", http.StatusPreconditionFailed)
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
http.Error(w, "failed to invite", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
writeJSONObject(w, toUserResponse(newUser))
|
||||
}
|
||||
|
||||
// GetUsers returns a list of users of the account this user belongs to.
|
||||
// It also gathers additional user data (like email and name) from the IDP manager.
|
||||
func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
||||
@ -101,7 +144,7 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
users := []*api.User{}
|
||||
users := make([]*api.User, 0)
|
||||
for _, r := range data {
|
||||
users = append(users, toUserResponse(r))
|
||||
}
|
||||
@ -116,11 +159,22 @@ func toUserResponse(user *server.UserInfo) *api.User {
|
||||
autoGroups = []string{}
|
||||
}
|
||||
|
||||
var userStatus api.UserStatus
|
||||
switch user.Status {
|
||||
case "active":
|
||||
userStatus = api.UserStatusActive
|
||||
case "invited":
|
||||
userStatus = api.UserStatusInvited
|
||||
default:
|
||||
userStatus = api.UserStatusDisabled
|
||||
}
|
||||
|
||||
return &api.User{
|
||||
Id: user.ID,
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
AutoGroups: autoGroups,
|
||||
Status: userStatus,
|
||||
}
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ import (
|
||||
func initUsers(user ...*server.User) *UserHandler {
|
||||
return &UserHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
users := make(map[string]*server.User, 0)
|
||||
for _, u := range user {
|
||||
users[u.Id] = u
|
||||
|
@ -60,7 +60,7 @@ func getJWTAccount(accountManager server.AccountManager,
|
||||
|
||||
jwtClaims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience)
|
||||
|
||||
account, err := accountManager.GetAccountWithAuthorizationClaims(jwtClaims)
|
||||
account, err := accountManager.GetAccountFromToken(jwtClaims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
|
||||
}
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
@ -54,6 +53,16 @@ type Auth0Credentials struct {
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
// createUserRequest is a user create request
|
||||
type createUserRequest struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
AppMeta AppMetadata `json:"app_metadata"`
|
||||
Connection string `json:"connection"`
|
||||
Password string `json:"password"`
|
||||
VerifyEmail bool `json:"verify_email"`
|
||||
}
|
||||
|
||||
// userExportJobRequest is a user export request struct
|
||||
type userExportJobRequest struct {
|
||||
Format string `json:"format"`
|
||||
@ -87,12 +96,13 @@ type userExportJobStatusResponse struct {
|
||||
|
||||
// auth0Profile represents an Auth0 user profile response
|
||||
type auth0Profile struct {
|
||||
AccountID string `json:"wt_account_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
LastLogin string `json:"last_login"`
|
||||
AccountID string `json:"wt_account_id"`
|
||||
PendingInvite bool `json:"wt_pending_invite"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
LastLogin string `json:"last_login"`
|
||||
}
|
||||
|
||||
// NewAuth0Manager creates a new instance of the Auth0Manager
|
||||
@ -172,7 +182,7 @@ func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) {
|
||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
||||
func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||
jwtToken := JWTToken{}
|
||||
body, err := ioutil.ReadAll(rawBody)
|
||||
body, err := io.ReadAll(rawBody)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
@ -230,7 +240,7 @@ func (c *Auth0Credentials) Authenticate() (JWTToken, error) {
|
||||
return c.jwtToken, nil
|
||||
}
|
||||
|
||||
func batchRequestUsersURL(authIssuer, accountID string, page int) (string, url.Values, error) {
|
||||
func batchRequestUsersURL(authIssuer, accountID string, page int, perPage int) (string, url.Values, error) {
|
||||
u, err := url.Parse(authIssuer + "/api/v2/users")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
@ -238,6 +248,7 @@ func batchRequestUsersURL(authIssuer, accountID string, page int) (string, url.V
|
||||
q := u.Query()
|
||||
q.Set("page", strconv.Itoa(page))
|
||||
q.Set("search_engine", "v3")
|
||||
q.Set("per_page", strconv.Itoa(perPage))
|
||||
q.Set("q", "app_metadata.wt_account_id:"+accountID)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
@ -259,8 +270,9 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
|
||||
// https://auth0.com/docs/manage-users/user-search/retrieve-users-with-get-users-endpoint#limitations
|
||||
// auth0 limitation of 1000 users via this endpoint
|
||||
resultsPerPage := 50
|
||||
for page := 0; page < 20; page++ {
|
||||
reqURL, query, err := batchRequestUsersURL(am.authIssuer, accountID, page)
|
||||
reqURL, query, err := batchRequestUsersURL(am.authIssuer, accountID, page, resultsPerPage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -283,30 +295,31 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed requesting user data from IdP %s", string(body))
|
||||
}
|
||||
|
||||
var batch []UserData
|
||||
err = json.Unmarshal(body, &batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("requested batch; %v", batch)
|
||||
log.Debugf("returned user batch for accountID %s on page %d, %v", accountID, page, batch)
|
||||
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("unable to request UserData from auth0, statusCode %d", res.StatusCode)
|
||||
}
|
||||
|
||||
if len(batch) == 0 {
|
||||
return list, nil
|
||||
}
|
||||
|
||||
for user := range batch {
|
||||
list = append(list, &batch[user])
|
||||
}
|
||||
|
||||
if len(batch) == 0 || len(batch) < resultsPerPage {
|
||||
log.Debugf("finished loading users for accountID %s", accountID)
|
||||
return list, nil
|
||||
}
|
||||
}
|
||||
|
||||
return list, nil
|
||||
@ -367,14 +380,12 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
|
||||
|
||||
reqURL := am.authIssuer + "/api/v2/users/" + userID
|
||||
|
||||
data, err := am.helper.Marshal(appMetadata)
|
||||
data, err := am.helper.Marshal(map[string]any{"app_metadata": appMetadata})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
payloadString := fmt.Sprintf("{\"app_metadata\": %s}", string(data))
|
||||
|
||||
payload := strings.NewReader(payloadString)
|
||||
payload := strings.NewReader(string(data))
|
||||
|
||||
req, err := http.NewRequest("PATCH", reqURL, payload)
|
||||
if err != nil {
|
||||
@ -383,7 +394,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.Debugf("updating metadata for user %s", userID)
|
||||
log.Debugf("updating IdP metadata for user %s", userID)
|
||||
|
||||
res, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@ -404,6 +415,27 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildCreateUserRequestPayload(email string, name string, accountID string) (string, error) {
|
||||
req := &createUserRequest{
|
||||
Email: email,
|
||||
Name: name,
|
||||
AppMeta: AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: true,
|
||||
},
|
||||
Connection: "Username-Password-Authentication",
|
||||
Password: GeneratePassword(8, 1, 1, 1),
|
||||
VerifyEmail: true,
|
||||
}
|
||||
|
||||
str, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(str), nil
|
||||
}
|
||||
|
||||
func buildUserExportRequest() (string, error) {
|
||||
req := &userExportJobRequest{}
|
||||
fields := make([]map[string]string, 0)
|
||||
@ -417,6 +449,11 @@ func buildUserExportRequest() (string, error) {
|
||||
"export_as": "wt_account_id",
|
||||
})
|
||||
|
||||
fields = append(fields, map[string]string{
|
||||
"name": "app_metadata.wt_pending_invite",
|
||||
"export_as": "wt_pending_invite",
|
||||
})
|
||||
|
||||
req.Format = "json"
|
||||
req.Fields = fields
|
||||
|
||||
@ -428,28 +465,39 @@ func buildUserExportRequest() (string, error) {
|
||||
return string(str), nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqURL := am.authIssuer + "/api/v2/jobs/users-exports"
|
||||
reqURL := am.authIssuer + endpoint
|
||||
|
||||
payload := strings.NewReader(payloadStr)
|
||||
|
||||
req, err := http.NewRequest("POST", reqURL, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
return req, nil
|
||||
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
payloadString, err := buildUserExportRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := strings.NewReader(payloadString)
|
||||
|
||||
exportJobReq, err := http.NewRequest("POST", reqURL, payload)
|
||||
exportJobReq, err := am.createPostRequest("/api/v2/jobs/users-exports", payloadString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
exportJobReq.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
exportJobReq.Header.Add("content-type", "application/json")
|
||||
|
||||
jobResp, err := am.httpClient.Do(exportJobReq)
|
||||
if err != nil {
|
||||
@ -469,7 +517,7 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
|
||||
var exportJobResp userExportJobResponse
|
||||
|
||||
body, err := ioutil.ReadAll(jobResp.Body)
|
||||
body, err := io.ReadAll(jobResp.Body)
|
||||
if err != nil {
|
||||
log.Debugf("Coudln't read export job response; %v", err)
|
||||
return nil, err
|
||||
@ -500,6 +548,82 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
return nil, fmt.Errorf("failed extracting user profiles from auth0")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email. If no users have been found, this function returns an empty list.
|
||||
// This function can return multiple users. This is due to the Auth0 internals - there could be multiple users with
|
||||
// the same email but different connections that are considered as separate accounts (e.g., Google and username/password).
|
||||
func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + email
|
||||
body, err := doGetReq(am.httpClient, reqURL, jwtToken.AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userResp := []*UserData{}
|
||||
|
||||
err = am.helper.Unmarshal(body, &userResp)
|
||||
if err != nil {
|
||||
log.Debugf("Coudln't unmarshal export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userResp, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in Auth0 Idp and sends an invite
|
||||
func (am *Auth0Manager) CreateUser(email string, name string, accountID string) (*UserData, error) {
|
||||
|
||||
payloadString, err := buildCreateUserRequestPayload(email, name, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := am.createPostRequest("/api/v2/users", payloadString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't get job response %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing create user response body: %v", err)
|
||||
}
|
||||
}()
|
||||
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
|
||||
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var createResp UserData
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Debugf("Coudln't read export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.helper.Unmarshal(body, &createResp)
|
||||
if err != nil {
|
||||
log.Debugf("Coudln't unmarshal export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if createResp.ID == "" {
|
||||
return nil, fmt.Errorf("couldn't create user: response %v", resp)
|
||||
}
|
||||
|
||||
log.Debugf("created user %s in account %s", createResp.ID, accountID)
|
||||
|
||||
return &createResp, nil
|
||||
}
|
||||
|
||||
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
|
||||
// If the status is "completed", then return the downloadLink
|
||||
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {
|
||||
@ -572,6 +696,10 @@ func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*Us
|
||||
ID: profile.UserID,
|
||||
Name: profile.Name,
|
||||
Email: profile.Email,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: profile.AccountID,
|
||||
WTPendingInvite: profile.PendingInvite,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -605,7 +733,7 @@ func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error)
|
||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -22,13 +22,13 @@ type mockHTTPClient struct {
|
||||
}
|
||||
|
||||
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
body, err := ioutil.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err == nil {
|
||||
c.reqBody = string(body)
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: c.code,
|
||||
Body: ioutil.NopCloser(strings.NewReader(c.resBody)),
|
||||
Body: io.NopCloser(strings.NewReader(c.resBody)),
|
||||
}, c.err
|
||||
}
|
||||
|
||||
@ -130,7 +130,7 @@ func TestAuth0_RequestJWTToken(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
body, err := io.ReadAll(res.Body)
|
||||
assert.NoError(t, err, "unable to read the response body")
|
||||
|
||||
jwtToken := JWTToken{}
|
||||
@ -178,7 +178,7 @@ func TestAuth0_ParseRequestJWTResponse(t *testing.T) {
|
||||
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
rawBody := ioutil.NopCloser(strings.NewReader(testCase.inputResBody))
|
||||
rawBody := io.NopCloser(strings.NewReader(testCase.inputResBody))
|
||||
|
||||
config := Auth0ClientConfig{}
|
||||
|
||||
@ -320,7 +320,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
|
||||
|
||||
exp := 15
|
||||
token := newTestJWT(t, exp)
|
||||
appMetadata := AppMetadata{WTAccountId: "ok"}
|
||||
appMetadata := AppMetadata{WTAccountID: "ok"}
|
||||
|
||||
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
|
||||
name: "Bad Authentication",
|
||||
@ -340,7 +340,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
|
||||
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
|
||||
name: "Bad Status Code",
|
||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
expectedReqBody: fmt.Sprintf("{\"app_metadata\": {\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountId),
|
||||
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":false}}", appMetadata.WTAccountID),
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 400,
|
||||
helper: JsonParser{},
|
||||
@ -363,7 +363,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
|
||||
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
|
||||
name: "Good request",
|
||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
expectedReqBody: fmt.Sprintf("{\"app_metadata\": {\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountId),
|
||||
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":false}}", appMetadata.WTAccountID),
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 200,
|
||||
helper: JsonParser{},
|
||||
|
@ -13,6 +13,8 @@ type Manager interface {
|
||||
GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error)
|
||||
GetAccount(accountId string) ([]*UserData, error)
|
||||
GetAllAccounts() (map[string][]*UserData, error)
|
||||
CreateUser(email string, name string, accountID string) (*UserData, error)
|
||||
GetUserByEmail(email string) ([]*UserData, error)
|
||||
}
|
||||
|
||||
// Config an idp configuration struct to be loaded from management server's config file
|
||||
@ -38,16 +40,18 @@ type ManagerHelper interface {
|
||||
}
|
||||
|
||||
type UserData struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
ID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
ID string `json:"user_id"`
|
||||
AppMetadata AppMetadata `json:"app_metadata"`
|
||||
}
|
||||
|
||||
// AppMetadata user app metadata to associate with a profile
|
||||
type AppMetadata struct {
|
||||
// Wiretrustee account id to update in the IDP
|
||||
// WTAccountID is a NetBird (previously Wiretrustee) account id to update in the IDP
|
||||
// maps to wt_account_id when json.marshal
|
||||
WTAccountId string `json:"wt_account_id"`
|
||||
WTAccountID string `json:"wt_account_id,omitempty"`
|
||||
WTPendingInvite bool `json:"wt_pending_invite"`
|
||||
}
|
||||
|
||||
// JWTToken a JWT object that holds information of a token
|
||||
|
@ -1,6 +1,18 @@
|
||||
package idp
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
lowerCharSet = "abcdedfghijklmnopqrst"
|
||||
upperCharSet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
specialCharSet = "!@#$%&*"
|
||||
numberSet = "0123456789"
|
||||
allCharSet = lowerCharSet + upperCharSet + specialCharSet + numberSet
|
||||
)
|
||||
|
||||
type JsonParser struct{}
|
||||
|
||||
@ -11,3 +23,37 @@ func (JsonParser) Marshal(v interface{}) ([]byte, error) {
|
||||
func (JsonParser) Unmarshal(data []byte, v interface{}) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
// GeneratePassword generates user password
|
||||
func GeneratePassword(passwordLength, minSpecialChar, minNum, minUpperCase int) string {
|
||||
var password strings.Builder
|
||||
|
||||
//Set special character
|
||||
for i := 0; i < minSpecialChar; i++ {
|
||||
random := rand.Intn(len(specialCharSet))
|
||||
password.WriteString(string(specialCharSet[random]))
|
||||
}
|
||||
|
||||
//Set numeric
|
||||
for i := 0; i < minNum; i++ {
|
||||
random := rand.Intn(len(numberSet))
|
||||
password.WriteString(string(numberSet[random]))
|
||||
}
|
||||
|
||||
//Set uppercase
|
||||
for i := 0; i < minUpperCase; i++ {
|
||||
random := rand.Intn(len(upperCharSet))
|
||||
password.WriteString(string(upperCharSet[random]))
|
||||
}
|
||||
|
||||
remainingLength := passwordLength - minSpecialChar - minNum - minUpperCase
|
||||
for i := 0; i < remainingLength; i++ {
|
||||
random := rand.Intn(len(allCharSet))
|
||||
password.WriteString(string(allCharSet[random]))
|
||||
}
|
||||
inRune := []rune(password.String())
|
||||
rand.Shuffle(len(inRune), func(i, j int) {
|
||||
inRune[i], inRune[j] = inRune[j], inRune[i]
|
||||
})
|
||||
return string(inRune)
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package server_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
@ -45,7 +44,7 @@ var _ = Describe("Management service", func() {
|
||||
level, _ := log.ParseLevel("Debug")
|
||||
log.SetLevel(level)
|
||||
var err error
|
||||
dataDir, err = ioutil.TempDir("", "wiretrustee_mgmt_test_tmp_*")
|
||||
dataDir, err = os.MkdirTemp("", "wiretrustee_mgmt_test_tmp_*")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
err = util.CopyFileContents("testdata/store.json", filepath.Join(dataDir, "store.json"))
|
||||
|
@ -11,55 +11,56 @@ import (
|
||||
)
|
||||
|
||||
type MockAccountManager struct {
|
||||
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
||||
GetAccountByUserFunc func(userId string) (*server.Account, error)
|
||||
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error)
|
||||
GetSetupKeyFunc func(accountID string, keyID string) (*server.SetupKey, error)
|
||||
GetAccountByIdFunc func(accountId string) (*server.Account, error)
|
||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||
GetAccountWithAuthorizationClaimsFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error)
|
||||
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
AccountExistsFunc func(accountId string) (*bool, error)
|
||||
GetPeerFunc func(peerKey string) (*server.Peer, error)
|
||||
MarkPeerConnectedFunc func(peerKey string, connected bool) error
|
||||
RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error)
|
||||
DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error)
|
||||
GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error)
|
||||
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
|
||||
AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, error)
|
||||
GetGroupFunc func(accountID, groupID string) (*server.Group, error)
|
||||
SaveGroupFunc func(accountID string, group *server.Group) error
|
||||
UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error)
|
||||
DeleteGroupFunc func(accountID, groupID string) error
|
||||
ListGroupsFunc func(accountID string) ([]*server.Group, error)
|
||||
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
|
||||
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error
|
||||
GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error)
|
||||
GetRuleFunc func(accountID, ruleID string) (*server.Rule, error)
|
||||
SaveRuleFunc func(accountID string, rule *server.Rule) error
|
||||
UpdateRuleFunc func(accountID string, ruleID string, operations []server.RuleUpdateOperation) (*server.Rule, error)
|
||||
DeleteRuleFunc func(accountID, ruleID string) error
|
||||
ListRulesFunc func(accountID string) ([]*server.Rule, error)
|
||||
GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error)
|
||||
UpdatePeerMetaFunc func(peerKey string, meta server.PeerSystemMeta) error
|
||||
UpdatePeerSSHKeyFunc func(peerKey string, sshKey string) error
|
||||
UpdatePeerFunc func(accountID string, peer *server.Peer) (*server.Peer, error)
|
||||
CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error)
|
||||
GetRouteFunc func(accountID, routeID string) (*route.Route, error)
|
||||
SaveRouteFunc func(accountID string, route *route.Route) error
|
||||
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
|
||||
DeleteRouteFunc func(accountID, routeID string) error
|
||||
ListRoutesFunc func(accountID string) ([]*route.Route, error)
|
||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error)
|
||||
ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error)
|
||||
SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error)
|
||||
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(accountID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
UpdateNameServerGroupFunc func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
|
||||
DeleteNameServerGroupFunc func(accountID, nsGroupID string) error
|
||||
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
|
||||
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
||||
GetAccountByUserFunc func(userId string) (*server.Account, error)
|
||||
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error)
|
||||
GetSetupKeyFunc func(accountID string, keyID string) (*server.SetupKey, error)
|
||||
GetAccountByIdFunc func(accountId string) (*server.Account, error)
|
||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
AccountExistsFunc func(accountId string) (*bool, error)
|
||||
GetPeerFunc func(peerKey string) (*server.Peer, error)
|
||||
MarkPeerConnectedFunc func(peerKey string, connected bool) error
|
||||
RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error)
|
||||
DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error)
|
||||
GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error)
|
||||
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
|
||||
AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, error)
|
||||
GetGroupFunc func(accountID, groupID string) (*server.Group, error)
|
||||
SaveGroupFunc func(accountID string, group *server.Group) error
|
||||
UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error)
|
||||
DeleteGroupFunc func(accountID, groupID string) error
|
||||
ListGroupsFunc func(accountID string) ([]*server.Group, error)
|
||||
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
|
||||
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error
|
||||
GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error)
|
||||
GetRuleFunc func(accountID, ruleID string) (*server.Rule, error)
|
||||
SaveRuleFunc func(accountID string, rule *server.Rule) error
|
||||
UpdateRuleFunc func(accountID string, ruleID string, operations []server.RuleUpdateOperation) (*server.Rule, error)
|
||||
DeleteRuleFunc func(accountID, ruleID string) error
|
||||
ListRulesFunc func(accountID string) ([]*server.Rule, error)
|
||||
GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error)
|
||||
UpdatePeerMetaFunc func(peerKey string, meta server.PeerSystemMeta) error
|
||||
UpdatePeerSSHKeyFunc func(peerKey string, sshKey string) error
|
||||
UpdatePeerFunc func(accountID string, peer *server.Peer) (*server.Peer, error)
|
||||
CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error)
|
||||
GetRouteFunc func(accountID, routeID string) (*route.Route, error)
|
||||
SaveRouteFunc func(accountID string, route *route.Route) error
|
||||
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
|
||||
DeleteRouteFunc func(accountID, routeID string) error
|
||||
ListRoutesFunc func(accountID string) ([]*route.Route, error)
|
||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error)
|
||||
ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error)
|
||||
SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error)
|
||||
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(accountID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
UpdateNameServerGroupFunc func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
|
||||
DeleteNameServerGroupFunc func(accountID, nsGroupID string) error
|
||||
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(accountID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error)
|
||||
}
|
||||
|
||||
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
||||
@ -126,19 +127,6 @@ func (am *MockAccountManager) GetAccountByUserOrAccountId(
|
||||
)
|
||||
}
|
||||
|
||||
// GetAccountWithAuthorizationClaims mock implementation of GetAccountWithAuthorizationClaims from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountWithAuthorizationClaims(
|
||||
claims jwtclaims.AuthorizationClaims,
|
||||
) (*server.Account, error) {
|
||||
if am.GetAccountWithAuthorizationClaimsFunc != nil {
|
||||
return am.GetAccountWithAuthorizationClaimsFunc(claims)
|
||||
}
|
||||
return nil, status.Errorf(
|
||||
codes.Unimplemented,
|
||||
"method GetAccountWithAuthorizationClaims is not implemented",
|
||||
)
|
||||
}
|
||||
|
||||
// AccountExists mock implementation of AccountExists from server.AccountManager interface
|
||||
func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) {
|
||||
if am.AccountExistsFunc != nil {
|
||||
@ -485,3 +473,19 @@ func (am *MockAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.N
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// CreateUser mocks CreateUser of the AccountManager interface
|
||||
func (am *MockAccountManager) CreateUser(accountID string, invite *server.UserInfo) (*server.UserInfo, error) {
|
||||
if am.CreateUserFunc != nil {
|
||||
return am.CreateUserFunc(accountID, invite)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
if am.GetAccountFromTokenFunc != nil {
|
||||
return am.GetAccountFromTokenFunc(claims)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
|
||||
}
|
||||
|
@ -14,6 +14,10 @@ const (
|
||||
UserRoleAdmin UserRole = "admin"
|
||||
UserRoleUser UserRole = "user"
|
||||
UserRoleUnknown UserRole = "unknown"
|
||||
|
||||
UserStatusActive UserStatus = "active"
|
||||
UserStatusDisabled UserStatus = "disabled"
|
||||
UserStatusInvited UserStatus = "invited"
|
||||
)
|
||||
|
||||
// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown
|
||||
@ -28,7 +32,10 @@ func StrRoleToUserRole(strRole string) UserRole {
|
||||
}
|
||||
}
|
||||
|
||||
// UserRole is the role of the User
|
||||
// UserStatus is the status of a User
|
||||
type UserStatus string
|
||||
|
||||
// UserRole is the role of a User
|
||||
type UserRole string
|
||||
|
||||
// User represents a user of the system
|
||||
@ -53,24 +60,31 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
Name: "",
|
||||
Role: string(u.Role),
|
||||
AutoGroups: u.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
}, nil
|
||||
}
|
||||
if userData.ID != u.Id {
|
||||
return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id)
|
||||
}
|
||||
|
||||
userStatus := UserStatusActive
|
||||
if userData.AppMetadata.WTPendingInvite {
|
||||
userStatus = UserStatusInvited
|
||||
}
|
||||
|
||||
return &UserInfo{
|
||||
ID: u.Id,
|
||||
Email: userData.Email,
|
||||
Name: userData.Name,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: autoGroups,
|
||||
Status: string(userStatus),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Copy the user
|
||||
func (u *User) Copy() *User {
|
||||
autoGroups := []string{}
|
||||
autoGroups := make([]string, 0)
|
||||
autoGroups = append(autoGroups, u.AutoGroups...)
|
||||
return &User{
|
||||
Id: u.Id,
|
||||
@ -98,6 +112,70 @@ func NewAdminUser(id string) *User {
|
||||
return NewUser(id, UserRoleAdmin)
|
||||
}
|
||||
|
||||
// CreateUser creates a new user under the given account. Effectively this is a user invite.
|
||||
func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo) (*UserInfo, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
if am.idpManager == nil {
|
||||
return nil, Errorf(PreconditionFailed, "IdP manager must be enabled to send user invites")
|
||||
}
|
||||
|
||||
if invite == nil {
|
||||
return nil, fmt.Errorf("provided user update is nil")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, Errorf(AccountNotFound, "account %s doesn't exist", accountID)
|
||||
}
|
||||
|
||||
// check if the user is already registered with this email => reject
|
||||
user, err := am.lookupUserInCacheByEmail(invite.Email, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user != nil {
|
||||
return nil, Errorf(UserAlreadyExists, "user has an existing account")
|
||||
}
|
||||
|
||||
users, err := am.idpManager.GetUserByEmail(invite.Email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(users) > 0 {
|
||||
return nil, Errorf(UserAlreadyExists, "user has an existing account")
|
||||
}
|
||||
|
||||
idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
role := StrRoleToUserRole(invite.Role)
|
||||
newUser := &User{
|
||||
Id: idpUser.ID,
|
||||
Role: role,
|
||||
AutoGroups: invite.AutoGroups,
|
||||
}
|
||||
account.Users[idpUser.ID] = newUser
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = am.refreshCache(account.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newUser.toUserInfo(idpUser)
|
||||
|
||||
}
|
||||
|
||||
// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error.
|
||||
// Only User.AutoGroups field is allowed to be updated for now.
|
||||
func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*UserInfo, error) {
|
||||
@ -138,10 +216,13 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User
|
||||
}
|
||||
|
||||
if !isNil(am.idpManager) {
|
||||
userData, err := am.lookupUserInCache(newUser, accountID)
|
||||
userData, err := am.lookupUserInCache(newUser.Id, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userData == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user %s not found in the IdP", newUser.Id)
|
||||
}
|
||||
return newUser.toUserInfo(userData)
|
||||
}
|
||||
return newUser.toUserInfo(nil)
|
||||
@ -194,7 +275,7 @@ func (am *DefaultAccountManager) GetAccountByUser(userId string) (*Account, erro
|
||||
|
||||
// IsUserAdmin flag for current user authenticated by JWT token
|
||||
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
|
||||
account, err := am.GetAccountWithAuthorizationClaims(claims)
|
||||
account, err := am.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get account: %v", err)
|
||||
}
|
||||
@ -216,7 +297,11 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserI
|
||||
|
||||
queriedUsers := make([]*idp.UserData, 0)
|
||||
if !isNil(am.idpManager) {
|
||||
queriedUsers, err = am.lookupCache(account.Users, accountID)
|
||||
users := make(map[string]struct{}, len(account.Users))
|
||||
for _, user := range account.Users {
|
||||
users[user.Id] = struct{}{}
|
||||
}
|
||||
queriedUsers, err = am.lookupCache(users, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user