zrepl/internal/transport/ssh/serve_stdinserver.go
2024-10-18 19:21:17 +02:00

150 lines
3.5 KiB
Go

package ssh
import (
"context"
"fmt"
"net"
"path"
"sync/atomic"
"github.com/pkg/errors"
"github.com/problame/go-netssh"
"github.com/zrepl/zrepl/internal/config"
"github.com/zrepl/zrepl/internal/daemon/nethelpers"
"github.com/zrepl/zrepl/internal/transport"
)
func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (transport.AuthenticatedListenerFactory, error) {
for _, ci := range in.ClientIdentities {
if err := transport.ValidateClientIdentity(ci); err != nil {
return nil, errors.Wrapf(err, "invalid client identity %q", ci)
}
}
clientIdentities := in.ClientIdentities
sockdir := g.Serve.StdinServer.SockDir
lf := func() (transport.AuthenticatedListener, error) {
return multiStdinserverListenerFromClientIdentities(sockdir, clientIdentities)
}
return lf, nil
}
type multiStdinserverAcceptRes struct {
conn *transport.AuthConn
err error
}
type MultiStdinserverListener struct {
listeners []*stdinserverListener
accepts chan multiStdinserverAcceptRes
closed int32
}
// client identities must be validated
func multiStdinserverListenerFromClientIdentities(sockdir string, cis []string) (*MultiStdinserverListener, error) {
listeners := make([]*stdinserverListener, 0, len(cis))
var err error
for _, ci := range cis {
sockpath := path.Join(sockdir, ci)
l := &stdinserverListener{clientIdentity: ci}
if err = nethelpers.PreparePrivateSockpath(sockpath); err != nil {
break
}
if l.l, err = netssh.Listen(sockpath); err != nil {
break
}
listeners = append(listeners, l)
}
if err != nil {
for _, l := range listeners {
l.Close() // FIXME error reporting?
}
return nil, err
}
return &MultiStdinserverListener{listeners: listeners}, nil
}
func (m *MultiStdinserverListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
if m.accepts == nil {
m.accepts = make(chan multiStdinserverAcceptRes, len(m.listeners))
for i := range m.listeners {
go func(i int) {
for atomic.LoadInt32(&m.closed) == 0 {
conn, err := m.listeners[i].Accept(context.TODO())
m.accepts <- multiStdinserverAcceptRes{conn, err}
}
}(i)
}
}
res := <-m.accepts
return res.conn, res.err
}
type multiListenerAddr struct {
clients []string
}
func (multiListenerAddr) Network() string { return "netssh" }
func (l multiListenerAddr) String() string {
return fmt.Sprintf("netssh:clients=%v", l.clients)
}
func (m *MultiStdinserverListener) Addr() net.Addr {
cis := make([]string, len(m.listeners))
for i := range cis {
cis[i] = m.listeners[i].clientIdentity
}
return multiListenerAddr{cis}
}
func (m *MultiStdinserverListener) Close() error {
atomic.StoreInt32(&m.closed, 1)
var oneErr error
for _, l := range m.listeners {
if err := l.Close(); err != nil && oneErr == nil {
oneErr = err
}
}
return oneErr
}
// a single stdinserverListener (part of multiStdinserverListener)
type stdinserverListener struct {
l *netssh.Listener
clientIdentity string
}
type listenerAddr struct {
clientIdentity string
}
func (listenerAddr) Network() string { return "netssh" }
func (a listenerAddr) String() string {
return fmt.Sprintf("netssh:client=%q", a.clientIdentity)
}
func (l stdinserverListener) Addr() net.Addr {
return listenerAddr{l.clientIdentity}
}
func (l stdinserverListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
c, err := l.l.Accept()
if err != nil {
return nil, err
}
return transport.NewAuthConn(c, l.clientIdentity), nil
}
func (l stdinserverListener) Close() (err error) {
return l.l.Close()
}