zrepl/rpc/server.go
2017-09-30 16:31:55 +02:00

260 lines
5.9 KiB
Go

package rpc
import (
"bytes"
"encoding/json"
"io"
"reflect"
"github.com/pkg/errors"
)
type Server struct {
ml *MessageLayer
logger Logger
endpoints map[string]endpointDescr
}
type typeMap struct {
local reflect.Type
proto DataType
}
type endpointDescr struct {
inType typeMap
outType typeMap
handler reflect.Value
}
type MarshaledJSONEndpoint func(bodyJSON interface{})
func NewServer(rwc io.ReadWriteCloser) *Server {
ml := NewMessageLayer(rwc)
return &Server{
ml, noLogger{}, make(map[string]endpointDescr),
}
}
func (s *Server) SetLogger(logger Logger, logMessageLayer bool) {
s.logger = logger
if logMessageLayer {
s.ml.logger = logger
} else {
s.ml.logger = noLogger{}
}
}
func (s *Server) RegisterEndpoint(name string, handler interface{}) (err error) {
_, ok := s.endpoints[name]
if ok {
return errors.Errorf("already set up an endpoint for '%s'", name)
}
s.endpoints[name], err = makeEndpointDescr(handler)
return
}
func checkResponseHeader(h *Header) (err error) {
var statusNotSet Status
if h.Error == statusNotSet {
return errors.Errorf("status has zero-value")
}
return nil
}
func (s *Server) writeResponse(h *Header) (err error) {
// TODO validate
return s.ml.WriteHeader(h)
}
func (s *Server) recvRequest() (h *Header, err error) {
h, err = s.ml.ReadHeader()
if err != nil {
s.logger.Printf("error reading header: %s", err)
return nil, err
}
s.logger.Printf("validating request")
err = nil // TODO validate
if err == nil {
return h, nil
}
s.logger.Printf("request validation error: %s", err)
r := NewErrorHeader(StatusRequestError, "%s", err)
return nil, s.writeResponse(r)
}
var doneServeNext error = errors.New("this should not cause a HangUp() in the server")
var doneStopServing error = errors.New("this should cause the server to close the connection")
var ProtocolError error = errors.New("protocol error, server should hang up")
const ControlEndpointClose string = "Close"
// Serve the connection until failure or the client hangs up
func (s *Server) Serve() (err error) {
for {
err = s.ServeRequest()
if err == nil {
continue
}
if err == doneServeNext {
s.logger.Printf("subroutine returned pseudo-error indicating early-exit")
err = nil
continue
}
if err == doneStopServing {
s.logger.Printf("subroutine returned pseudo-error indicating close request")
err = nil
break
}
break
}
if err != nil {
s.logger.Printf("an error occurred that could not be handled on PRC protocol level: %+v", err)
}
s.logger.Printf("cloing MessageLayer")
if mlErr := s.ml.Close(); mlErr != nil {
s.logger.Printf("error closing MessageLayer: %+v", mlErr)
}
return err
}
// Serve a single request
// * wait for request to come in
// * call handler
// * reply
//
// The connection is left open, the next bytes on the conn should be
// the next request header.
//
// Returns an err != nil if the error is bad enough to hang up on the client.
// Examples: protocol version mismatches, protocol errors in general, ...
// Non-Examples: a handler error
func (s *Server) ServeRequest() (err error) {
ml := s.ml
s.logger.Printf("reading header")
h, err := s.recvRequest()
if err != nil {
return err
}
if h.DataType == DataTypeControl {
switch h.Endpoint {
case ControlEndpointClose:
ack := Header{Error: StatusOK, DataType: DataTypeControl}
err = s.writeResponse(&ack)
if err != nil {
return err
}
return doneStopServing
default:
r := NewErrorHeader(StatusRequestError, "unregistered control endpoint %s", h.Endpoint)
return s.writeResponse(r)
}
}
ep, ok := s.endpoints[h.Endpoint]
if !ok {
r := NewErrorHeader(StatusRequestError, "unregistered endpoint %s", h.Endpoint)
return s.writeResponse(r)
}
if ep.inType.proto != h.DataType {
r := NewErrorHeader(StatusRequestError, "wrong DataType for endpoint %s (has %s, you provided %s)", h.Endpoint, ep.inType.proto, h.DataType)
return s.writeResponse(r)
}
if ep.outType.proto != h.Accept {
r := NewErrorHeader(StatusRequestError, "wrong Accept for endpoint %s (has %s, you provided %s)", h.Endpoint, ep.outType.proto, h.Accept)
return s.writeResponse(r)
}
dr := ml.ReadData()
// Determine inval
var inval reflect.Value
switch ep.inType.proto {
case DataTypeMarshaledJSON:
// Unmarshal input
inval = reflect.New(ep.inType.local.Elem())
invalIface := inval.Interface()
err = json.NewDecoder(dr).Decode(invalIface)
if err != nil {
r := NewErrorHeader(StatusRequestError, "cannot decode marshaled JSON: %s", err)
return s.writeResponse(r)
}
case DataTypeOctets:
// Take data as is
inval = reflect.ValueOf(dr)
default:
panic("not implemented")
}
outval := reflect.New(ep.outType.local.Elem()) // outval is a double pointer
s.logger.Printf("before handler, inval=%v outval=%v", inval, outval)
// Call the handler
errs := ep.handler.Call([]reflect.Value{inval, outval})
if !errs[0].IsNil() {
he := errs[0].Interface().(error) // we checked that before...
s.logger.Printf("handler returned error: %s", err)
r := NewErrorHeader(StatusError, "%s", he.Error())
return s.writeResponse(r)
}
switch ep.outType.proto {
case DataTypeMarshaledJSON:
var dataBuf bytes.Buffer
// Marshal output
err = json.NewEncoder(&dataBuf).Encode(outval.Interface())
if err != nil {
r := NewErrorHeader(StatusServerError, "cannot marshal response: %s", err)
return s.writeResponse(r)
}
replyHeader := Header{
Error: StatusOK,
DataType: ep.outType.proto,
}
if err = s.writeResponse(&replyHeader); err != nil {
return err
}
if err = ml.WriteData(&dataBuf); err != nil {
return
}
case DataTypeOctets:
h := Header{
Error: StatusOK,
DataType: DataTypeOctets,
}
if err = s.writeResponse(&h); err != nil {
return
}
reader := outval.Interface().(*io.Reader) // we checked that when adding the endpoint
err = ml.WriteData(*reader)
if err != nil {
return err
}
}
return nil
}