mirror of
https://github.com/zrepl/zrepl.git
synced 2025-01-25 15:48:40 +01:00
141 lines
2.3 KiB
Go
141 lines
2.3 KiB
Go
|
package main
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"go/ast"
|
||
|
"go/format"
|
||
|
"go/parser"
|
||
|
"go/token"
|
||
|
"io/ioutil"
|
||
|
"os"
|
||
|
"sort"
|
||
|
"strings"
|
||
|
"text/template"
|
||
|
|
||
|
"golang.org/x/tools/go/packages"
|
||
|
)
|
||
|
|
||
|
func check(err error) {
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type platformtestFuncDeclFinder struct {
|
||
|
pkg *packages.Package
|
||
|
testFuncs []*ast.FuncDecl
|
||
|
}
|
||
|
|
||
|
func isPlatformtestFunc(n *ast.FuncDecl) bool {
|
||
|
if !n.Name.IsExported() {
|
||
|
return false
|
||
|
}
|
||
|
if n.Recv != nil {
|
||
|
return false
|
||
|
}
|
||
|
if n.Type.Results.NumFields() != 0 {
|
||
|
return false
|
||
|
}
|
||
|
if n.Type.Params.NumFields() != 1 {
|
||
|
return false
|
||
|
}
|
||
|
se, ok := n.Type.Params.List[0].Type.(*ast.StarExpr)
|
||
|
if !ok {
|
||
|
return false
|
||
|
}
|
||
|
sel, ok := se.X.(*ast.SelectorExpr)
|
||
|
if !ok {
|
||
|
return false
|
||
|
}
|
||
|
x, ok := sel.X.(*ast.Ident)
|
||
|
if !ok {
|
||
|
return false
|
||
|
}
|
||
|
if x.Name != "platformtest" || sel.Sel.Name != "Context" {
|
||
|
return false
|
||
|
}
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func (e *platformtestFuncDeclFinder) Visit(n2 ast.Node) ast.Visitor {
|
||
|
switch n := n2.(type) {
|
||
|
case *ast.File:
|
||
|
return e
|
||
|
case *ast.FuncDecl:
|
||
|
if isPlatformtestFunc(n) {
|
||
|
e.testFuncs = append(e.testFuncs, n)
|
||
|
}
|
||
|
return nil
|
||
|
default:
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func main() {
|
||
|
|
||
|
// TODO safeguards that prevent us from deleting non-generated generated_cases.go
|
||
|
os.Remove("generated_cases.go")
|
||
|
// (no error handling to easily cover the case where the file doesn't exist)
|
||
|
|
||
|
pkgs, err := packages.Load(
|
||
|
&packages.Config{
|
||
|
Mode: packages.LoadFiles,
|
||
|
Tests: false,
|
||
|
},
|
||
|
os.Args[1],
|
||
|
)
|
||
|
check(err)
|
||
|
|
||
|
if len(pkgs) != 1 {
|
||
|
panic(pkgs)
|
||
|
}
|
||
|
|
||
|
p := pkgs[0]
|
||
|
|
||
|
var tests []*ast.FuncDecl
|
||
|
|
||
|
for _, f := range p.GoFiles {
|
||
|
s := token.NewFileSet()
|
||
|
a, err := parser.ParseFile(s, f, nil, parser.AllErrors)
|
||
|
check(err)
|
||
|
finder := &platformtestFuncDeclFinder{
|
||
|
pkg: p,
|
||
|
}
|
||
|
ast.Walk(finder, a)
|
||
|
tests = append(tests, finder.testFuncs...)
|
||
|
}
|
||
|
|
||
|
sort.Slice(tests, func(i, j int) bool {
|
||
|
return strings.Compare(tests[i].Name.Name, tests[j].Name.Name) < 0
|
||
|
})
|
||
|
|
||
|
{
|
||
|
casesTemplate := `
|
||
|
// Code generated by zrepl tooling; DO NOT EDIT.
|
||
|
|
||
|
package tests
|
||
|
|
||
|
var Cases = []Case {
|
||
|
{{- range . -}}
|
||
|
{{ .Name }},
|
||
|
{{ end -}}
|
||
|
}
|
||
|
|
||
|
`
|
||
|
t, err := template.New("CaseFunc").Parse(casesTemplate)
|
||
|
check(err)
|
||
|
|
||
|
var buf bytes.Buffer
|
||
|
err = t.Execute(&buf, tests)
|
||
|
check(err)
|
||
|
|
||
|
formatted, err := format.Source(buf.Bytes())
|
||
|
check(err)
|
||
|
|
||
|
err = ioutil.WriteFile("generated_cases.go", formatted, 0664)
|
||
|
check(err)
|
||
|
|
||
|
}
|
||
|
|
||
|
}
|