zrepl/daemon/transport/serve/serve.go

108 lines
2.5 KiB
Go
Raw Normal View History

package serve
import (
"github.com/pkg/errors"
"github.com/zrepl/zrepl/config"
"net"
"github.com/zrepl/zrepl/daemon/streamrpcconfig"
"github.com/problame/go-streamrpc"
"context"
"github.com/zrepl/zrepl/logger"
"github.com/zrepl/zrepl/zfs"
)
type contextKey int
const contextKeyLog contextKey = 0
type Logger = logger.Logger
func WithLogger(ctx context.Context, log Logger) context.Context {
return context.WithValue(ctx, contextKeyLog, log)
}
func getLogger(ctx context.Context) Logger {
if log, ok := ctx.Value(contextKeyLog).(Logger); ok {
return log
}
return logger.NewNullLogger()
}
type AuthenticatedConn interface {
net.Conn
// ClientIdentity must be a string that satisfies ValidateClientIdentity
ClientIdentity() string
}
// A client identity must be a single component in a ZFS filesystem path
func ValidateClientIdentity(in string) (err error) {
path, err := zfs.NewDatasetPath(in)
if err != nil {
return err
}
if path.Length() != 1 {
return errors.New("client identity must be a single path comonent (not empty, no '/')")
}
return nil
}
type authConn struct {
net.Conn
clientIdentity string
}
var _ AuthenticatedConn = authConn{}
func (c authConn) ClientIdentity() string {
if err := ValidateClientIdentity(c.clientIdentity); err != nil {
panic(err)
}
return c.clientIdentity
}
// like net.Listener, but with an AuthenticatedConn instead of net.Conn
type AuthenticatedListener interface {
Addr() (net.Addr)
Accept(ctx context.Context) (AuthenticatedConn, error)
Close() error
}
type ListenerFactory interface {
Listen() (AuthenticatedListener, error)
}
func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf *streamrpc.ConnConfig, _ error) {
var (
lfError, rpcErr error
)
switch v := in.Ret.(type) {
case *config.TCPServe:
lf, lfError = TCPListenerFactoryFromConfig(g, v)
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
case *config.TLSServe:
lf, lfError = TLSListenerFactoryFromConfig(g, v)
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
case *config.StdinserverServer:
lf, lfError = MultiStdinserverListenerFactoryFromConfig(g, v)
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
case *config.LocalServe:
lf, lfError = LocalListenerFactoryFromConfig(g, v)
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
default:
return nil, nil, errors.Errorf("internal error: unknown serve type %T", v)
}
if lfError != nil {
return nil, nil, lfError
}
if rpcErr != nil {
return nil, nil, rpcErr
}
return lf, conf, nil
}