package main import ( "bytes" "go/ast" "go/format" "go/parser" "go/token" "io/ioutil" "os" "sort" "strings" "text/template" "golang.org/x/tools/go/packages" ) func check(err error) { if err != nil { panic(err) } } type platformtestFuncDeclFinder struct { pkg *packages.Package testFuncs []*ast.FuncDecl } func isPlatformtestFunc(n *ast.FuncDecl) bool { if !n.Name.IsExported() { return false } if n.Recv != nil { return false } if n.Type.Results.NumFields() != 0 { return false } if n.Type.Params.NumFields() != 1 { return false } se, ok := n.Type.Params.List[0].Type.(*ast.StarExpr) if !ok { return false } sel, ok := se.X.(*ast.SelectorExpr) if !ok { return false } x, ok := sel.X.(*ast.Ident) if !ok { return false } if x.Name != "platformtest" || sel.Sel.Name != "Context" { return false } return true } func (e *platformtestFuncDeclFinder) Visit(n2 ast.Node) ast.Visitor { switch n := n2.(type) { case *ast.File: return e case *ast.FuncDecl: if isPlatformtestFunc(n) { e.testFuncs = append(e.testFuncs, n) } return nil default: return nil } } func main() { // TODO safeguards that prevent us from deleting non-generated generated_cases.go os.Remove("generated_cases.go") // (no error handling to easily cover the case where the file doesn't exist) pkgs, err := packages.Load( &packages.Config{ Mode: packages.LoadFiles, Tests: false, }, os.Args[1], ) check(err) if len(pkgs) != 1 { panic(pkgs) } p := pkgs[0] var tests []*ast.FuncDecl for _, f := range p.GoFiles { s := token.NewFileSet() a, err := parser.ParseFile(s, f, nil, parser.AllErrors) check(err) finder := &platformtestFuncDeclFinder{ pkg: p, } ast.Walk(finder, a) tests = append(tests, finder.testFuncs...) } sort.Slice(tests, func(i, j int) bool { return strings.Compare(tests[i].Name.Name, tests[j].Name.Name) < 0 }) { casesTemplate := ` // Code generated by zrepl tooling; DO NOT EDIT. package tests var Cases = []Case { {{- range . -}} {{ .Name }}, {{ end -}} } ` t, err := template.New("CaseFunc").Parse(casesTemplate) check(err) var buf bytes.Buffer err = t.Execute(&buf, tests) check(err) formatted, err := format.Source(buf.Bytes()) check(err) err = ioutil.WriteFile("generated_cases.go", formatted, 0664) check(err) } }