mirror of
https://github.com/zrepl/zrepl.git
synced 2025-01-25 15:48:40 +01:00
148 lines
3.4 KiB
Go
148 lines
3.4 KiB
Go
package serve
|
|
|
|
import (
|
|
"github.com/pkg/errors"
|
|
"github.com/zrepl/zrepl/config"
|
|
"github.com/zrepl/zrepl/daemon/transport"
|
|
"net"
|
|
"github.com/zrepl/zrepl/daemon/streamrpcconfig"
|
|
"github.com/problame/go-streamrpc"
|
|
"context"
|
|
"github.com/zrepl/zrepl/logger"
|
|
"github.com/zrepl/zrepl/zfs"
|
|
"time"
|
|
)
|
|
|
|
type contextKey int
|
|
|
|
const contextKeyLog contextKey = 0
|
|
|
|
type Logger = logger.Logger
|
|
|
|
func WithLogger(ctx context.Context, log Logger) context.Context {
|
|
return context.WithValue(ctx, contextKeyLog, log)
|
|
}
|
|
|
|
func getLogger(ctx context.Context) Logger {
|
|
if log, ok := ctx.Value(contextKeyLog).(Logger); ok {
|
|
return log
|
|
}
|
|
return logger.NewNullLogger()
|
|
}
|
|
|
|
type AuthenticatedConn interface {
|
|
net.Conn
|
|
// ClientIdentity must be a string that satisfies ValidateClientIdentity
|
|
ClientIdentity() string
|
|
}
|
|
|
|
// A client identity must be a single component in a ZFS filesystem path
|
|
func ValidateClientIdentity(in string) (err error) {
|
|
path, err := zfs.NewDatasetPath(in)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if path.Length() != 1 {
|
|
return errors.New("client identity must be a single path comonent (not empty, no '/')")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type authConn struct {
|
|
net.Conn
|
|
clientIdentity string
|
|
}
|
|
|
|
var _ AuthenticatedConn = authConn{}
|
|
|
|
func (c authConn) ClientIdentity() string {
|
|
if err := ValidateClientIdentity(c.clientIdentity); err != nil {
|
|
panic(err)
|
|
}
|
|
return c.clientIdentity
|
|
}
|
|
|
|
// like net.Listener, but with an AuthenticatedConn instead of net.Conn
|
|
type AuthenticatedListener interface {
|
|
Addr() (net.Addr)
|
|
Accept(ctx context.Context) (AuthenticatedConn, error)
|
|
Close() error
|
|
}
|
|
|
|
type ListenerFactory interface {
|
|
Listen() (AuthenticatedListener, error)
|
|
}
|
|
|
|
type HandshakeListenerFactory struct {
|
|
lf ListenerFactory
|
|
}
|
|
|
|
func (lf HandshakeListenerFactory) Listen() (AuthenticatedListener, error) {
|
|
l, err := lf.lf.Listen()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return HandshakeListener{l}, nil
|
|
}
|
|
|
|
type HandshakeListener struct {
|
|
l AuthenticatedListener
|
|
}
|
|
|
|
func (l HandshakeListener) Addr() (net.Addr) { return l.l.Addr() }
|
|
|
|
func (l HandshakeListener) Close() error { return l.l.Close() }
|
|
|
|
func (l HandshakeListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
|
|
conn, err := l.l.Accept(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
dl, ok := ctx.Deadline()
|
|
if !ok {
|
|
dl = time.Now().Add(10*time.Second) // FIXME constant
|
|
}
|
|
if err := transport.DoHandshakeCurrentVersion(conn, dl); err != nil {
|
|
conn.Close()
|
|
return nil, err
|
|
}
|
|
return conn, nil
|
|
}
|
|
|
|
func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf *streamrpc.ConnConfig, _ error) {
|
|
|
|
var (
|
|
lfError, rpcErr error
|
|
)
|
|
switch v := in.Ret.(type) {
|
|
case *config.TCPServe:
|
|
lf, lfError = TCPListenerFactoryFromConfig(g, v)
|
|
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
|
case *config.TLSServe:
|
|
lf, lfError = TLSListenerFactoryFromConfig(g, v)
|
|
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
|
case *config.StdinserverServer:
|
|
lf, lfError = MultiStdinserverListenerFactoryFromConfig(g, v)
|
|
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
|
case *config.LocalServe:
|
|
lf, lfError = LocalListenerFactoryFromConfig(g, v)
|
|
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
|
default:
|
|
return nil, nil, errors.Errorf("internal error: unknown serve type %T", v)
|
|
}
|
|
|
|
if lfError != nil {
|
|
return nil, nil, lfError
|
|
}
|
|
if rpcErr != nil {
|
|
return nil, nil, rpcErr
|
|
}
|
|
|
|
lf = HandshakeListenerFactory{lf}
|
|
|
|
return lf, conf, nil
|
|
|
|
}
|
|
|
|
|