mirror of
https://github.com/zrepl/zrepl.git
synced 2025-01-13 09:49:32 +01:00
260 lines
5.9 KiB
Go
260 lines
5.9 KiB
Go
package rpc
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"io"
|
|
"reflect"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
type Server struct {
|
|
ml *MessageLayer
|
|
logger Logger
|
|
endpoints map[string]endpointDescr
|
|
}
|
|
|
|
type typeMap struct {
|
|
local reflect.Type
|
|
proto DataType
|
|
}
|
|
type endpointDescr struct {
|
|
inType typeMap
|
|
outType typeMap
|
|
handler reflect.Value
|
|
}
|
|
|
|
type MarshaledJSONEndpoint func(bodyJSON interface{})
|
|
|
|
func NewServer(rwc io.ReadWriteCloser) *Server {
|
|
ml := NewMessageLayer(rwc)
|
|
return &Server{
|
|
ml, noLogger{}, make(map[string]endpointDescr),
|
|
}
|
|
}
|
|
|
|
func (s *Server) SetLogger(logger Logger, logMessageLayer bool) {
|
|
s.logger = logger
|
|
if logMessageLayer {
|
|
s.ml.logger = logger
|
|
} else {
|
|
s.ml.logger = noLogger{}
|
|
}
|
|
}
|
|
|
|
func (s *Server) RegisterEndpoint(name string, handler interface{}) (err error) {
|
|
_, ok := s.endpoints[name]
|
|
if ok {
|
|
return errors.Errorf("already set up an endpoint for '%s'", name)
|
|
}
|
|
s.endpoints[name], err = makeEndpointDescr(handler)
|
|
return
|
|
}
|
|
|
|
func checkResponseHeader(h *Header) (err error) {
|
|
var statusNotSet Status
|
|
if h.Error == statusNotSet {
|
|
return errors.Errorf("status has zero-value")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) writeResponse(h *Header) (err error) {
|
|
// TODO validate
|
|
return s.ml.WriteHeader(h)
|
|
}
|
|
|
|
func (s *Server) recvRequest() (h *Header, err error) {
|
|
h, err = s.ml.ReadHeader()
|
|
if err != nil {
|
|
s.logger.Printf("error reading header: %s", err)
|
|
return nil, err
|
|
}
|
|
|
|
s.logger.Printf("validating request")
|
|
err = nil // TODO validate
|
|
if err == nil {
|
|
return h, nil
|
|
}
|
|
s.logger.Printf("request validation error: %s", err)
|
|
|
|
r := NewErrorHeader(StatusRequestError, "%s", err)
|
|
return nil, s.writeResponse(r)
|
|
}
|
|
|
|
var doneServeNext error = errors.New("this should not cause a HangUp() in the server")
|
|
var doneStopServing error = errors.New("this should cause the server to close the connection")
|
|
|
|
var ProtocolError error = errors.New("protocol error, server should hang up")
|
|
|
|
const ControlEndpointClose string = "Close"
|
|
|
|
// Serve the connection until failure or the client hangs up
|
|
func (s *Server) Serve() (err error) {
|
|
for {
|
|
|
|
err = s.ServeRequest()
|
|
|
|
if err == nil {
|
|
continue
|
|
}
|
|
if err == doneServeNext {
|
|
s.logger.Printf("subroutine returned pseudo-error indicating early-exit")
|
|
err = nil
|
|
continue
|
|
}
|
|
|
|
if err == doneStopServing {
|
|
s.logger.Printf("subroutine returned pseudo-error indicating close request")
|
|
err = nil
|
|
break
|
|
}
|
|
|
|
break
|
|
}
|
|
|
|
if err != nil {
|
|
s.logger.Printf("an error occurred that could not be handled on PRC protocol level: %+v", err)
|
|
}
|
|
|
|
s.logger.Printf("cloing MessageLayer")
|
|
if mlErr := s.ml.Close(); mlErr != nil {
|
|
s.logger.Printf("error closing MessageLayer: %+v", mlErr)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// Serve a single request
|
|
// * wait for request to come in
|
|
// * call handler
|
|
// * reply
|
|
//
|
|
// The connection is left open, the next bytes on the conn should be
|
|
// the next request header.
|
|
//
|
|
// Returns an err != nil if the error is bad enough to hang up on the client.
|
|
// Examples: protocol version mismatches, protocol errors in general, ...
|
|
// Non-Examples: a handler error
|
|
func (s *Server) ServeRequest() (err error) {
|
|
|
|
ml := s.ml
|
|
|
|
s.logger.Printf("reading header")
|
|
h, err := s.recvRequest()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if h.DataType == DataTypeControl {
|
|
switch h.Endpoint {
|
|
case ControlEndpointClose:
|
|
ack := Header{Error: StatusOK, DataType: DataTypeControl}
|
|
err = s.writeResponse(&ack)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return doneStopServing
|
|
default:
|
|
r := NewErrorHeader(StatusRequestError, "unregistered control endpoint %s", h.Endpoint)
|
|
return s.writeResponse(r)
|
|
}
|
|
}
|
|
|
|
ep, ok := s.endpoints[h.Endpoint]
|
|
if !ok {
|
|
r := NewErrorHeader(StatusRequestError, "unregistered endpoint %s", h.Endpoint)
|
|
return s.writeResponse(r)
|
|
}
|
|
|
|
if ep.inType.proto != h.DataType {
|
|
r := NewErrorHeader(StatusRequestError, "wrong DataType for endpoint %s (has %s, you provided %s)", h.Endpoint, ep.inType.proto, h.DataType)
|
|
return s.writeResponse(r)
|
|
}
|
|
|
|
if ep.outType.proto != h.Accept {
|
|
r := NewErrorHeader(StatusRequestError, "wrong Accept for endpoint %s (has %s, you provided %s)", h.Endpoint, ep.outType.proto, h.Accept)
|
|
return s.writeResponse(r)
|
|
}
|
|
|
|
dr := ml.ReadData()
|
|
|
|
// Determine inval
|
|
var inval reflect.Value
|
|
switch ep.inType.proto {
|
|
case DataTypeMarshaledJSON:
|
|
// Unmarshal input
|
|
inval = reflect.New(ep.inType.local.Elem())
|
|
invalIface := inval.Interface()
|
|
err = json.NewDecoder(dr).Decode(invalIface)
|
|
if err != nil {
|
|
r := NewErrorHeader(StatusRequestError, "cannot decode marshaled JSON: %s", err)
|
|
return s.writeResponse(r)
|
|
}
|
|
case DataTypeOctets:
|
|
// Take data as is
|
|
inval = reflect.ValueOf(dr)
|
|
default:
|
|
panic("not implemented")
|
|
}
|
|
|
|
outval := reflect.New(ep.outType.local.Elem()) // outval is a double pointer
|
|
|
|
s.logger.Printf("before handler, inval=%v outval=%v", inval, outval)
|
|
|
|
// Call the handler
|
|
errs := ep.handler.Call([]reflect.Value{inval, outval})
|
|
|
|
if !errs[0].IsNil() {
|
|
he := errs[0].Interface().(error) // we checked that before...
|
|
s.logger.Printf("handler returned error: %s", err)
|
|
r := NewErrorHeader(StatusError, "%s", he.Error())
|
|
return s.writeResponse(r)
|
|
}
|
|
|
|
switch ep.outType.proto {
|
|
|
|
case DataTypeMarshaledJSON:
|
|
|
|
var dataBuf bytes.Buffer
|
|
// Marshal output
|
|
err = json.NewEncoder(&dataBuf).Encode(outval.Interface())
|
|
if err != nil {
|
|
r := NewErrorHeader(StatusServerError, "cannot marshal response: %s", err)
|
|
return s.writeResponse(r)
|
|
}
|
|
|
|
replyHeader := Header{
|
|
Error: StatusOK,
|
|
DataType: ep.outType.proto,
|
|
}
|
|
if err = s.writeResponse(&replyHeader); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = ml.WriteData(&dataBuf); err != nil {
|
|
return
|
|
}
|
|
|
|
case DataTypeOctets:
|
|
|
|
h := Header{
|
|
Error: StatusOK,
|
|
DataType: DataTypeOctets,
|
|
}
|
|
if err = s.writeResponse(&h); err != nil {
|
|
return
|
|
}
|
|
|
|
reader := outval.Interface().(*io.Reader) // we checked that when adding the endpoint
|
|
err = ml.WriteData(*reader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
}
|