zrepl/platformtest/tests/gen/gen.go

141 lines
2.3 KiB
Go

package main
import (
"bytes"
"go/ast"
"go/format"
"go/parser"
"go/token"
"io/ioutil"
"os"
"sort"
"strings"
"text/template"
"golang.org/x/tools/go/packages"
)
func check(err error) {
if err != nil {
panic(err)
}
}
type platformtestFuncDeclFinder struct {
pkg *packages.Package
testFuncs []*ast.FuncDecl
}
func isPlatformtestFunc(n *ast.FuncDecl) bool {
if !n.Name.IsExported() {
return false
}
if n.Recv != nil {
return false
}
if n.Type.Results.NumFields() != 0 {
return false
}
if n.Type.Params.NumFields() != 1 {
return false
}
se, ok := n.Type.Params.List[0].Type.(*ast.StarExpr)
if !ok {
return false
}
sel, ok := se.X.(*ast.SelectorExpr)
if !ok {
return false
}
x, ok := sel.X.(*ast.Ident)
if !ok {
return false
}
if x.Name != "platformtest" || sel.Sel.Name != "Context" {
return false
}
return true
}
func (e *platformtestFuncDeclFinder) Visit(n2 ast.Node) ast.Visitor {
switch n := n2.(type) {
case *ast.File:
return e
case *ast.FuncDecl:
if isPlatformtestFunc(n) {
e.testFuncs = append(e.testFuncs, n)
}
return nil
default:
return nil
}
}
func main() {
// TODO safeguards that prevent us from deleting non-generated generated_cases.go
os.Remove("generated_cases.go")
// (no error handling to easily cover the case where the file doesn't exist)
pkgs, err := packages.Load(
&packages.Config{
Mode: packages.LoadFiles,
Tests: false,
},
os.Args[1],
)
check(err)
if len(pkgs) != 1 {
panic(pkgs)
}
p := pkgs[0]
var tests []*ast.FuncDecl
for _, f := range p.GoFiles {
s := token.NewFileSet()
a, err := parser.ParseFile(s, f, nil, parser.AllErrors)
check(err)
finder := &platformtestFuncDeclFinder{
pkg: p,
}
ast.Walk(finder, a)
tests = append(tests, finder.testFuncs...)
}
sort.Slice(tests, func(i, j int) bool {
return strings.Compare(tests[i].Name.Name, tests[j].Name.Name) < 0
})
{
casesTemplate := `
// Code generated by zrepl tooling; DO NOT EDIT.
package tests
var Cases = []Case {
{{- range . -}}
{{ .Name }},
{{ end -}}
}
`
t, err := template.New("CaseFunc").Parse(casesTemplate)
check(err)
var buf bytes.Buffer
err = t.Execute(&buf, tests)
check(err)
formatted, err := format.Source(buf.Bytes())
check(err)
err = ioutil.WriteFile("generated_cases.go", formatted, 0664)
check(err)
}
}