diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 9457e6815..53dbdda2a 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -1736,7 +1736,7 @@ struct sql_printer : public printer { } }; -static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) { +static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); const llama_model * model = llama_get_model(ctx); @@ -1753,14 +1753,19 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th for (int i = 1; i < n_tokens; i++) { tokens[i] = std::rand() % n_vocab; } - llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens)); + int res = llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens)); + if (res != 0) { + fprintf(stderr, "%s: failed to decode prompt batch, res = %d\n", __func__, res); + return false; + } n_processed += n_tokens; } llama_synchronize(ctx); + return true; } -static void test_gen(llama_context * ctx, int n_gen, int n_threads) { +static bool test_gen(llama_context * ctx, int n_gen, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); const llama_model * model = llama_get_model(ctx); @@ -1770,10 +1775,15 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) { llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { - llama_decode(ctx, llama_batch_get_one(&token, 1)); + int res = llama_decode(ctx, llama_batch_get_one(&token, 1)); + if (res != 0) { + fprintf(stderr, "%s: failed to decode generation batch, res = %d\n", __func__, res); + return false; + } llama_synchronize(ctx); token = std::rand() % n_vocab; } + return true; } static void llama_null_log_callback(enum ggml_log_level level, const char * text, void * user_data) { @@ -1917,13 +1927,21 @@ int main(int argc, char ** argv) { fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count); } //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads); - test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); + bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); + if (!res) { + fprintf(stderr, "%s: error: failed to run prompt warmup\n", __func__); + exit(1); + } } if (t.n_gen > 0) { if (params.progress) { fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count); } - test_gen(ctx, 1, t.n_threads); + bool res = test_gen(ctx, 1, t.n_threads); + if (!res) { + fprintf(stderr, "%s: error: failed to run gen warmup\n", __func__); + exit(1); + } } for (int i = 0; i < params.reps; i++) { @@ -1934,7 +1952,11 @@ int main(int argc, char ** argv) { fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count, i + 1, params.reps); } - test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads); + bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads); + if (!res) { + fprintf(stderr, "%s: error: failed to run depth\n", __func__); + exit(1); + } } uint64_t t_start = get_time_ns(); @@ -1944,14 +1966,22 @@ int main(int argc, char ** argv) { fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count, i + 1, params.reps); } - test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); + bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); + if (!res) { + fprintf(stderr, "%s: error: failed to run prompt\n", __func__); + exit(1); + } } if (t.n_gen > 0) { if (params.progress) { fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count, i + 1, params.reps); } - test_gen(ctx, t.n_gen, t.n_threads); + bool res = test_gen(ctx, t.n_gen, t.n_threads); + if (!res) { + fprintf(stderr, "%s: error: failed to run gen\n", __func__); + exit(1); + } } uint64_t t_ns = get_time_ns() - t_start;