mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-14 20:29:41 -04:00
server : add SWA checkpoints (#15293)
* server : add SWA checkpoints ggml-ci * cont : server clean-up * server : handle state restore fails * llama : add extended llama_state_seq_ API * server : do not make checkpoints if --swa-full ggml-ci * llama : remove flags value for NONE * server : configure number of SWA checkpoints with CLI arg ggml-ci * args : fix scope of new argument
This commit is contained in:
@@ -692,6 +692,13 @@ struct completion_token_output {
|
||||
}
|
||||
};
|
||||
|
||||
struct swa_checkpoint {
|
||||
llama_pos pos_min;
|
||||
llama_pos pos_max;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
};
|
||||
|
||||
struct server_task_result_cmpl_final : server_task_result {
|
||||
int index = 0;
|
||||
|
||||
@@ -1336,6 +1343,8 @@ struct server_slot {
|
||||
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
|
||||
std::vector<swa_checkpoint> swa_checkpoints;
|
||||
|
||||
bool has_next_token = true;
|
||||
bool has_new_line = false;
|
||||
bool truncated = false;
|
||||
@@ -3293,6 +3302,8 @@ struct server_context {
|
||||
slot.n_past = 0;
|
||||
}
|
||||
|
||||
const auto n_swa = llama_model_n_swa(model);
|
||||
|
||||
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
if (pos_min == -1) {
|
||||
@@ -3300,12 +3311,58 @@ struct server_context {
|
||||
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
||||
}
|
||||
|
||||
const auto n_swa = llama_model_n_swa(model);
|
||||
if (pos_min > std::max(0, slot.n_past - n_swa)) {
|
||||
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
|
||||
|
||||
if (pos_min > pos_min_thold) {
|
||||
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
|
||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||
slot.n_past = 0;
|
||||
|
||||
// search for a SWA checkpoint
|
||||
const auto it = std::find_if(
|
||||
slot.swa_checkpoints.rbegin(),
|
||||
slot.swa_checkpoints.rend(),
|
||||
[&](const auto & cur) {
|
||||
return cur.pos_min <= pos_min_thold;
|
||||
}
|
||||
);
|
||||
|
||||
bool do_reset = it == slot.swa_checkpoints.rend();
|
||||
|
||||
if (!do_reset) {
|
||||
// restore the checkpoint
|
||||
const size_t swa_size = it->data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
|
||||
|
||||
if (n != swa_size) {
|
||||
SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
|
||||
do_reset = true;
|
||||
} else {
|
||||
slot.n_past = std::min(slot.n_past, it->pos_max);
|
||||
|
||||
SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
|
||||
if (do_reset) {
|
||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||
|
||||
slot.n_past = 0;
|
||||
slot.swa_checkpoints.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (n_swa > 0) {
|
||||
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
|
||||
|
||||
// erase any checkpoints with pos_min > pos_min_thold
|
||||
for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) {
|
||||
const auto & cur = slot.swa_checkpoints[i];
|
||||
if (cur.pos_min > pos_min_thold) {
|
||||
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i);
|
||||
|
||||
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3519,6 +3576,39 @@ struct server_context {
|
||||
|
||||
// prompt evaluated for next-token prediction
|
||||
slot.state = SLOT_STATE_GENERATING;
|
||||
|
||||
// make a checkpoint with the SWA memory
|
||||
// checkpoints are needed only if we are not using "--swa-full"
|
||||
if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) {
|
||||
if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) {
|
||||
{
|
||||
const auto & cur = slot.swa_checkpoints.back();
|
||||
|
||||
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n",
|
||||
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||
}
|
||||
|
||||
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
|
||||
}
|
||||
|
||||
const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
|
||||
|
||||
auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{
|
||||
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
|
||||
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
|
||||
/*.data = */ std::vector<uint8_t>(swa_size),
|
||||
});
|
||||
|
||||
llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
|
||||
|
||||
float size_total = 0.0f;
|
||||
for (const auto & checkpoint : slot.swa_checkpoints) {
|
||||
size_total += (float) checkpoint.data.size() / 1024 / 1024;
|
||||
}
|
||||
|
||||
SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total);
|
||||
}
|
||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
Reference in New Issue
Block a user