mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-25 01:44:43 +01:00
105 lines
2.4 KiB
Go
105 lines
2.4 KiB
Go
package serve
|
|
|
|
import (
|
|
"github.com/pkg/errors"
|
|
"github.com/zrepl/zrepl/config"
|
|
"net"
|
|
"github.com/zrepl/zrepl/daemon/streamrpcconfig"
|
|
"github.com/problame/go-streamrpc"
|
|
"context"
|
|
"github.com/zrepl/zrepl/logger"
|
|
"github.com/zrepl/zrepl/zfs"
|
|
)
|
|
|
|
type contextKey int
|
|
|
|
const contextKeyLog contextKey = 0
|
|
|
|
type Logger = logger.Logger
|
|
|
|
func WithLogger(ctx context.Context, log Logger) context.Context {
|
|
return context.WithValue(ctx, contextKeyLog, log)
|
|
}
|
|
|
|
func getLogger(ctx context.Context) Logger {
|
|
if log, ok := ctx.Value(contextKeyLog).(Logger); ok {
|
|
return log
|
|
}
|
|
return logger.NewNullLogger()
|
|
}
|
|
|
|
type AuthenticatedConn interface {
|
|
net.Conn
|
|
// ClientIdentity must be a string that satisfies ValidateClientIdentity
|
|
ClientIdentity() string
|
|
}
|
|
|
|
// A client identity must be a single component in a ZFS filesystem path
|
|
func ValidateClientIdentity(in string) (err error) {
|
|
path, err := zfs.NewDatasetPath(in)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if path.Length() != 1 {
|
|
return errors.New("client identity must be a single path comonent (not empty, no '/')")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type authConn struct {
|
|
net.Conn
|
|
clientIdentity string
|
|
}
|
|
|
|
var _ AuthenticatedConn = authConn{}
|
|
|
|
func (c authConn) ClientIdentity() string {
|
|
if err := ValidateClientIdentity(c.clientIdentity); err != nil {
|
|
panic(err)
|
|
}
|
|
return c.clientIdentity
|
|
}
|
|
|
|
// like net.Listener, but with an AuthenticatedConn instead of net.Conn
|
|
type AuthenticatedListener interface {
|
|
Addr() (net.Addr)
|
|
Accept(ctx context.Context) (AuthenticatedConn, error)
|
|
Close() error
|
|
}
|
|
|
|
type ListenerFactory interface {
|
|
Listen() (AuthenticatedListener, error)
|
|
}
|
|
|
|
func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf *streamrpc.ConnConfig, _ error) {
|
|
|
|
var (
|
|
lfError, rpcErr error
|
|
)
|
|
switch v := in.Ret.(type) {
|
|
case *config.TCPServe:
|
|
lf, lfError = TCPListenerFactoryFromConfig(g, v)
|
|
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
|
case *config.TLSServe:
|
|
lf, lfError = TLSListenerFactoryFromConfig(g, v)
|
|
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
|
case *config.StdinserverServer:
|
|
lf, lfError = MultiStdinserverListenerFactoryFromConfig(g, v)
|
|
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
|
default:
|
|
return nil, nil, errors.Errorf("internal error: unknown serve type %T", v)
|
|
}
|
|
|
|
if lfError != nil {
|
|
return nil, nil, lfError
|
|
}
|
|
if rpcErr != nil {
|
|
return nil, nil, rpcErr
|
|
}
|
|
|
|
return lf, conf, nil
|
|
|
|
}
|
|
|
|
|