mirror of
https://github.com/zrepl/zrepl.git
synced 2025-01-07 06:49:15 +01:00
44bd354eae
Signed-off-by: InsanePrawn <insane.prawny@gmail.com>
195 lines
5.6 KiB
Go
195 lines
5.6 KiB
Go
package local
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
|
|
"github.com/zrepl/zrepl/config"
|
|
"github.com/zrepl/zrepl/transport"
|
|
"github.com/zrepl/zrepl/util/socketpair"
|
|
)
|
|
|
|
var localListeners struct {
|
|
m map[string]*LocalListener // listenerName -> listener
|
|
init sync.Once
|
|
mtx sync.Mutex
|
|
}
|
|
|
|
func GetLocalListener(listenerName string) *LocalListener {
|
|
|
|
localListeners.init.Do(func() {
|
|
localListeners.m = make(map[string]*LocalListener)
|
|
})
|
|
|
|
localListeners.mtx.Lock()
|
|
defer localListeners.mtx.Unlock()
|
|
|
|
l, ok := localListeners.m[listenerName]
|
|
if !ok {
|
|
l = newLocalListener()
|
|
localListeners.m[listenerName] = l
|
|
}
|
|
return l
|
|
|
|
}
|
|
|
|
type connectRequest struct {
|
|
clientIdentity string
|
|
callback chan connectResult
|
|
}
|
|
|
|
type connectResult struct {
|
|
conn transport.Wire
|
|
err error
|
|
}
|
|
|
|
type LocalListener struct {
|
|
connects chan connectRequest
|
|
}
|
|
|
|
func newLocalListener() *LocalListener {
|
|
return &LocalListener{
|
|
connects: make(chan connectRequest),
|
|
}
|
|
}
|
|
|
|
// Connect to the LocalListener from a client with identity clientIdentity
|
|
func (l *LocalListener) Connect(dialCtx context.Context, clientIdentity string) (conn transport.Wire, err error) {
|
|
|
|
// place request
|
|
req := connectRequest{
|
|
clientIdentity: clientIdentity,
|
|
// ensure non-blocking send in Accept, we don't necessarily read from callback before
|
|
// Accept writes to it
|
|
callback: make(chan connectResult, 1),
|
|
}
|
|
select {
|
|
case l.connects <- req:
|
|
case <-dialCtx.Done():
|
|
return nil, dialCtx.Err()
|
|
}
|
|
|
|
// wait for listener response
|
|
select {
|
|
case connRes := <-req.callback:
|
|
conn, err = connRes.conn, connRes.err
|
|
case <-dialCtx.Done():
|
|
close(req.callback) // sending to the channel afterwards will panic, the listener has to catch this
|
|
conn, err = nil, dialCtx.Err()
|
|
}
|
|
|
|
return conn, err
|
|
}
|
|
|
|
type localAddr struct {
|
|
S string
|
|
}
|
|
|
|
func (localAddr) Network() string { return "local" }
|
|
|
|
func (a localAddr) String() string { return a.S }
|
|
|
|
func (l *LocalListener) Addr() net.Addr { return localAddr{"<listening>"} }
|
|
|
|
func (l *LocalListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
|
|
respondToRequest := func(req connectRequest, res connectResult) (err error) {
|
|
|
|
transport.GetLogger(ctx).
|
|
WithField("res.conn", res.conn).WithField("res.err", res.err).
|
|
Debug("responding to client request")
|
|
|
|
// contract between Connect and Accept is that Connect sends a req.callback
|
|
// into which we can send one result non-blockingly.
|
|
// We want to panic if that contract is violated (impl error)
|
|
//
|
|
// However, Connect also supports timeouts through context cancellation:
|
|
// Connect closes the channel into which we might send, which will panic.
|
|
//
|
|
// ==> distinguish those cases in defer
|
|
const clientCallbackBlocked = "client-provided callback did block on send"
|
|
defer func() {
|
|
errv := recover()
|
|
if errv == clientCallbackBlocked {
|
|
// this would be a violation of contract between Connect and Accept, see above
|
|
panic(clientCallbackBlocked)
|
|
} else {
|
|
transport.GetLogger(ctx).WithField("recover_err", errv).
|
|
Debug("panic on send to client callback, likely a legitimate client-side timeout")
|
|
}
|
|
}()
|
|
select {
|
|
case req.callback <- res:
|
|
err = nil
|
|
default:
|
|
panic(clientCallbackBlocked)
|
|
}
|
|
close(req.callback)
|
|
return err
|
|
}
|
|
|
|
transport.GetLogger(ctx).Debug("waiting for local client connect requests")
|
|
var req connectRequest
|
|
select {
|
|
case req = <-l.connects:
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
|
|
transport.GetLogger(ctx).WithField("client_identity", req.clientIdentity).Debug("got connect request")
|
|
if req.clientIdentity == "" {
|
|
res := connectResult{nil, fmt.Errorf("client identity must not be empty")}
|
|
if err := respondToRequest(req, res); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, fmt.Errorf("client connected with empty client identity")
|
|
}
|
|
|
|
transport.GetLogger(ctx).Debug("creating socketpair")
|
|
left, right, err := socketpair.SocketPair()
|
|
if err != nil {
|
|
res := connectResult{nil, fmt.Errorf("server error: %s", err)}
|
|
if respErr := respondToRequest(req, res); respErr != nil {
|
|
// returning the socketpair error properly is more important than the error sent to the client
|
|
transport.GetLogger(ctx).WithError(respErr).Error("error responding to client")
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
transport.GetLogger(ctx).Debug("responding with left side of socketpair")
|
|
res := connectResult{left, nil}
|
|
if err := respondToRequest(req, res); err != nil {
|
|
transport.GetLogger(ctx).WithError(err).Error("error responding to client")
|
|
if err := left.Close(); err != nil {
|
|
transport.GetLogger(ctx).WithError(err).Error("cannot close left side of socketpair")
|
|
}
|
|
if err := right.Close(); err != nil {
|
|
transport.GetLogger(ctx).WithError(err).Error("cannot close right side of socketpair")
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
return transport.NewAuthConn(right, req.clientIdentity), nil
|
|
}
|
|
|
|
func (l *LocalListener) Close() error {
|
|
// FIXME: make sure concurrent Accepts return with error, and further Accepts return that error, too
|
|
// Example impl: for each accept, do context.WithCancel, and store the cancel in a list
|
|
// When closing, set a member variable to state=closed, make sure accept will exit early
|
|
// and then call all cancels in the list
|
|
// The code path from Accept entry over check if state=closed to list entry must be protected by a mutex.
|
|
return nil
|
|
}
|
|
|
|
func LocalListenerFactoryFromConfig(g *config.Global, in *config.LocalServe) (transport.AuthenticatedListenerFactory, error) {
|
|
if in.ListenerName == "" {
|
|
return nil, fmt.Errorf("ListenerName must not be empty")
|
|
}
|
|
listenerName := in.ListenerName
|
|
lf := func() (transport.AuthenticatedListener, error) {
|
|
return GetLocalListener(listenerName), nil
|
|
}
|
|
return lf, nil
|
|
}
|