// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package rpcreplay import ( "bufio" "encoding/binary" "errors" "fmt" "io" "os" "sync" "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" pb "cloud.google.com/go/internal/rpcreplay/proto/rpcreplay" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/any" spb "google.golang.org/genproto/googleapis/rpc/status" ) // A Recorder records RPCs for later playback. type Recorder struct { mu sync.Mutex w *bufio.Writer f *os.File next int err error } // NewRecorder creates a recorder that writes to filename. The file will // also store the initial bytes for retrieval during replay. // // You must call Close on the Recorder to ensure that all data is written. func NewRecorder(filename string, initial []byte) (*Recorder, error) { f, err := os.Create(filename) if err != nil { return nil, err } rec, err := NewRecorderWriter(f, initial) if err != nil { _ = f.Close() return nil, err } rec.f = f return rec, nil } // NewRecorderWriter creates a recorder that writes to w. The initial // bytes will also be written to w for retrieval during replay. // // You must call Close on the Recorder to ensure that all data is written. func NewRecorderWriter(w io.Writer, initial []byte) (*Recorder, error) { bw := bufio.NewWriter(w) if err := writeHeader(bw, initial); err != nil { return nil, err } return &Recorder{w: bw, next: 1}, nil } // DialOptions returns the options that must be passed to grpc.Dial // to enable recording. func (r *Recorder) DialOptions() []grpc.DialOption { return []grpc.DialOption{ grpc.WithUnaryInterceptor(r.interceptUnary), grpc.WithStreamInterceptor(r.interceptStream), } } // Close saves any unwritten information. func (r *Recorder) Close() error { r.mu.Lock() defer r.mu.Unlock() if r.err != nil { return r.err } err := r.w.Flush() if r.f != nil { if err2 := r.f.Close(); err == nil { err = err2 } } return err } // Intercepts all unary (non-stream) RPCs. func (r *Recorder) interceptUnary(ctx context.Context, method string, req, res interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { ereq := &entry{ kind: pb.Entry_REQUEST, method: method, msg: message{msg: req.(proto.Message)}, } refIndex, err := r.writeEntry(ereq) if err != nil { return err } ierr := invoker(ctx, method, req, res, cc, opts...) eres := &entry{ kind: pb.Entry_RESPONSE, refIndex: refIndex, } // If the error is not a gRPC status, then something more // serious is wrong. More significantly, we have no way // of serializing an arbitrary error. So just return it // without recording the response. if _, ok := status.FromError(ierr); !ok { r.mu.Lock() r.err = fmt.Errorf("saw non-status error in %s response: %v (%T)", method, ierr, ierr) r.mu.Unlock() return ierr } eres.msg.set(res, ierr) if _, err := r.writeEntry(eres); err != nil { return err } return ierr } func (r *Recorder) writeEntry(e *entry) (int, error) { r.mu.Lock() defer r.mu.Unlock() if r.err != nil { return 0, r.err } err := writeEntry(r.w, e) if err != nil { r.err = err return 0, err } n := r.next r.next++ return n, nil } func (r *Recorder) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { cstream, serr := streamer(ctx, desc, cc, method, opts...) e := &entry{ kind: pb.Entry_CREATE_STREAM, method: method, } e.msg.set(nil, serr) refIndex, err := r.writeEntry(e) if err != nil { return nil, err } return &recClientStream{ ctx: ctx, rec: r, cstream: cstream, refIndex: refIndex, }, serr } // A recClientStream implements the gprc.ClientStream interface. // It behaves exactly like the default ClientStream, but also // records all messages sent and received. type recClientStream struct { ctx context.Context rec *Recorder cstream grpc.ClientStream refIndex int } func (rcs *recClientStream) Context() context.Context { return rcs.ctx } func (rcs *recClientStream) SendMsg(m interface{}) error { serr := rcs.cstream.SendMsg(m) e := &entry{ kind: pb.Entry_SEND, refIndex: rcs.refIndex, } e.msg.set(m, serr) if _, err := rcs.rec.writeEntry(e); err != nil { return err } return serr } func (rcs *recClientStream) RecvMsg(m interface{}) error { serr := rcs.cstream.RecvMsg(m) e := &entry{ kind: pb.Entry_RECV, refIndex: rcs.refIndex, } e.msg.set(m, serr) if _, err := rcs.rec.writeEntry(e); err != nil { return err } return serr } func (rcs *recClientStream) Header() (metadata.MD, error) { // TODO(jba): record. return rcs.cstream.Header() } func (rcs *recClientStream) Trailer() metadata.MD { // TODO(jba): record. return rcs.cstream.Trailer() } func (rcs *recClientStream) CloseSend() error { // TODO(jba): record. return rcs.cstream.CloseSend() } // A Replayer replays a set of RPCs saved by a Recorder. type Replayer struct { initial []byte // initial state log func(format string, v ...interface{}) // for debugging mu sync.Mutex calls []*call } // A call represents a unary RPC, with a request and response (or error). type call struct { method string request proto.Message response message } // NewReplayer creates a Replayer that reads from filename. func NewReplayer(filename string) (*Replayer, error) { f, err := os.Open(filename) if err != nil { return nil, err } defer f.Close() return NewReplayerReader(f) } // NewReplayerReader creates a Replayer that reads from r. func NewReplayerReader(r io.Reader) (*Replayer, error) { rep := &Replayer{ log: func(string, ...interface{}) {}, } if err := rep.read(r); err != nil { return nil, err } return rep, nil } // read reads the stream of recorded entries. // It matches requests with responses, with each pair grouped // into a call struct. func (rep *Replayer) read(r io.Reader) error { r = bufio.NewReader(r) bytes, err := readHeader(r) if err != nil { return err } rep.initial = bytes callsByIndex := map[int]*call{} for i := 1; ; i++ { e, err := readEntry(r) if err != nil { return err } if e == nil { break } switch e.kind { case pb.Entry_REQUEST: callsByIndex[i] = &call{ method: e.method, request: e.msg.msg, } case pb.Entry_RESPONSE: call := callsByIndex[e.refIndex] if call == nil { return fmt.Errorf("replayer: no request for response #%d", i) } delete(callsByIndex, e.refIndex) call.response = e.msg rep.calls = append(rep.calls, call) default: return fmt.Errorf("replayer: unknown kind %s", e.kind) } } if len(callsByIndex) > 0 { return fmt.Errorf("replayer: %d unmatched requests", len(callsByIndex)) } return nil } // DialOptions returns the options that must be passed to grpc.Dial // to enable replaying. func (r *Replayer) DialOptions() []grpc.DialOption { return []grpc.DialOption{ // On replay, we make no RPCs, which means the connection may be closed // before the normally async Dial completes. Making the Dial synchronous // fixes that. grpc.WithBlock(), grpc.WithUnaryInterceptor(r.interceptUnary), } } // Initial returns the initial state saved by the Recorder. func (r *Replayer) Initial() []byte { return r.initial } // SetLogFunc sets a function to be used for debug logging. The function // should be safe to be called from multiple goroutines. func (r *Replayer) SetLogFunc(f func(format string, v ...interface{})) { r.log = f } // Close closes the Replayer. func (r *Replayer) Close() error { return nil } func (r *Replayer) interceptUnary(_ context.Context, method string, req, res interface{}, _ *grpc.ClientConn, _ grpc.UnaryInvoker, _ ...grpc.CallOption) error { mreq := req.(proto.Message) r.log("request %s (%s)", method, req) call := r.extractCall(method, mreq) if call == nil { return fmt.Errorf("replayer: request not found: %s", mreq) } r.log("returning %v", call.response) if call.response.err != nil { return call.response.err } proto.Merge(res.(proto.Message), call.response.msg) // copy msg into res return nil } // extractCall finds the first call in the list with the same method // and request. It returns nil if it can't find such a call. func (r *Replayer) extractCall(method string, req proto.Message) *call { r.mu.Lock() defer r.mu.Unlock() for i, call := range r.calls { if call == nil { continue } if method == call.method && proto.Equal(req, call.request) { r.calls[i] = nil // nil out this call so we don't reuse it return call } } return nil } // Fprint reads the entries from filename and writes them to w in human-readable form. // It is intended for debugging. func Fprint(w io.Writer, filename string) error { f, err := os.Open(filename) if err != nil { return err } defer f.Close() return FprintReader(w, f) } // FprintReader reads the entries from r and writes them to w in human-readable form. // It is intended for debugging. func FprintReader(w io.Writer, r io.Reader) error { initial, err := readHeader(r) if err != nil { return err } fmt.Fprintf(w, "initial state: %q\n", string(initial)) for i := 1; ; i++ { e, err := readEntry(r) if err != nil { return err } if e == nil { return nil } s := "message" if e.msg.err != nil { s = "error" } fmt.Fprintf(w, "#%d: kind: %s, method: %s, ref index: %d, %s:\n", i, e.kind, e.method, e.refIndex, s) if e.msg.err == nil { if err := proto.MarshalText(w, e.msg.msg); err != nil { return err } } else { fmt.Fprintf(w, "%v\n", e.msg.err) } } } // An entry holds one gRPC action (request, response, etc.). type entry struct { kind pb.Entry_Kind method string msg message refIndex int // index of corresponding request or create-stream } func (e1 *entry) equal(e2 *entry) bool { if e1 == nil && e2 == nil { return true } if e1 == nil || e2 == nil { return false } return e1.kind == e2.kind && e1.method == e2.method && proto.Equal(e1.msg.msg, e2.msg.msg) && errEqual(e1.msg.err, e2.msg.err) && e1.refIndex == e2.refIndex } func errEqual(e1, e2 error) bool { if e1 == e2 { return true } s1, ok1 := status.FromError(e1) s2, ok2 := status.FromError(e2) if !ok1 || !ok2 { return false } return proto.Equal(s1.Proto(), s2.Proto()) } // message holds either a single proto.Message or an error. type message struct { msg proto.Message err error } func (m *message) set(msg interface{}, err error) { m.err = err if err != io.EOF && msg != nil { m.msg = msg.(proto.Message) } } // File format: // header // sequence of Entry protos // // Header format: // magic string // a record containing the bytes of the initial state const magic = "RPCReplay" func writeHeader(w io.Writer, initial []byte) error { if _, err := io.WriteString(w, magic); err != nil { return err } return writeRecord(w, initial) } func readHeader(r io.Reader) ([]byte, error) { var buf [len(magic)]byte if _, err := io.ReadFull(r, buf[:]); err != nil { if err == io.EOF { err = errors.New("rpcreplay: empty replay file") } return nil, err } if string(buf[:]) != magic { return nil, errors.New("rpcreplay: not a replay file (does not begin with magic string)") } bytes, err := readRecord(r) if err == io.EOF { err = errors.New("rpcreplay: missing initial state") } return bytes, err } func writeEntry(w io.Writer, e *entry) error { var m proto.Message if e.msg.err != nil && e.msg.err != io.EOF { s, ok := status.FromError(e.msg.err) if !ok { return fmt.Errorf("rpcreplay: error %v is not a Status", e.msg.err) } m = s.Proto() } else { m = e.msg.msg } var a *any.Any var err error if m != nil { a, err = ptypes.MarshalAny(m) if err != nil { return err } } pe := &pb.Entry{ Kind: e.kind, Method: e.method, Message: a, IsError: e.msg.err != nil, RefIndex: int32(e.refIndex), } bytes, err := proto.Marshal(pe) if err != nil { return err } return writeRecord(w, bytes) } func readEntry(r io.Reader) (*entry, error) { buf, err := readRecord(r) if err == io.EOF { return nil, nil } if err != nil { return nil, err } var pe pb.Entry if err := proto.Unmarshal(buf, &pe); err != nil { return nil, err } var msg message if pe.Message != nil { var any ptypes.DynamicAny if err := ptypes.UnmarshalAny(pe.Message, &any); err != nil { return nil, err } if pe.IsError { msg.err = status.ErrorProto(any.Message.(*spb.Status)) } else { msg.msg = any.Message } } else if pe.IsError { msg.err = io.EOF } else if pe.Kind != pb.Entry_CREATE_STREAM { return nil, errors.New("rpcreplay: entry with nil message and false is_error") } return &entry{ kind: pe.Kind, method: pe.Method, msg: msg, refIndex: int(pe.RefIndex), }, nil } // A record consists of an unsigned 32-bit little-endian length L followed by L // bytes. func writeRecord(w io.Writer, data []byte) error { if err := binary.Write(w, binary.LittleEndian, uint32(len(data))); err != nil { return err } _, err := w.Write(data) return err } func readRecord(r io.Reader) ([]byte, error) { var size uint32 if err := binary.Read(r, binary.LittleEndian, &size); err != nil { return nil, err } buf := make([]byte, size) if _, err := io.ReadFull(r, buf); err != nil { return nil, err } return buf, nil }