mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-28 19:34:58 +01:00
255 lines
6.0 KiB
Go
255 lines
6.0 KiB
Go
// Package transportmux wraps a transport.{Connecter,AuthenticatedListener}
|
|
// to distinguish different connection types based on a label
|
|
// sent from client to server on connection establishment.
|
|
//
|
|
// Labels are plain text and fixed length.
|
|
package transportmux
|
|
|
|
import (
|
|
"context"
|
|
"sync/atomic"
|
|
"syscall"
|
|
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"time"
|
|
|
|
"github.com/zrepl/zrepl/internal/daemon/logging"
|
|
"github.com/zrepl/zrepl/internal/logger"
|
|
"github.com/zrepl/zrepl/internal/transport"
|
|
)
|
|
|
|
type Logger = logger.Logger
|
|
|
|
func getLog(ctx context.Context) Logger {
|
|
return logging.GetLogger(ctx, logging.SubsysTransportMux)
|
|
}
|
|
|
|
type acceptRes struct {
|
|
conn *transport.AuthConn
|
|
err error
|
|
}
|
|
|
|
type demuxListener struct {
|
|
closed int32
|
|
conns chan acceptRes
|
|
}
|
|
|
|
var ErrClosed = &net.OpError{
|
|
Op: "accept",
|
|
Net: "demux",
|
|
Source: nil,
|
|
Addr: nil,
|
|
Err: syscall.EINVAL,
|
|
}
|
|
|
|
func (l *demuxListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
|
|
if atomic.LoadInt32(&l.closed) != 0 {
|
|
return nil, ErrClosed
|
|
}
|
|
select {
|
|
case r, ok := <-l.conns:
|
|
if !ok {
|
|
return nil, ErrClosed
|
|
}
|
|
return r.conn, r.err
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
type demuxAddr struct{}
|
|
|
|
func (demuxAddr) Network() string { return "demux" }
|
|
func (demuxAddr) String() string { return "demux" }
|
|
|
|
func (l *demuxListener) Addr() net.Addr {
|
|
return demuxAddr{}
|
|
}
|
|
|
|
func (l *demuxListener) Close() error {
|
|
atomic.StoreInt32(&l.closed, 1)
|
|
return nil
|
|
}
|
|
|
|
// Exact length of a label in bytes (0-byte padded if it is shorter).
|
|
// This is a protocol constant, changing it breaks the wire protocol.
|
|
const LabelLen = 64
|
|
|
|
func padLabel(out []byte, label string) error {
|
|
if len(label) > LabelLen {
|
|
return fmt.Errorf("label %q exceeds max length (is %d, max %d)", label, len(label), LabelLen)
|
|
}
|
|
if len(out) != LabelLen {
|
|
panic(fmt.Sprintf("implementation error: %d", out))
|
|
}
|
|
labelBytes := []byte(label)
|
|
copy(out[:], labelBytes)
|
|
return nil
|
|
}
|
|
|
|
func Demux(ctx context.Context, rawListener transport.AuthenticatedListener, labels []string, timeout time.Duration) (map[string]transport.AuthenticatedListener, error) {
|
|
|
|
padded := make(map[[64]byte]*demuxListener, len(labels))
|
|
ret := make(map[string]transport.AuthenticatedListener, len(labels))
|
|
for _, label := range labels {
|
|
var labelPadded [LabelLen]byte
|
|
err := padLabel(labelPadded[:], label)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if _, ok := padded[labelPadded]; ok {
|
|
return nil, fmt.Errorf("duplicate label %q", label)
|
|
}
|
|
dl := &demuxListener{
|
|
closed: 0,
|
|
conns: make(chan acceptRes, 1),
|
|
}
|
|
padded[labelPadded] = dl
|
|
ret[label] = dl
|
|
}
|
|
|
|
// invariant: padded contains same-length, non-duplicate labels
|
|
|
|
go func() {
|
|
<-ctx.Done()
|
|
getLog(ctx).Debug("context cancelled, closing listener")
|
|
if err := rawListener.Close(); err != nil {
|
|
getLog(ctx).WithError(err).Error("error closing listener")
|
|
}
|
|
|
|
drainConns := func(ch chan acceptRes) {
|
|
for c := range ch {
|
|
if c.conn != nil {
|
|
if err := c.conn.Close(); err != nil {
|
|
getLog(ctx).WithError(err).Error("error closing connection while draining after listener was closed")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
for _, dl := range ret {
|
|
atomic.StoreInt32(&dl.(*demuxListener).closed, 1)
|
|
drainConns(dl.(*demuxListener).conns)
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
defer func() {
|
|
for _, dl := range ret {
|
|
close(dl.(*demuxListener).conns)
|
|
}
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
getLog(ctx).WithError(ctx.Err()).Info("stop accepting new connections after context done")
|
|
return
|
|
default:
|
|
}
|
|
|
|
rawConn, err := rawListener.Accept(ctx)
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
getLog(ctx).WithError(err).WithField("errType", fmt.Sprintf("%T", err)).Error("accept error")
|
|
continue
|
|
}
|
|
closeConn := func() {
|
|
if err := rawConn.Close(); err != nil {
|
|
getLog(ctx).WithError(err).Error("cannot close conn")
|
|
}
|
|
}
|
|
|
|
if err := rawConn.SetDeadline(time.Now().Add(timeout)); err != nil {
|
|
getLog(ctx).WithError(err).Error("SetDeadline failed")
|
|
closeConn()
|
|
continue
|
|
}
|
|
|
|
var labelBuf [LabelLen]byte
|
|
if _, err := io.ReadFull(rawConn, labelBuf[:]); err != nil {
|
|
getLog(ctx).WithError(err).Error("error reading label")
|
|
closeConn()
|
|
continue
|
|
}
|
|
|
|
demuxListener, ok := padded[labelBuf]
|
|
if !ok {
|
|
getLog(ctx).WithError(err).
|
|
WithField("client_label", fmt.Sprintf("%q", labelBuf)).
|
|
Error("unknown client label")
|
|
closeConn()
|
|
continue
|
|
}
|
|
|
|
err = rawConn.SetDeadline(time.Time{})
|
|
if err != nil {
|
|
getLog(ctx).WithError(err).Error("cannot reset deadline")
|
|
}
|
|
demuxListener.conns <- acceptRes{conn: rawConn, err: nil}
|
|
}
|
|
}()
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
type labeledConnecter struct {
|
|
label []byte
|
|
transport.Connecter
|
|
}
|
|
|
|
func (c labeledConnecter) Connect(ctx context.Context) (transport.Wire, error) {
|
|
conn, err := c.Connecter.Connect(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
closeConn := func(why error) {
|
|
getLog(ctx).WithField("reason", why.Error()).Debug("closing connection")
|
|
if err := conn.Close(); err != nil {
|
|
getLog(ctx).WithError(err).Error("error closing connection after label write error")
|
|
}
|
|
}
|
|
|
|
if dl, ok := ctx.Deadline(); ok {
|
|
defer func() {
|
|
err := conn.SetDeadline(time.Time{})
|
|
if err != nil {
|
|
getLog(ctx).WithError(err).Error("cannot reset deadline")
|
|
}
|
|
}()
|
|
if err := conn.SetDeadline(dl); err != nil {
|
|
closeConn(err)
|
|
return nil, err
|
|
}
|
|
}
|
|
n, err := conn.Write(c.label)
|
|
if err != nil {
|
|
closeConn(err)
|
|
return nil, err
|
|
}
|
|
if n != len(c.label) {
|
|
closeConn(fmt.Errorf("short label write"))
|
|
return nil, io.ErrShortWrite
|
|
}
|
|
return conn, nil
|
|
}
|
|
|
|
func MuxConnecter(rawConnecter transport.Connecter, labels []string, timeout time.Duration) (map[string]transport.Connecter, error) {
|
|
ret := make(map[string]transport.Connecter, len(labels))
|
|
for _, label := range labels {
|
|
var paddedLabel [LabelLen]byte
|
|
if err := padLabel(paddedLabel[:], label); err != nil {
|
|
return nil, err
|
|
}
|
|
lc := &labeledConnecter{paddedLabel[:], rawConnecter}
|
|
if _, ok := ret[label]; ok {
|
|
return nil, fmt.Errorf("duplicate label %q", label)
|
|
}
|
|
ret[label] = lc
|
|
}
|
|
return ret, nil
|
|
}
|