diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 4f0abb5a..f468f796 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -53,6 +53,9 @@ struct socket_t { } }; +// macro for nicer error messages on server crash +#define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response") + // all RPC structures must be packed #pragma pack(push, 1) // ggml_tensor is serialized into rpc_tensor @@ -425,7 +428,7 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm static bool check_server_version(const std::shared_ptr & sock) { rpc_msg_hello_rsp response; bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) { fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch); return false; @@ -481,7 +484,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; rpc_msg_free_buffer_req request = {ctx->remote_ptr}; bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); delete ctx; } @@ -493,7 +496,7 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { rpc_msg_buffer_get_base_req request = {ctx->remote_ptr}; rpc_msg_buffer_get_base_rsp response; bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); ctx->base_ptr = reinterpret_cast(response.base_ptr); return ctx->base_ptr; } @@ -545,7 +548,7 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_ request.tensor = serialize_tensor(tensor); bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); } return GGML_STATUS_SUCCESS; } @@ -560,7 +563,7 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm request.hash = fnv_hash((const uint8_t*)data, size); rpc_msg_set_tensor_hash_rsp response; bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); if (response.result) { // the server has the same data, no need to send it return; @@ -573,7 +576,7 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size); bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size()); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); } static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { @@ -583,7 +586,7 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con request.offset = offset; request.size = size; bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); } static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { @@ -601,7 +604,7 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con request.dst = serialize_tensor(dst); rpc_msg_copy_tensor_rsp response; bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); return response.result; } @@ -609,7 +612,7 @@ static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value}; bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); } static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = { @@ -635,7 +638,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back rpc_msg_alloc_buffer_rsp response; auto sock = get_socket(buft_ctx->endpoint); bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); if (response.remote_ptr != 0) { ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, ggml_backend_rpc_buffer_interface, @@ -650,7 +653,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back static size_t get_alignment(const std::shared_ptr & sock) { rpc_msg_get_alignment_rsp response; bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); return response.alignment; } @@ -662,7 +665,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ static size_t get_max_size(const std::shared_ptr & sock) { rpc_msg_get_max_size_rsp response; bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); return response.max_size; } @@ -683,7 +686,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty rpc_msg_get_alloc_size_rsp response; bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); return response.alloc_size; } else { @@ -761,7 +764,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g rpc_msg_graph_compute_rsp response; auto sock = get_socket(rpc_ctx->endpoint); bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); return (enum ggml_status)response.result; } @@ -835,7 +838,7 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) { static void get_device_memory(const std::shared_ptr & sock, size_t * free, size_t * total) { rpc_msg_get_device_memory_rsp response; bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); *free = response.free_mem; *total = response.total_mem; }