diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index b9169299c..d1f839b63 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -246,7 +246,7 @@ bool llama_kv_cache_unified_iswa_context::next() { } bool llama_kv_cache_unified_iswa_context::apply() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + assert(!llama_memory_status_is_fail(status)); bool res = true; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 8517b722a..7f7b162ff 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -1776,7 +1776,7 @@ bool llama_kv_cache_unified_context::next() { } bool llama_kv_cache_unified_context::apply() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + assert(!llama_memory_status_is_fail(status)); // no ubatches -> this is a KV cache update if (ubatches.empty()) { diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 15cde98d1..67cbf9554 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -218,7 +218,7 @@ bool llama_memory_hybrid_context::next() { } bool llama_memory_hybrid_context::apply() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + assert(!llama_memory_status_is_fail(status)); bool res = true; diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index e52156bf3..6ed84057c 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -1071,7 +1071,15 @@ bool llama_memory_recurrent_context::next() { } bool llama_memory_recurrent_context::apply() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + assert(!llama_memory_status_is_fail(status)); + + // no ubatches -> this is an update + if (ubatches.empty()) { + // recurrent cache never performs updates + assert(status == LLAMA_MEMORY_STATUS_NO_UPDATE); + + return true; + } mem->find_slot(ubatches[i_next]); diff --git a/src/llama-memory.cpp b/src/llama-memory.cpp index f1107672c..ca6844c32 100644 --- a/src/llama-memory.cpp +++ b/src/llama-memory.cpp @@ -40,3 +40,20 @@ llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_me // if either status has an update, then the combined status has an update return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE; } + +bool llama_memory_status_is_fail(llama_memory_status status) { + switch (status) { + case LLAMA_MEMORY_STATUS_SUCCESS: + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + return false; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + return true; + } + } + + return false; +} diff --git a/src/llama-memory.h b/src/llama-memory.h index 16b7e5ee2..e8ba336e8 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -31,6 +31,9 @@ enum llama_memory_status { // useful for implementing hybrid memory types (e.g. iSWA) llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1); +// helper function for checking if a memory status indicates a failure +bool llama_memory_status_is_fail(llama_memory_status status); + // the interface for managing the memory context during batch processing // this interface is implemented per memory type. see: // - llama_kv_cache_unified_context