diff --git a/rpc/client.go b/rpc/client.go index 97b5546..262b9d1 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -28,10 +28,35 @@ func (c *Client) SetLogger(logger Logger, logMessageLayer bool) { } func (c *Client) Close() (err error) { - err = c.ml.HangUp() - if err == RST { - return nil + + c.logger.Printf("sending Close request") + header := Header{ + DataType: DataTypeControl, + Endpoint: ControlEndpointClose, + Accept: DataTypeControl, } + err = c.ml.WriteHeader(&header) + if err != nil { + return + } + + c.logger.Printf("reading Close ACK") + ack, err := c.ml.ReadHeader() + if err != nil { + return err + } + c.logger.Printf("received Close ACK: %#v", ack) + if ack.Error != StatusOK { + err = errors.Errorf("error hanging up: remote error (%s) %s", ack.Error, ack.ErrorMessage) + return + } + + c.logger.Printf("closing MessageLayer") + if err = c.ml.Close(); err != nil { + c.logger.Printf("error closing RWC: %+v", err) + return + } + return err } diff --git a/rpc/frame_layer.go b/rpc/frame_layer.go index 1316ec5..47e5d4b 100644 --- a/rpc/frame_layer.go +++ b/rpc/frame_layer.go @@ -60,6 +60,7 @@ type DataType uint8 const ( DataTypeNone DataType = 1 + iota + DataTypeControl DataTypeMarshaledJSON DataTypeOctets ) @@ -83,12 +84,14 @@ func NewFrameBridgingReader(l *MessageLayer, frameType FrameType, totalLimit int func (r *frameBridgingReader) Read(b []byte) (n int, err error) { if r.bytesLeftToLimit == 0 { + r.l.logger.Printf("limit reached, returning EOF") return 0, io.EOF } log := r.l.logger if r.f.PayloadLength == 0 { if r.f.NoMoreFrames { + r.l.logger.Printf("no more frames flag set, returning EOF") err = io.EOF return } @@ -96,6 +99,7 @@ func (r *frameBridgingReader) Read(b []byte) (n int, err error) { log.Printf("reading frame") r.f, err = r.l.readFrame() if err != nil { + log.Printf("error reading frame: %+v", err) return 0, err } log.Printf("read frame: %#v", r.f) @@ -197,22 +201,16 @@ func NewMessageLayer(rwc io.ReadWriteCloser) *MessageLayer { return &MessageLayer{rwc, noLogger{}} } -// Always returns an error, RST error if no error occurred while sending RST frame -func (l *MessageLayer) HangUp() (err error) { - l.logger.Printf("hanging up") +func (l *MessageLayer) Close() (err error) { f := Frame{ Type: FrameTypeRST, NoMoreFrames: true, } - rstFrameError := l.writeFrame(f) - closeErr := l.rwc.Close() - if rstFrameError != nil { - return errors.WithStack(rstFrameError) - } else if closeErr != nil { - return errors.WithStack(closeErr) - } else { - return RST + if err = l.writeFrame(f); err != nil { + l.logger.Printf("error sending RST frame: %s", err) + return errors.WithStack(err) } + return nil } var RST error = fmt.Errorf("reset frame observed on connection") @@ -234,6 +232,7 @@ func (l *MessageLayer) readFrame() (f Frame, err error) { return } if f.Type == FrameTypeRST { + l.logger.Printf("read RST frame") err = RST return } diff --git a/rpc/server.go b/rpc/server.go index f9f59e4..14b4521 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -84,9 +84,12 @@ func (s *Server) recvRequest() (h *Header, err error) { } var doneServeNext error = errors.New("this should not cause a HangUp() in the server") +var doneStopServing error = errors.New("this should cause the server to close the connection") var ProtocolError error = errors.New("protocol error, server should hang up") +const ControlEndpointClose string = "Close" + // Serve the connection until failure or the client hangs up func (s *Server) Serve() (err error) { for { @@ -96,16 +99,31 @@ func (s *Server) Serve() (err error) { if err == nil { continue } - if err == doneServeNext { s.logger.Printf("subroutine returned pseudo-error indicating early-exit") + err = nil continue } - s.logger.Printf("hanging up after ServeRequest returned error: %s", err) - s.ml.HangUp() - return err + if err == doneStopServing { + s.logger.Printf("subroutine returned pseudo-error indicating close request") + err = nil + break + } + + break } + + if err != nil { + s.logger.Printf("an error occurred that could not be handled on PRC protocol level: %+v", err) + } + + s.logger.Printf("cloing MessageLayer") + if mlErr := s.ml.Close(); mlErr != nil { + s.logger.Printf("error closing MessageLayer: %+v", mlErr) + } + + return err } // Serve a single request @@ -129,6 +147,22 @@ func (s *Server) ServeRequest() (err error) { return err } + if h.DataType == DataTypeControl { + switch h.Endpoint { + case ControlEndpointClose: + ack := Header{Error: StatusOK, DataType: DataTypeControl} + err = s.writeResponse(&ack) + if err != nil { + return err + } + return doneStopServing + default: + r := NewErrorHeader(StatusRequestError, "unregistered control endpoint %s", h.Endpoint) + return s.writeResponse(r) + } + panic("implementation error") + } + ep, ok := s.endpoints[h.Endpoint] if !ok { r := NewErrorHeader(StatusRequestError, "unregistered endpoint %s", h.Endpoint)