diff --git a/controller/grants.go b/controller/grants.go index 7bc3853a..94cfed4a 100644 --- a/controller/grants.go +++ b/controller/grants.go @@ -67,8 +67,8 @@ func (h *grantsHandler) Handle(params admin.GrantsParams, principal *rest_model_ } if shrCfg.Interstitial != !acctSkipInterstitial { - logrus.Infof("updating config for '%v'", shr.Token) - err := zrokEdgeSdk.UpdateConfig(cfgZId, shrCfg, edge) + shrCfg.Interstitial = !acctSkipInterstitial + err := zrokEdgeSdk.UpdateConfig(shr.Token, cfgZId, shrCfg, edge) if err != nil { logrus.Errorf("error updating config for '%v': %v", shr.Token, err) return admin.NewGrantsInternalServerError() diff --git a/controller/zrokEdgeSdk/config.go b/controller/zrokEdgeSdk/config.go index 24e8d46d..6c9b7b8c 100644 --- a/controller/zrokEdgeSdk/config.go +++ b/controller/zrokEdgeSdk/config.go @@ -87,7 +87,22 @@ func GetConfig(shrToken string, edge *rest_management_api_client.ZitiEdgeManagem return "", nil, fmt.Errorf("unknown data type '%v' unmarshaling config for '%v'", reflect.TypeOf(listResp.Payload.Data[0].Data), shrToken) } -func UpdateConfig(cfgZId string, cfg *sdk.FrontendConfig, edge *rest_management_api_client.ZitiEdgeManagement) error { +func UpdateConfig(shrToken, cfgZId string, cfg *sdk.FrontendConfig, edge *rest_management_api_client.ZitiEdgeManagement) error { + logrus.Infof("updating config for '%v' (%v)", shrToken, cfgZId) + req := &config.UpdateConfigParams{ + Config: &rest_model.ConfigUpdate{ + Data: cfg, + Name: &shrToken, + Tags: ZrokShareTags(shrToken), + }, + ID: cfgZId, + Context: context.Background(), + } + req.SetTimeout(30 * time.Second) + _, err := edge.Config.UpdateConfig(req, nil) + if err != nil { + return err + } return nil } diff --git a/sdk/golang/sdk/config.go b/sdk/golang/sdk/config.go index f701d5f7..cd9b3a53 100644 --- a/sdk/golang/sdk/config.go +++ b/sdk/golang/sdk/config.go @@ -59,11 +59,11 @@ type BasicAuthConfig struct { func BasicAuthConfigFromMap(m map[string]interface{}) (*BasicAuthConfig, error) { out := &BasicAuthConfig{} - if v, found := m["basic_auth"]; found { - if vArr, ok := v.([]interface{}); ok { - for _, vV := range vArr { - if v, ok := vV.(map[string]interface{}); ok { - if auc, err := AuthUserConfigFromMap(v); err == nil { + if v, found := m["users"]; found { + if subArr, ok := v.([]interface{}); ok { + for _, v := range subArr { + if subMap, ok := v.(map[string]interface{}); ok { + if auc, err := AuthUserConfigFromMap(subMap); err == nil { out.Users = append(out.Users, auc) } else { return nil, err @@ -75,9 +75,10 @@ func BasicAuthConfigFromMap(m map[string]interface{}) (*BasicAuthConfig, error) } else { return nil, errors.Errorf("unexpected type '%v'", reflect.TypeOf(v)) } - return out, nil + } else { + return nil, errors.New("missing 'users' field") } - return nil, nil + return out, nil } type AuthUserConfig struct { diff --git a/sdk/golang/sdk/config_test.go b/sdk/golang/sdk/config_test.go new file mode 100644 index 00000000..288f585d --- /dev/null +++ b/sdk/golang/sdk/config_test.go @@ -0,0 +1,71 @@ +package sdk + +import ( + "encoding/json" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestBasicFrontendConfigFromMap(t *testing.T) { + inFec := &FrontendConfig{ + Interstitial: true, + AuthScheme: None, + } + m, err := frontendConfigToMap(inFec) + assert.NoError(t, err) + assert.NotNil(t, m) + outFec, err := FrontendConfigFromMap(m) + assert.NoError(t, err) + assert.NotNil(t, outFec) + assert.Equal(t, inFec, outFec) +} + +func TestBasicAuthFrontendConfigFromMap(t *testing.T) { + inFec := &FrontendConfig{ + Interstitial: false, + AuthScheme: Basic, + BasicAuth: &BasicAuthConfig{ + Users: []*AuthUserConfig{ + {Username: "nobody", Password: "password"}, + }, + }, + } + m, err := frontendConfigToMap(inFec) + assert.NoError(t, err) + assert.NotNil(t, m) + outFec, err := FrontendConfigFromMap(m) + assert.NoError(t, err) + assert.NotNil(t, outFec) + assert.Equal(t, inFec, outFec) +} + +func TestOauthAuthFrontendConfigFromMap(t *testing.T) { + inFec := &FrontendConfig{ + Interstitial: true, + AuthScheme: Oauth, + OauthAuth: &OauthConfig{ + Provider: "google", + EmailDomains: []string{"a@b.com", "c@d.com"}, + AuthorizationCheckInterval: "5m", + }, + } + m, err := frontendConfigToMap(inFec) + assert.NoError(t, err) + assert.NotNil(t, m) + outFec, err := FrontendConfigFromMap(m) + assert.NoError(t, err) + assert.NotNil(t, outFec) + assert.Equal(t, inFec, outFec) +} + +func frontendConfigToMap(fec *FrontendConfig) (map[string]interface{}, error) { + jsonData, err := json.Marshal(fec) + if err != nil { + return nil, err + } + m := make(map[string]interface{}) + if err := json.Unmarshal(jsonData, &m); err != nil { + return nil, err + } + return m, nil +}