package rpc import ( "encoding/json" "errors" "fmt" . "github.com/zrepl/zrepl/model" . "github.com/zrepl/zrepl/util" "io" "reflect" ) type RPCRequester interface { FilesystemRequest(r FilesystemRequest) (roots []Filesystem, err error) InitialTransferRequest(r InitialTransferRequest) (io.Reader, error) IncrementalTransferRequest(r IncrementalTransferRequest) (io.Reader, error) } type RPCHandler interface { HandleFilesystemRequest(r FilesystemRequest) (roots []Filesystem, err error) HandleInitialTransferRequest(r InitialTransferRequest) (io.Reader, error) HandleIncrementalTransferRequest(r IncrementalTransferRequest) (io.Reader, error) } const ByteStreamRPCProtocolVersion = 1 type ByteStreamRPC struct { conn io.ReadWriteCloser encoder *json.Encoder decoder *json.Decoder } func ConnectByteStreamRPC(conn io.ReadWriteCloser) (RPCRequester, error) { // TODO do ssh connection to transport, establish TCP-like communication channel rpc := ByteStreamRPC{ conn: conn, encoder: json.NewEncoder(conn), decoder: json.NewDecoder(conn), } // Assert protocol versions are equal err := rpc.ProtocolVersionRequest() if err != nil { return nil, err } return rpc, nil } func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler) error { // A request consists of two subsequent JSON objects // Object 1: RequestHeader => contains type of Request Body // Object 2: RequestBody, e.g. IncrementalTransferRequest // A response is always a ResponseHeader followed by bytes to be interpreted // as indicated by the ResponseHeader.ResponseType, e.g. // a) a chunked response // b) or another JSON object decoder := json.NewDecoder(conn) encoder := json.NewEncoder(conn) for { var header RequestHeader = RequestHeader{} if err := decoder.Decode(&header); err != nil { respondWithError(conn, EDecodeHeader, err) conn.Close() return err } switch header.Type { case RTProtocolVersionRequest: var rq ByteStreamRPCProtocolVersionRequest if err := decoder.Decode(&rq); err != nil { respondWithError(conn, EDecodeRequestBody, nil) conn.Close() } if rq.ClientVersion != ByteStreamRPCProtocolVersion { respondWithError(conn, EProtocolVersionMismatch, nil) conn.Close() } r := ResponseHeader{ RequestId: header.Id, } if err := encoder.Encode(&r); err != nil { return err } case RTFilesystemRequest: var rq FilesystemRequest if err := decoder.Decode(&rq); err != nil { respondWithError(conn, EDecodeRequestBody, nil) conn.Close() } roots, err := handler.HandleFilesystemRequest(rq) if err != nil { respondWithError(conn, EHandler, err) } else { if err := encoder.Encode(&roots); err != nil { return err } } case RTInitialTransferRequest: var rq InitialTransferRequest if err := decoder.Decode(&rq); err != nil { respondWithError(conn, EDecodeRequestBody, nil) } snapReader, err := handler.HandleInitialTransferRequest(rq) if err != nil { respondWithError(conn, EHandler, err) } else { chunker := NewChunker(snapReader) _, err := io.Copy(conn, &chunker) if err != nil { return err } } case RTIncrementalTransferRequest: var rq IncrementalTransferRequest if err := decoder.Decode(&rq); err != nil { respondWithError(conn, EDecodeRequestBody, nil) } snapReader, err := handler.HandleIncrementalTransferRequest(rq) if err != nil { respondWithError(conn, EHandler, err) } else { chunker := NewChunker(snapReader) _, err := io.Copy(conn, &chunker) if err != nil { return err } } default: respondWithError(conn, EUnknownRequestType, nil) conn.Close() } } return nil } func respondWithError(conn io.Writer, id ErrorId, err error) error { return nil } func inferRequestType(v interface{}) (RequestType, error) { switch v.(type) { case ByteStreamRPCProtocolVersionRequest: return RTProtocolVersionRequest, nil case FilesystemRequest: return RTFilesystemRequest, nil case InitialTransferRequest: return RTInitialTransferRequest, nil default: return 0, errors.New(fmt.Sprintf("cannot infer request type for type '%v'", reflect.TypeOf(v))) } } func genUUID() [16]byte { return [16]byte{} // TODO } func (c ByteStreamRPC) sendRequest(v interface{}) (err error) { var rt RequestType if rt, err = inferRequestType(v); err != nil { return } h := RequestHeader{ Type: rt, Id: genUUID(), } if err = c.encoder.Encode(h); err != nil { return } if err = c.encoder.Encode(v); err != nil { return } return } func (c ByteStreamRPC) expectResponseType(rt ResponseType) (err error) { var h ResponseHeader if err = c.decoder.Decode(&h); err != nil { return } if h.ResponseType != rt { return errors.New("unexpected response type in response header") } return } func (c ByteStreamRPC) sendRequestReceiveHeader(request interface{}, rt ResponseType) (err error) { if err = c.sendRequest(request); err != nil { return err } if err = c.expectResponseType(rt); err != nil { return err } return nil } func (c ByteStreamRPC) ProtocolVersionRequest() (err error) { b := ByteStreamRPCProtocolVersionRequest{ ClientVersion: ByteStreamRPCProtocolVersion, } // OK response means the remote side can cope with our protocol version return c.sendRequestReceiveHeader(b, ROK) } func (c ByteStreamRPC) FilesystemRequest(r FilesystemRequest) (roots []Filesystem, err error) { if err = c.sendRequestReceiveHeader(r, RFilesystems); err != nil { return } roots = make([]Filesystem, 0) if err = c.decoder.Decode(roots); err != nil { return } return } func (c ByteStreamRPC) InitialTransferRequest(r InitialTransferRequest) (unchunker io.Reader, err error) { if err = c.sendRequestReceiveHeader(r, RChunkedStream); err != nil { return } unchunker = NewUnchunker(c.conn) return } func (c ByteStreamRPC) IncrementalTransferRequest(r IncrementalTransferRequest) (unchunker io.Reader, err error) { if err = c.sendRequestReceiveHeader(r, RChunkedStream); err != nil { return } unchunker = NewUnchunker(c.conn) return } type LocalRPC struct { handler RPCHandler } func ConnectLocalRPC(handler RPCHandler) RPCRequester { return LocalRPC{handler} } func (c LocalRPC) FilesystemRequest(r FilesystemRequest) (roots []Filesystem, err error) { return c.handler.HandleFilesystemRequest(r) } func (c LocalRPC) InitialTransferRequest(r InitialTransferRequest) (io.Reader, error) { return c.handler.HandleInitialTransferRequest(r) } func (c LocalRPC) IncrementalTransferRequest(r IncrementalTransferRequest) (reader io.Reader, err error) { reader, err = c.handler.HandleIncrementalTransferRequest(r) return }