mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-24 17:35:01 +01:00
150 lines
3.5 KiB
Go
150 lines
3.5 KiB
Go
package ssh
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"path"
|
|
"sync/atomic"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/problame/go-netssh"
|
|
|
|
"github.com/zrepl/zrepl/internal/config"
|
|
"github.com/zrepl/zrepl/internal/daemon/nethelpers"
|
|
"github.com/zrepl/zrepl/internal/transport"
|
|
)
|
|
|
|
func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (transport.AuthenticatedListenerFactory, error) {
|
|
|
|
for _, ci := range in.ClientIdentities {
|
|
if err := transport.ValidateClientIdentity(ci); err != nil {
|
|
return nil, errors.Wrapf(err, "invalid client identity %q", ci)
|
|
}
|
|
}
|
|
|
|
clientIdentities := in.ClientIdentities
|
|
sockdir := g.Serve.StdinServer.SockDir
|
|
|
|
lf := func() (transport.AuthenticatedListener, error) {
|
|
return multiStdinserverListenerFromClientIdentities(sockdir, clientIdentities)
|
|
}
|
|
|
|
return lf, nil
|
|
}
|
|
|
|
type multiStdinserverAcceptRes struct {
|
|
conn *transport.AuthConn
|
|
err error
|
|
}
|
|
|
|
type MultiStdinserverListener struct {
|
|
listeners []*stdinserverListener
|
|
accepts chan multiStdinserverAcceptRes
|
|
closed int32
|
|
}
|
|
|
|
// client identities must be validated
|
|
func multiStdinserverListenerFromClientIdentities(sockdir string, cis []string) (*MultiStdinserverListener, error) {
|
|
listeners := make([]*stdinserverListener, 0, len(cis))
|
|
var err error
|
|
for _, ci := range cis {
|
|
sockpath := path.Join(sockdir, ci)
|
|
l := &stdinserverListener{clientIdentity: ci}
|
|
if err = nethelpers.PreparePrivateSockpath(sockpath); err != nil {
|
|
break
|
|
}
|
|
if l.l, err = netssh.Listen(sockpath); err != nil {
|
|
break
|
|
}
|
|
listeners = append(listeners, l)
|
|
}
|
|
if err != nil {
|
|
for _, l := range listeners {
|
|
l.Close() // FIXME error reporting?
|
|
}
|
|
return nil, err
|
|
}
|
|
return &MultiStdinserverListener{listeners: listeners}, nil
|
|
}
|
|
|
|
func (m *MultiStdinserverListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
|
|
|
|
if m.accepts == nil {
|
|
m.accepts = make(chan multiStdinserverAcceptRes, len(m.listeners))
|
|
for i := range m.listeners {
|
|
go func(i int) {
|
|
for atomic.LoadInt32(&m.closed) == 0 {
|
|
conn, err := m.listeners[i].Accept(context.TODO())
|
|
m.accepts <- multiStdinserverAcceptRes{conn, err}
|
|
}
|
|
}(i)
|
|
}
|
|
}
|
|
|
|
res := <-m.accepts
|
|
return res.conn, res.err
|
|
|
|
}
|
|
|
|
type multiListenerAddr struct {
|
|
clients []string
|
|
}
|
|
|
|
func (multiListenerAddr) Network() string { return "netssh" }
|
|
|
|
func (l multiListenerAddr) String() string {
|
|
return fmt.Sprintf("netssh:clients=%v", l.clients)
|
|
}
|
|
|
|
func (m *MultiStdinserverListener) Addr() net.Addr {
|
|
cis := make([]string, len(m.listeners))
|
|
for i := range cis {
|
|
cis[i] = m.listeners[i].clientIdentity
|
|
}
|
|
return multiListenerAddr{cis}
|
|
}
|
|
|
|
func (m *MultiStdinserverListener) Close() error {
|
|
atomic.StoreInt32(&m.closed, 1)
|
|
var oneErr error
|
|
for _, l := range m.listeners {
|
|
if err := l.Close(); err != nil && oneErr == nil {
|
|
oneErr = err
|
|
}
|
|
}
|
|
return oneErr
|
|
}
|
|
|
|
// a single stdinserverListener (part of multiStdinserverListener)
|
|
type stdinserverListener struct {
|
|
l *netssh.Listener
|
|
clientIdentity string
|
|
}
|
|
|
|
type listenerAddr struct {
|
|
clientIdentity string
|
|
}
|
|
|
|
func (listenerAddr) Network() string { return "netssh" }
|
|
|
|
func (a listenerAddr) String() string {
|
|
return fmt.Sprintf("netssh:client=%q", a.clientIdentity)
|
|
}
|
|
|
|
func (l stdinserverListener) Addr() net.Addr {
|
|
return listenerAddr{l.clientIdentity}
|
|
}
|
|
|
|
func (l stdinserverListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
|
|
c, err := l.l.Accept()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return transport.NewAuthConn(c, l.clientIdentity), nil
|
|
}
|
|
|
|
func (l stdinserverListener) Close() (err error) {
|
|
return l.l.Close()
|
|
}
|