Skip to content

Commit 50a29b3

Browse files
matt23654slaren
authored andcommitted
[GGML][RPC] Support for models with non-512-aligned tensors over RPC. (ggml-org#11047)
* Added init tensor calling code * Added get_alloc_size forwarding * Cleaned up and improved type/error handling. * fix: remove trailing whitespaces. * Cleanup and use GGML error logging functions. * Handle potentially dangerous edge cases. * Apply suggestions from code review Co-authored-by: Diego Devesa <[email protected]> --------- Co-authored-by: Diego Devesa <[email protected]>
1 parent 2599750 commit 50a29b3

File tree

1 file changed

+134
-6
lines changed

1 file changed

+134
-6
lines changed

ggml/src/ggml-rpc/ggml-rpc.cpp

+134-6
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,23 @@ enum rpc_cmd {
9393
RPC_CMD_COPY_TENSOR,
9494
RPC_CMD_GRAPH_COMPUTE,
9595
RPC_CMD_GET_DEVICE_MEMORY,
96+
RPC_CMD_INIT_TENSOR,
97+
RPC_CMD_GET_ALLOC_SIZE,
9698
RPC_CMD_COUNT,
9799
};
98100

101+
struct rpc_msg_get_alloc_size_req {
102+
rpc_tensor tensor;
103+
};
104+
105+
struct rpc_msg_get_alloc_size_rsp {
106+
uint64_t alloc_size;
107+
};
108+
109+
struct rpc_msg_init_tensor_req {
110+
rpc_tensor tensor;
111+
};
112+
99113
struct rpc_msg_alloc_buffer_req {
100114
uint64_t size;
101115
};
@@ -461,10 +475,18 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
461475
}
462476

463477
static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
464-
UNUSED(buffer);
465-
if (ggml_is_quantized(tensor->type)) {
466-
// TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
467-
GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor");
478+
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
479+
480+
// CUDA backend on the server pads everything to 512 due to CUDA limitations.
481+
// Due to bandwidth constraints, we only call the server init tensor functions if necessary.
482+
// In particular, only quantized tensors need padding
483+
if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
484+
rpc_msg_init_tensor_req request;
485+
486+
request.tensor = serialize_tensor(tensor);
487+
488+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
489+
GGML_ASSERT(status);
468490
}
469491
}
470492

@@ -577,8 +599,23 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
577599
}
578600

579601
static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
580-
UNUSED(buft);
581-
return ggml_nbytes(tensor);
602+
// See comments in init_tensor.
603+
if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
604+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
605+
auto sock = get_socket(buft_ctx->endpoint);
606+
607+
rpc_msg_get_alloc_size_req request;
608+
609+
request.tensor = serialize_tensor(tensor);
610+
611+
rpc_msg_get_alloc_size_rsp response;
612+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
613+
GGML_ASSERT(status);
614+
615+
return response.alloc_size;
616+
} else {
617+
return ggml_nbytes(tensor);
618+
}
582619
}
583620

584621
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -757,6 +794,8 @@ class rpc_server {
757794
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
758795
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
759796
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
797+
bool init_tensor(const rpc_msg_init_tensor_req & request);
798+
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
760799

761800
private:
762801
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
@@ -770,6 +809,36 @@ class rpc_server {
770809
std::unordered_set<ggml_backend_buffer_t> buffers;
771810
};
772811

812+
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
813+
ggml_backend_buffer_type_t buft;
814+
struct ggml_init_params params {
815+
/*.mem_size =*/ ggml_tensor_overhead(),
816+
/*.mem_buffer =*/ NULL,
817+
/*.no_alloc =*/ true,
818+
};
819+
820+
struct ggml_context * ctx = ggml_init(params);
821+
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
822+
823+
if (tensor == nullptr) {
824+
GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
825+
ggml_free(ctx);
826+
return false;
827+
}
828+
829+
if (tensor->buffer == nullptr) {
830+
//No buffer allocated.
831+
buft = ggml_backend_get_default_buffer_type(backend);
832+
} else {
833+
buft = tensor->buffer->buft;
834+
}
835+
836+
response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor);
837+
838+
ggml_free(ctx);
839+
return true;
840+
}
841+
773842
void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
774843
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
775844
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
@@ -905,6 +974,40 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
905974
return true;
906975
}
907976

977+
bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
978+
struct ggml_init_params params {
979+
/*.mem_size =*/ ggml_tensor_overhead(),
980+
/*.mem_buffer =*/ NULL,
981+
/*.no_alloc =*/ true,
982+
};
983+
struct ggml_context * ctx = ggml_init(params);
984+
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
985+
if (tensor == nullptr) {
986+
GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n");
987+
ggml_free(ctx);
988+
return false;
989+
}
990+
991+
// Call the backend's buffer_init_tensor function
992+
ggml_backend_buffer_t buffer = tensor->buffer;
993+
if (buffer && buffer->iface.init_tensor) {
994+
buffer->iface.init_tensor(buffer, tensor);
995+
} else {
996+
GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n");
997+
}
998+
999+
if (tensor->extra != nullptr) {
1000+
// This pointer can either be passed around client/server, or probably better stored server-side and kept track of.
1001+
// Currently unimplemented.
1002+
GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n");
1003+
ggml_free(ctx);
1004+
return false;
1005+
}
1006+
1007+
ggml_free(ctx);
1008+
return true;
1009+
}
1010+
9081011
bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
9091012
struct ggml_init_params params {
9101013
/*.mem_size =*/ ggml_tensor_overhead(),
@@ -1058,6 +1161,18 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
10581161
}
10591162
break;
10601163
}
1164+
case RPC_CMD_GET_ALLOC_SIZE: {
1165+
rpc_msg_get_alloc_size_req request;
1166+
if (!recv_msg(sockfd, &request, sizeof(request))) {
1167+
return;
1168+
}
1169+
rpc_msg_get_alloc_size_rsp response;
1170+
server.get_alloc_size(request, response);
1171+
if (!send_msg(sockfd, &response, sizeof(response))) {
1172+
return;
1173+
}
1174+
break;
1175+
}
10611176
case RPC_CMD_GET_ALIGNMENT: {
10621177
if (!recv_msg(sockfd, nullptr, 0)) {
10631178
return;
@@ -1133,6 +1248,19 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
11331248
}
11341249
break;
11351250
}
1251+
case RPC_CMD_INIT_TENSOR: {
1252+
rpc_msg_init_tensor_req request;
1253+
if (!recv_msg(sockfd, &request,sizeof(request))) {
1254+
return;
1255+
}
1256+
if (!server.init_tensor(request)) {
1257+
return;
1258+
}
1259+
if (!send_msg(sockfd, nullptr, 0)) {
1260+
return;
1261+
}
1262+
break;
1263+
}
11361264
case RPC_CMD_GET_TENSOR: {
11371265
rpc_msg_get_tensor_req request;
11381266
if (!recv_msg(sockfd, &request, sizeof(request))) {

0 commit comments

Comments
 (0)