implement test server and client to diagnose certificate issues

cc @asomers
This commit is contained in:
Christian Schwarz 2021-10-09 15:27:39 +02:00
parent 6f11e92801
commit fc9c9b184e

View File

@ -0,0 +1,133 @@
package main
import (
"context"
"flag"
"io"
"log"
"os"
"strings"
"github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/transport/tls"
)
var servConf = config.TLSServe{
ServeCommon: config.ServeCommon{
Type: "tls",
},
}
var clientConf = config.TLSConnect{
ConnectCommon: config.ConnectCommon{
Type: "",
},
Address: "",
Ca: "",
Cert: "",
Key: "",
ServerCN: "",
DialTimeout: 0,
}
var ca string
var mode string
func main() {
flag.StringVar(&mode, "mode", "", "server|client")
flag.StringVar(&ca, "ca", "", "path")
flag.StringVar(&servConf.Listen, "serve.listen", "", "")
flag.StringVar(&servConf.Cert, "serve.cert", "", "path")
flag.StringVar(&servConf.Key, "serve.key", "", "path")
var clientCN string
flag.StringVar(&clientCN, "serve.client_cn", "", "")
flag.StringVar(&clientConf.Address, "client.address", "", "")
flag.StringVar(&clientConf.Cert, "client.cert", "", "path")
flag.StringVar(&clientConf.Key, "client.key", "", "path")
flag.StringVar(&clientConf.ServerCN, "client.server_cn", "", "")
flag.Parse()
servConf.ClientCNs = append(servConf.ClientCNs, clientCN)
servConf.Ca = ca
clientConf.Ca = ca
switch mode {
case "server":
server()
case "client":
client()
default:
panic(mode)
}
}
func server() {
servFactory, err := tls.TLSListenerFactoryFromConfig(nil, &servConf)
if err != nil {
panic(err)
}
listener, err := servFactory()
if err != nil {
panic(err)
}
ctx := context.Background()
for {
conn, err := listener.Accept(ctx)
if err != nil {
log.Printf("accept error: %s", err)
continue
}
go func() {
defer conn.Close()
log.Printf("handling connection %s", conn)
_, err = io.Copy(conn, strings.NewReader("here is the server\n"))
if err != nil {
log.Printf("%s: respond to client error: %s", conn, err)
return
}
log.Printf("%s: waiting for client to close connection", conn)
_, err = io.Copy(io.Discard, conn)
if err != nil {
log.Printf("%s: error draining client connection: %s", conn, err)
return
}
log.Printf("%s: done", conn)
return
}()
}
}
func client() {
connecter, err := tls.TLSConnecterFromConfig(&clientConf)
if err != nil {
panic(err)
}
ctx := context.Background()
conn, err := connecter.Connect(ctx)
if err != nil {
panic(err)
}
defer conn.Close()
_, err = io.Copy(os.Stdout, conn)
if err != nil {
panic(err)
}
}