// +build codegen package main import ( "bytes" "encoding/json" "fmt" "net/url" "os" "os/exec" "reflect" "regexp" "sort" "strconv" "strings" "text/template" "github.com/aws/aws-sdk-go/private/model/api" "github.com/aws/aws-sdk-go/private/util" ) // TestSuiteTypeInput input test // TestSuiteTypeInput output test const ( TestSuiteTypeInput = iota TestSuiteTypeOutput ) type testSuite struct { *api.API Description string Cases []testCase Type uint title string } type testCase struct { TestSuite *testSuite Given *api.Operation Params interface{} `json:",omitempty"` Data interface{} `json:"result,omitempty"` InputTest testExpectation `json:"serialized"` OutputTest testExpectation `json:"response"` } type testExpectation struct { Body string URI string Headers map[string]string JSONValues map[string]string StatusCode uint `json:"status_code"` } const preamble = ` var _ bytes.Buffer // always import bytes var _ http.Request var _ json.Marshaler var _ time.Time var _ xmlutil.XMLNode var _ xml.Attr var _ = ioutil.Discard var _ = util.Trim("") var _ = url.Values{} var _ = io.EOF var _ = aws.String var _ = fmt.Println var _ = reflect.Value{} func init() { protocol.RandReader = &awstesting.ZeroReader{} } ` var reStripSpace = regexp.MustCompile(`\s(\w)`) var reImportRemoval = regexp.MustCompile(`(?s:import \((.+?)\))`) func removeImports(code string) string { return reImportRemoval.ReplaceAllString(code, "") } var extraImports = []string{ "bytes", "encoding/json", "encoding/xml", "fmt", "io", "io/ioutil", "net/http", "testing", "time", "reflect", "net/url", "", "github.com/aws/aws-sdk-go/awstesting", "github.com/aws/aws-sdk-go/awstesting/unit", "github.com/aws/aws-sdk-go/private/protocol", "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil", "github.com/aws/aws-sdk-go/private/util", } func addImports(code string) string { importNames := make([]string, len(extraImports)) for i, n := range extraImports { if n != "" { importNames[i] = fmt.Sprintf("%q", n) } } str := reImportRemoval.ReplaceAllString(code, "import (\n"+strings.Join(importNames, "\n")+"$1\n)") return str } func (t *testSuite) TestSuite() string { var buf bytes.Buffer t.title = reStripSpace.ReplaceAllStringFunc(t.Description, func(x string) string { return strings.ToUpper(x[1:]) }) t.title = regexp.MustCompile(`\W`).ReplaceAllString(t.title, "") for idx, c := range t.Cases { c.TestSuite = t buf.WriteString(c.TestCase(idx) + "\n") } return buf.String() } var tplInputTestCase = template.Must(template.New("inputcase").Parse(` func Test{{ .OpName }}(t *testing.T) { svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("https://test")}) {{ if ne .ParamsString "" }}input := {{ .ParamsString }} {{ range $k, $v := .JSONValues -}} input.{{ $k }} = {{ $v }} {{ end -}} req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(input){{ else }}req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(nil){{ end }} r := req.HTTPRequest // build request {{ .TestCase.TestSuite.API.ProtocolPackage }}.Build(req) if req.Error != nil { t.Errorf("expect no error, got %v", req.Error) } {{ if ne .TestCase.InputTest.Body "" }}// assert body if r.Body == nil { t.Errorf("expect body not to be nil") } {{ .BodyAssertions }}{{ end }} {{ if ne .TestCase.InputTest.URI "" }}// assert URL awstesting.AssertURL(t, "https://test{{ .TestCase.InputTest.URI }}", r.URL.String()){{ end }} // assert headers {{ range $k, $v := .TestCase.InputTest.Headers -}} if e, a := "{{ $v }}", r.Header.Get("{{ $k }}"); e != a { t.Errorf("expect %v to be %v", e, a) } {{ end }} } `)) type tplInputTestCaseData struct { TestCase *testCase JSONValues map[string]string OpName, ParamsString string } func (t tplInputTestCaseData) BodyAssertions() string { code := &bytes.Buffer{} protocol := t.TestCase.TestSuite.API.Metadata.Protocol // Extract the body bytes switch protocol { case "rest-xml": fmt.Fprintln(code, "body := util.SortXML(r.Body)") default: fmt.Fprintln(code, "body, _ := ioutil.ReadAll(r.Body)") } // Generate the body verification code expectedBody := util.Trim(t.TestCase.InputTest.Body) switch protocol { case "ec2", "query": fmt.Fprintf(code, "awstesting.AssertQuery(t, `%s`, util.Trim(string(body)))", expectedBody) case "rest-xml": if strings.HasPrefix(expectedBody, "<") { fmt.Fprintf(code, "awstesting.AssertXML(t, `%s`, util.Trim(string(body)), %s{})", expectedBody, t.TestCase.Given.InputRef.ShapeName) } else { code.WriteString(fmtAssertEqual(fmt.Sprintf("%q", expectedBody), "util.Trim(string(body))")) } case "json", "jsonrpc", "rest-json": if strings.HasPrefix(expectedBody, "{") { fmt.Fprintf(code, "awstesting.AssertJSON(t, `%s`, util.Trim(string(body)))", expectedBody) } else { code.WriteString(fmtAssertEqual(fmt.Sprintf("%q", expectedBody), "util.Trim(string(body))")) } default: code.WriteString(fmtAssertEqual(expectedBody, "util.Trim(string(body))")) } return code.String() } func fmtAssertEqual(e, a string) string { const format = `if e, a := %s, %s; e != a { t.Errorf("expect %%v, got %%v", e, a) } ` return fmt.Sprintf(format, e, a) } func fmtAssertNil(v string) string { const format = `if e := %s; e != nil { t.Errorf("expect nil, got %%v", e) } ` return fmt.Sprintf(format, v) } var tplOutputTestCase = template.Must(template.New("outputcase").Parse(` func Test{{ .OpName }}(t *testing.T) { svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("https://test")}) buf := bytes.NewReader([]byte({{ .Body }})) req, out := svc.{{ .TestCase.Given.ExportedName }}Request(nil) req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}} // set headers {{ range $k, $v := .TestCase.OutputTest.Headers }}req.HTTPResponse.Header.Set("{{ $k }}", "{{ $v }}") {{ end }} // unmarshal response {{ .TestCase.TestSuite.API.ProtocolPackage }}.UnmarshalMeta(req) {{ .TestCase.TestSuite.API.ProtocolPackage }}.Unmarshal(req) if req.Error != nil { t.Errorf("expect not error, got %v", req.Error) } // assert response if out == nil { t.Errorf("expect not to be nil") } {{ .Assertions }} } `)) type tplOutputTestCaseData struct { TestCase *testCase Body, OpName, Assertions string } func (i *testCase) TestCase(idx int) string { var buf bytes.Buffer opName := i.TestSuite.API.StructName() + i.TestSuite.title + "Case" + strconv.Itoa(idx+1) if i.TestSuite.Type == TestSuiteTypeInput { // input test // query test should sort body as form encoded values switch i.TestSuite.API.Metadata.Protocol { case "query", "ec2": m, _ := url.ParseQuery(i.InputTest.Body) i.InputTest.Body = m.Encode() case "rest-xml": i.InputTest.Body = util.SortXML(bytes.NewReader([]byte(i.InputTest.Body))) case "json", "rest-json": i.InputTest.Body = strings.Replace(i.InputTest.Body, " ", "", -1) } jsonValues := buildJSONValues(i.Given.InputRef.Shape) var params interface{} if m, ok := i.Params.(map[string]interface{}); ok { paramsMap := map[string]interface{}{} for k, v := range m { if _, ok := jsonValues[k]; !ok { paramsMap[k] = v } else { if i.InputTest.JSONValues == nil { i.InputTest.JSONValues = map[string]string{} } i.InputTest.JSONValues[k] = serializeJSONValue(v.(map[string]interface{})) } } params = paramsMap } else { params = i.Params } input := tplInputTestCaseData{ TestCase: i, OpName: strings.ToUpper(opName[0:1]) + opName[1:], ParamsString: api.ParamsStructFromJSON(params, i.Given.InputRef.Shape, false), JSONValues: i.InputTest.JSONValues, } if err := tplInputTestCase.Execute(&buf, input); err != nil { panic(err) } } else if i.TestSuite.Type == TestSuiteTypeOutput { output := tplOutputTestCaseData{ TestCase: i, Body: fmt.Sprintf("%q", i.OutputTest.Body), OpName: strings.ToUpper(opName[0:1]) + opName[1:], Assertions: GenerateAssertions(i.Data, i.Given.OutputRef.Shape, "out"), } if err := tplOutputTestCase.Execute(&buf, output); err != nil { panic(err) } } return buf.String() } func serializeJSONValue(m map[string]interface{}) string { str := "aws.JSONValue" str += walkMap(m) return str } func walkMap(m map[string]interface{}) string { str := "{" for k, v := range m { str += fmt.Sprintf("%q:", k) switch v.(type) { case bool: str += fmt.Sprintf("%b,\n", v.(bool)) case string: str += fmt.Sprintf("%q,\n", v.(string)) case int: str += fmt.Sprintf("%d,\n", v.(int)) case float64: str += fmt.Sprintf("%f,\n", v.(float64)) case map[string]interface{}: str += walkMap(v.(map[string]interface{})) } } str += "}" return str } func buildJSONValues(shape *api.Shape) map[string]struct{} { keys := map[string]struct{}{} for key, field := range shape.MemberRefs { if field.JSONValue { keys[key] = struct{}{} } } return keys } // generateTestSuite generates a protocol test suite for a given configuration // JSON protocol test file. func generateTestSuite(filename string) string { inout := "Input" if strings.Contains(filename, "output/") { inout = "Output" } var suites []testSuite f, err := os.Open(filename) if err != nil { panic(err) } err = json.NewDecoder(f).Decode(&suites) if err != nil { panic(err) } var buf bytes.Buffer buf.WriteString("package " + suites[0].ProtocolPackage() + "_test\n\n") var innerBuf bytes.Buffer innerBuf.WriteString("//\n// Tests begin here\n//\n\n\n") for i, suite := range suites { svcPrefix := inout + "Service" + strconv.Itoa(i+1) suite.API.Metadata.ServiceAbbreviation = svcPrefix + "ProtocolTest" suite.API.Operations = map[string]*api.Operation{} for idx, c := range suite.Cases { c.Given.ExportedName = svcPrefix + "TestCaseOperation" + strconv.Itoa(idx+1) suite.API.Operations[c.Given.ExportedName] = c.Given } suite.Type = getType(inout) suite.API.NoInitMethods = true // don't generate init methods suite.API.NoStringerMethods = true // don't generate stringer methods suite.API.NoConstServiceNames = true // don't generate service names suite.API.Setup() suite.API.Metadata.EndpointPrefix = suite.API.PackageName() // Sort in order for deterministic test generation names := make([]string, 0, len(suite.API.Shapes)) for n := range suite.API.Shapes { names = append(names, n) } sort.Strings(names) for _, name := range names { s := suite.API.Shapes[name] s.Rename(svcPrefix + "TestShape" + name) } svcCode := addImports(suite.API.ServiceGoCode()) if i == 0 { importMatch := reImportRemoval.FindStringSubmatch(svcCode) buf.WriteString(importMatch[0] + "\n\n") buf.WriteString(preamble + "\n\n") } svcCode = removeImports(svcCode) svcCode = strings.Replace(svcCode, "func New(", "func New"+suite.API.StructName()+"(", -1) svcCode = strings.Replace(svcCode, "func newClient(", "func new"+suite.API.StructName()+"Client(", -1) svcCode = strings.Replace(svcCode, "return newClient(", "return new"+suite.API.StructName()+"Client(", -1) buf.WriteString(svcCode + "\n\n") apiCode := removeImports(suite.API.APIGoCode()) apiCode = strings.Replace(apiCode, "var oprw sync.Mutex", "", -1) apiCode = strings.Replace(apiCode, "oprw.Lock()", "", -1) apiCode = strings.Replace(apiCode, "defer oprw.Unlock()", "", -1) buf.WriteString(apiCode + "\n\n") innerBuf.WriteString(suite.TestSuite() + "\n") } return buf.String() + innerBuf.String() } // findMember searches the shape for the member with the matching key name. func findMember(shape *api.Shape, key string) string { for actualKey := range shape.MemberRefs { if strings.ToLower(key) == strings.ToLower(actualKey) { return actualKey } } return "" } // GenerateAssertions builds assertions for a shape based on its type. // // The shape's recursive values also will have assertions generated for them. func GenerateAssertions(out interface{}, shape *api.Shape, prefix string) string { if shape == nil { return "" } switch t := out.(type) { case map[string]interface{}: keys := util.SortedKeys(t) code := "" if shape.Type == "map" { for _, k := range keys { v := t[k] s := shape.ValueRef.Shape code += GenerateAssertions(v, s, prefix+"[\""+k+"\"]") } } else if shape.Type == "jsonvalue" { code += fmt.Sprintf("reflect.DeepEqual(%s, map[string]interface{}%s)", prefix, walkMap(out.(map[string]interface{}))) } else { for _, k := range keys { v := t[k] m := findMember(shape, k) s := shape.MemberRefs[m].Shape code += GenerateAssertions(v, s, prefix+"."+m+"") } } return code case []interface{}: code := "" for i, v := range t { s := shape.MemberRef.Shape code += GenerateAssertions(v, s, prefix+"["+strconv.Itoa(i)+"]") } return code default: switch shape.Type { case "timestamp": return fmtAssertEqual( fmt.Sprintf("time.Unix(%#v, 0).UTC().String()", out), fmt.Sprintf("%s.String()", prefix), ) case "blob": return fmtAssertEqual( fmt.Sprintf("%#v", out), fmt.Sprintf("string(%s)", prefix), ) case "integer", "long": return fmtAssertEqual( fmt.Sprintf("int64(%#v)", out), fmt.Sprintf("*%s", prefix), ) default: if !reflect.ValueOf(out).IsValid() { return fmtAssertNil(prefix) } return fmtAssertEqual( fmt.Sprintf("%#v", out), fmt.Sprintf("*%s", prefix), ) } } } func getType(t string) uint { switch t { case "Input": return TestSuiteTypeInput case "Output": return TestSuiteTypeOutput default: panic("Invalid type for test suite") } } func main() { out := generateTestSuite(os.Args[1]) if len(os.Args) == 3 { f, err := os.Create(os.Args[2]) defer f.Close() if err != nil { panic(err) } f.WriteString(util.GoFmt(out)) f.Close() c := exec.Command("gofmt", "-s", "-w", os.Args[2]) if err := c.Run(); err != nil { panic(err) } } else { fmt.Println(out) } }