mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-25 09:54:47 +01:00
159 lines
3.9 KiB
Go
159 lines
3.9 KiB
Go
package serve
|
|
|
|
import (
|
|
"github.com/problame/go-netssh"
|
|
"github.com/zrepl/zrepl/config"
|
|
"github.com/zrepl/zrepl/daemon/nethelpers"
|
|
"io"
|
|
"net"
|
|
"path"
|
|
"time"
|
|
"context"
|
|
"github.com/pkg/errors"
|
|
"sync/atomic"
|
|
)
|
|
|
|
type StdinserverListenerFactory struct {
|
|
ClientIdentities []string
|
|
Sockdir string
|
|
}
|
|
|
|
func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (f *multiStdinserverListenerFactory, err error) {
|
|
|
|
for _, ci := range in.ClientIdentities {
|
|
if err := ValidateClientIdentity(ci); err != nil {
|
|
return nil, errors.Wrapf(err, "invalid client identity %q", ci)
|
|
}
|
|
}
|
|
|
|
f = &multiStdinserverListenerFactory{
|
|
ClientIdentities: in.ClientIdentities,
|
|
Sockdir: g.Serve.StdinServer.SockDir,
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
type multiStdinserverListenerFactory struct {
|
|
ClientIdentities []string
|
|
Sockdir string
|
|
}
|
|
|
|
func (f *multiStdinserverListenerFactory) Listen() (AuthenticatedListener, error) {
|
|
return multiStdinserverListenerFromClientIdentities(f.Sockdir, f.ClientIdentities)
|
|
}
|
|
|
|
type multiStdinserverAcceptRes struct {
|
|
conn AuthenticatedConn
|
|
err error
|
|
}
|
|
|
|
type MultiStdinserverListener struct {
|
|
listeners []*stdinserverListener
|
|
accepts chan multiStdinserverAcceptRes
|
|
closed int32
|
|
}
|
|
|
|
// client identities must be validated
|
|
func multiStdinserverListenerFromClientIdentities(sockdir string, cis []string) (*MultiStdinserverListener, error) {
|
|
listeners := make([]*stdinserverListener, 0, len(cis))
|
|
var err error
|
|
for _, ci := range cis {
|
|
sockpath := path.Join(sockdir, ci)
|
|
l := &stdinserverListener{clientIdentity: ci}
|
|
if err = nethelpers.PreparePrivateSockpath(sockpath); err != nil {
|
|
break
|
|
}
|
|
if l.l, err = netssh.Listen(sockpath); err != nil {
|
|
break
|
|
}
|
|
listeners = append(listeners, l)
|
|
}
|
|
if err != nil {
|
|
for _, l := range listeners {
|
|
l.Close() // FIXME error reporting?
|
|
}
|
|
return nil, err
|
|
}
|
|
return &MultiStdinserverListener{listeners: listeners}, nil
|
|
}
|
|
|
|
func (m *MultiStdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error){
|
|
|
|
if m.accepts == nil {
|
|
m.accepts = make(chan multiStdinserverAcceptRes, len(m.listeners))
|
|
for i := range m.listeners {
|
|
go func(i int) {
|
|
for atomic.LoadInt32(&m.closed) == 0 {
|
|
conn, err := m.listeners[i].Accept(context.TODO())
|
|
m.accepts <- multiStdinserverAcceptRes{conn, err}
|
|
}
|
|
}(i)
|
|
}
|
|
}
|
|
|
|
res := <- m.accepts
|
|
return res.conn, res.err
|
|
|
|
}
|
|
|
|
func (m *MultiStdinserverListener) Addr() (net.Addr) {
|
|
return netsshAddr{}
|
|
}
|
|
|
|
func (m *MultiStdinserverListener) Close() error {
|
|
atomic.StoreInt32(&m.closed, 1)
|
|
var oneErr error
|
|
for _, l := range m.listeners {
|
|
if err := l.Close(); err != nil && oneErr == nil {
|
|
oneErr = err
|
|
}
|
|
}
|
|
return oneErr
|
|
}
|
|
|
|
// a single stdinserverListener (part of multiStinserverListener)
|
|
type stdinserverListener struct {
|
|
l *netssh.Listener
|
|
clientIdentity string
|
|
}
|
|
|
|
func (l stdinserverListener) Addr() net.Addr {
|
|
return netsshAddr{}
|
|
}
|
|
|
|
func (l stdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
|
|
c, err := l.l.Accept()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return netsshConnToNetConnAdatper{c, l.clientIdentity}, nil
|
|
}
|
|
|
|
func (l stdinserverListener) Close() (err error) {
|
|
return l.l.Close()
|
|
}
|
|
|
|
type netsshAddr struct{}
|
|
|
|
func (netsshAddr) Network() string { return "netssh" }
|
|
func (netsshAddr) String() string { return "???" }
|
|
|
|
type netsshConnToNetConnAdatper struct {
|
|
io.ReadWriteCloser // works for both netssh.SSHConn and netssh.ServeConn
|
|
clientIdentity string
|
|
}
|
|
|
|
func (a netsshConnToNetConnAdatper) ClientIdentity() string { return a.clientIdentity }
|
|
|
|
func (netsshConnToNetConnAdatper) LocalAddr() net.Addr { return netsshAddr{} }
|
|
|
|
func (netsshConnToNetConnAdatper) RemoteAddr() net.Addr { return netsshAddr{} }
|
|
|
|
// FIXME log warning once!
|
|
func (netsshConnToNetConnAdatper) SetDeadline(t time.Time) error { return nil }
|
|
|
|
func (netsshConnToNetConnAdatper) SetReadDeadline(t time.Time) error { return nil }
|
|
|
|
func (netsshConnToNetConnAdatper) SetWriteDeadline(t time.Time) error { return nil }
|