zrepl/daemon/serve/serve_local.go
2018-10-11 13:06:47 +02:00

232 lines
5.8 KiB
Go

package serve
import (
"context"
"fmt"
"github.com/zrepl/zrepl/config"
"golang.org/x/sys/unix"
"net"
"os"
"sync"
)
var localListeners struct {
m map[string]*LocalListener // listenerName -> listener
init sync.Once
mtx sync.Mutex
}
func GetLocalListener(listenerName string) (*LocalListener) {
localListeners.init.Do(func() {
localListeners.m = make(map[string]*LocalListener)
})
localListeners.mtx.Lock()
defer localListeners.mtx.Unlock()
l, ok := localListeners.m[listenerName]
if !ok {
l = newLocalListener()
localListeners.m[listenerName] = l
}
return l
}
type connectRequest struct {
clientIdentity string
callback chan connectResult
}
type connectResult struct {
conn net.Conn
err error
}
type LocalListener struct {
connects chan connectRequest
}
func newLocalListener() *LocalListener {
return &LocalListener{
connects: make(chan connectRequest),
}
}
// Connect to the LocalListener from a client with identity clientIdentity
func (l *LocalListener) Connect(dialCtx context.Context, clientIdentity string) (conn net.Conn, err error) {
// place request
req := connectRequest{
clientIdentity: clientIdentity,
callback: make(chan connectResult),
}
select {
case l.connects <- req:
case <-dialCtx.Done():
return nil, dialCtx.Err()
}
// wait for listener response
select {
case connRes := <- req.callback:
conn, err = connRes.conn, connRes.err
case <-dialCtx.Done():
close(req.callback) // sending to the channel afterwards will panic, the listener has to catch this
conn, err = nil, dialCtx.Err()
}
return conn, err
}
type localAddr struct {
S string
}
func (localAddr) Network() string { return "local" }
func (a localAddr) String() string { return a.S }
func (l *LocalListener) Addr() (net.Addr) { return localAddr{"<listening>"} }
type localConn struct {
net.Conn
clientIdentity string
}
func (l localConn) ClientIdentity() string { return l.clientIdentity }
func (l *LocalListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
respondToRequest := func(req connectRequest, res connectResult) (err error) {
getLogger(ctx).
WithField("res.conn", res.conn).WithField("res.err", res.err).
Debug("responding to client request")
defer func() {
errv := recover()
getLogger(ctx).WithField("recover_err", errv).
Debug("panic on send to client callback, likely a legitimate client-side timeout")
}()
select {
case req.callback <- res:
err = nil
default:
err = fmt.Errorf("client-provided callback did block on send")
}
close(req.callback)
return err
}
getLogger(ctx).Debug("waiting for local client connect requests")
var req connectRequest
select {
case req = <-l.connects:
case <-ctx.Done():
return nil, ctx.Err()
}
getLogger(ctx).WithField("client_identity", req.clientIdentity).Debug("got connect request")
if req.clientIdentity == "" {
res := connectResult{nil, fmt.Errorf("client identity must not be empty")}
if err := respondToRequest(req, res); err != nil {
return nil, err
}
return nil, fmt.Errorf("client connected with empty client identity")
}
getLogger(ctx).Debug("creating socketpair")
left, right, err := makeSocketpairConn()
if err != nil {
res := connectResult{nil, fmt.Errorf("server error: %s", err)}
if respErr := respondToRequest(req, res); respErr != nil {
// returning the socketpair error properly is more important than the error sent to the client
getLogger(ctx).WithError(respErr).Error("error responding to client")
}
return nil, err
}
getLogger(ctx).Debug("responding with left side of socketpair")
res := connectResult{left, nil}
if err := respondToRequest(req, res); err != nil {
getLogger(ctx).WithError(err).Error("error responding to client")
if err := left.Close(); err != nil {
getLogger(ctx).WithError(err).Error("cannot close left side of socketpair")
}
if err := right.Close(); err != nil {
getLogger(ctx).WithError(err).Error("cannot close right side of socketpair")
}
return nil, err
}
return localConn{right, req.clientIdentity}, nil
}
type fileConn struct {
net.Conn // net.FileConn
f *os.File
}
func (c fileConn) Close() error {
if err := c.Conn.Close(); err != nil {
return err
}
if err := c.f.Close(); err != nil {
return err
}
return nil
}
func makeSocketpairConn() (a, b net.Conn, err error) {
// don't use net.Pipe, as it doesn't implement things like lingering, which our code relies on
sockpair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0)
if err != nil {
return nil, nil, err
}
toConn := func(fd int) (net.Conn, error) {
f := os.NewFile(uintptr(fd), "fileconn")
if f == nil {
panic(fd)
}
c, err := net.FileConn(f)
if err != nil {
f.Close()
return nil, err
}
return fileConn{Conn: c, f: f}, nil
}
if a, err = toConn(sockpair[0]); err != nil { // shadowing
return nil, nil, err
}
if b, err = toConn(sockpair[1]); err != nil { // shadowing
a.Close()
return nil, nil, err
}
return a, b, nil
}
func (l *LocalListener) Close() error {
// FIXME: make sure concurrent Accepts return with error, and further Accepts return that error, too
// Example impl: for each accept, do context.WithCancel, and store the cancel in a list
// When closing, set a member variable to state=closed, make sure accept will exit early
// and then call all cancels in the list
// The code path from Accept entry over check if state=closed to list entry must be protected by a mutex.
return nil
}
type LocalListenerFactory struct {
listenerName string
}
func LocalListenerFactoryFromConfig(g *config.Global, in *config.LocalServe) (f *LocalListenerFactory, err error) {
if in.ListenerName == "" {
return nil, fmt.Errorf("ListenerName must not be empty")
}
return &LocalListenerFactory{listenerName: in.ListenerName}, nil
}
func (lf *LocalListenerFactory) Listen() (AuthenticatedListener, error) {
return GetLocalListener(lf.listenerName), nil
}