zrepl/daemon/transport/serve/serve_stdinserver.go
2018-10-11 21:21:46 +02:00

159 lines
3.9 KiB
Go

package serve
import (
"github.com/problame/go-netssh"
"github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/daemon/nethelpers"
"io"
"net"
"path"
"time"
"context"
"github.com/pkg/errors"
"sync/atomic"
)
type StdinserverListenerFactory struct {
ClientIdentities []string
Sockdir string
}
func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (f *multiStdinserverListenerFactory, err error) {
for _, ci := range in.ClientIdentities {
if err := ValidateClientIdentity(ci); err != nil {
return nil, errors.Wrapf(err, "invalid client identity %q", ci)
}
}
f = &multiStdinserverListenerFactory{
ClientIdentities: in.ClientIdentities,
Sockdir: g.Serve.StdinServer.SockDir,
}
return
}
type multiStdinserverListenerFactory struct {
ClientIdentities []string
Sockdir string
}
func (f *multiStdinserverListenerFactory) Listen() (AuthenticatedListener, error) {
return multiStdinserverListenerFromClientIdentities(f.Sockdir, f.ClientIdentities)
}
type multiStdinserverAcceptRes struct {
conn AuthenticatedConn
err error
}
type MultiStdinserverListener struct {
listeners []*stdinserverListener
accepts chan multiStdinserverAcceptRes
closed int32
}
// client identities must be validated
func multiStdinserverListenerFromClientIdentities(sockdir string, cis []string) (*MultiStdinserverListener, error) {
listeners := make([]*stdinserverListener, 0, len(cis))
var err error
for _, ci := range cis {
sockpath := path.Join(sockdir, ci)
l := &stdinserverListener{clientIdentity: ci}
if err = nethelpers.PreparePrivateSockpath(sockpath); err != nil {
break
}
if l.l, err = netssh.Listen(sockpath); err != nil {
break
}
listeners = append(listeners, l)
}
if err != nil {
for _, l := range listeners {
l.Close() // FIXME error reporting?
}
return nil, err
}
return &MultiStdinserverListener{listeners: listeners}, nil
}
func (m *MultiStdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error){
if m.accepts == nil {
m.accepts = make(chan multiStdinserverAcceptRes, len(m.listeners))
for i := range m.listeners {
go func(i int) {
for atomic.LoadInt32(&m.closed) == 0 {
conn, err := m.listeners[i].Accept(context.TODO())
m.accepts <- multiStdinserverAcceptRes{conn, err}
}
}(i)
}
}
res := <- m.accepts
return res.conn, res.err
}
func (m *MultiStdinserverListener) Addr() (net.Addr) {
return netsshAddr{}
}
func (m *MultiStdinserverListener) Close() error {
atomic.StoreInt32(&m.closed, 1)
var oneErr error
for _, l := range m.listeners {
if err := l.Close(); err != nil && oneErr == nil {
oneErr = err
}
}
return oneErr
}
// a single stdinserverListener (part of multiStinserverListener)
type stdinserverListener struct {
l *netssh.Listener
clientIdentity string
}
func (l stdinserverListener) Addr() net.Addr {
return netsshAddr{}
}
func (l stdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
c, err := l.l.Accept()
if err != nil {
return nil, err
}
return netsshConnToNetConnAdatper{c, l.clientIdentity}, nil
}
func (l stdinserverListener) Close() (err error) {
return l.l.Close()
}
type netsshAddr struct{}
func (netsshAddr) Network() string { return "netssh" }
func (netsshAddr) String() string { return "???" }
type netsshConnToNetConnAdatper struct {
io.ReadWriteCloser // works for both netssh.SSHConn and netssh.ServeConn
clientIdentity string
}
func (a netsshConnToNetConnAdatper) ClientIdentity() string { return a.clientIdentity }
func (netsshConnToNetConnAdatper) LocalAddr() net.Addr { return netsshAddr{} }
func (netsshConnToNetConnAdatper) RemoteAddr() net.Addr { return netsshAddr{} }
// FIXME log warning once!
func (netsshConnToNetConnAdatper) SetDeadline(t time.Time) error { return nil }
func (netsshConnToNetConnAdatper) SetReadDeadline(t time.Time) error { return nil }
func (netsshConnToNetConnAdatper) SetWriteDeadline(t time.Time) error { return nil }