package ssh import ( "context" "fmt" "net" "path" "sync/atomic" "github.com/pkg/errors" "github.com/problame/go-netssh" "github.com/zrepl/zrepl/internal/config" "github.com/zrepl/zrepl/internal/daemon/nethelpers" "github.com/zrepl/zrepl/internal/transport" ) func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (transport.AuthenticatedListenerFactory, error) { for _, ci := range in.ClientIdentities { if err := transport.ValidateClientIdentity(ci); err != nil { return nil, errors.Wrapf(err, "invalid client identity %q", ci) } } clientIdentities := in.ClientIdentities sockdir := g.Serve.StdinServer.SockDir lf := func() (transport.AuthenticatedListener, error) { return multiStdinserverListenerFromClientIdentities(sockdir, clientIdentities) } return lf, nil } type multiStdinserverAcceptRes struct { conn *transport.AuthConn err error } type MultiStdinserverListener struct { listeners []*stdinserverListener accepts chan multiStdinserverAcceptRes closed int32 } // client identities must be validated func multiStdinserverListenerFromClientIdentities(sockdir string, cis []string) (*MultiStdinserverListener, error) { listeners := make([]*stdinserverListener, 0, len(cis)) var err error for _, ci := range cis { sockpath := path.Join(sockdir, ci) l := &stdinserverListener{clientIdentity: ci} if err = nethelpers.PreparePrivateSockpath(sockpath); err != nil { break } if l.l, err = netssh.Listen(sockpath); err != nil { break } listeners = append(listeners, l) } if err != nil { for _, l := range listeners { l.Close() // FIXME error reporting? } return nil, err } return &MultiStdinserverListener{listeners: listeners}, nil } func (m *MultiStdinserverListener) Accept(ctx context.Context) (*transport.AuthConn, error) { if m.accepts == nil { m.accepts = make(chan multiStdinserverAcceptRes, len(m.listeners)) for i := range m.listeners { go func(i int) { for atomic.LoadInt32(&m.closed) == 0 { conn, err := m.listeners[i].Accept(context.TODO()) m.accepts <- multiStdinserverAcceptRes{conn, err} } }(i) } } res := <-m.accepts return res.conn, res.err } type multiListenerAddr struct { clients []string } func (multiListenerAddr) Network() string { return "netssh" } func (l multiListenerAddr) String() string { return fmt.Sprintf("netssh:clients=%v", l.clients) } func (m *MultiStdinserverListener) Addr() net.Addr { cis := make([]string, len(m.listeners)) for i := range cis { cis[i] = m.listeners[i].clientIdentity } return multiListenerAddr{cis} } func (m *MultiStdinserverListener) Close() error { atomic.StoreInt32(&m.closed, 1) var oneErr error for _, l := range m.listeners { if err := l.Close(); err != nil && oneErr == nil { oneErr = err } } return oneErr } // a single stdinserverListener (part of multiStdinserverListener) type stdinserverListener struct { l *netssh.Listener clientIdentity string } type listenerAddr struct { clientIdentity string } func (listenerAddr) Network() string { return "netssh" } func (a listenerAddr) String() string { return fmt.Sprintf("netssh:client=%q", a.clientIdentity) } func (l stdinserverListener) Addr() net.Addr { return listenerAddr{l.clientIdentity} } func (l stdinserverListener) Accept(ctx context.Context) (*transport.AuthConn, error) { c, err := l.l.Accept() if err != nil { return nil, err } return transport.NewAuthConn(c, l.clientIdentity), nil } func (l stdinserverListener) Close() (err error) { return l.l.Close() }