rpc : check for null buffers in get/set/copy tensor endpoints (#14868)

This commit is contained in:
Chris Rohlf
2025-07-25 06:17:02 -04:00
committed by GitHub
parent c12bbde372
commit 64bf1c3744

View File

@ -1055,7 +1055,7 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
GGML_ASSERT(ctx_ptr != nullptr); GGML_ASSERT(ctx_ptr != nullptr);
ggml_context * ctx = ctx_ptr.get(); ggml_context * ctx = ctx_ptr.get();
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
if (tensor == nullptr) { if (tensor == nullptr || tensor->buffer == nullptr) {
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
return false; return false;
} }
@ -1124,7 +1124,7 @@ bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rp
GGML_ASSERT(ctx_ptr != nullptr); GGML_ASSERT(ctx_ptr != nullptr);
ggml_context * ctx = ctx_ptr.get(); ggml_context * ctx = ctx_ptr.get();
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
if (tensor == nullptr) { if (tensor == nullptr || tensor->buffer == nullptr) {
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
return false; return false;
} }
@ -1192,7 +1192,7 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
GGML_ASSERT(ctx_ptr != nullptr); GGML_ASSERT(ctx_ptr != nullptr);
ggml_context * ctx = ctx_ptr.get(); ggml_context * ctx = ctx_ptr.get();
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
if (tensor == nullptr) { if (tensor == nullptr || tensor->buffer == nullptr) {
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
return false; return false;
} }
@ -1229,7 +1229,7 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
ggml_tensor * src = deserialize_tensor(ctx, &request.src); ggml_tensor * src = deserialize_tensor(ctx, &request.src);
ggml_tensor * dst = deserialize_tensor(ctx, &request.dst); ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
if (src == nullptr || dst == nullptr) { if (src == nullptr || dst == nullptr || src->buffer == nullptr || dst->buffer == nullptr) {
GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__); GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
return false; return false;
} }