rclone/vendor/storj.io/drpc/drpcmux/mux.go
2020-05-12 15:56:50 +00:00

117 lines
2.7 KiB
Go

// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package drpcmux
import (
"reflect"
"github.com/zeebo/errs"
"storj.io/drpc"
)
// Mux is an implementation of Handler to serve drpc connections to the
// appropriate Receivers registered by Descriptions.
type Mux struct {
rpcs map[string]rpcData
}
// New constructs a new Mux.
func New() *Mux {
return &Mux{
rpcs: make(map[string]rpcData),
}
}
var (
streamType = reflect.TypeOf((*drpc.Stream)(nil)).Elem()
messageType = reflect.TypeOf((*drpc.Message)(nil)).Elem()
)
type rpcData struct {
srv interface{}
receiver drpc.Receiver
in1 reflect.Type
in2 reflect.Type
}
// Register associates the rpcs described by the description in the server.
// It returns an error if there was a problem registering it.
func (m *Mux) Register(srv interface{}, desc drpc.Description) error {
n := desc.NumMethods()
for i := 0; i < n; i++ {
rpc, receiver, method, ok := desc.Method(i)
if !ok {
return errs.New("Description returned invalid method for index %d", i)
}
if err := m.registerOne(srv, rpc, receiver, method); err != nil {
return err
}
}
return nil
}
// registerOne does the work to register a single rpc.
func (m *Mux) registerOne(srv interface{}, rpc string, receiver drpc.Receiver, method interface{}) error {
data := rpcData{srv: srv, receiver: receiver}
switch mt := reflect.TypeOf(method); {
// unitary input, unitary output
case mt.NumOut() == 2:
data.in1 = mt.In(2)
if !data.in1.Implements(messageType) {
return errs.New("input argument not a drpc message: %v", data.in1)
}
// unitary input, stream output
case mt.NumIn() == 3:
data.in1 = mt.In(1)
if !data.in1.Implements(messageType) {
return errs.New("input argument not a drpc message: %v", data.in1)
}
data.in2 = streamType
// stream input
case mt.NumIn() == 2:
data.in1 = streamType
// code gen bug?
default:
return errs.New("unknown method type: %v", mt)
}
m.rpcs[rpc] = data
return nil
}
// HandleRPC handles the rpc that has been requested by the stream.
func (m *Mux) HandleRPC(stream drpc.Stream, rpc string) (err error) {
data, ok := m.rpcs[rpc]
if !ok {
return drpc.ProtocolError.New("unknown rpc: %q", rpc)
}
in := interface{}(stream)
if data.in1 != streamType {
msg, ok := reflect.New(data.in1.Elem()).Interface().(drpc.Message)
if !ok {
return drpc.InternalError.New("invalid rpc input type")
}
if err := stream.MsgRecv(msg); err != nil {
return errs.Wrap(err)
}
in = msg
}
out, err := data.receiver(data.srv, stream.Context(), in, stream)
switch {
case err != nil:
return errs.Wrap(err)
case out != nil:
return stream.MsgSend(out)
default:
return nil
}
}