@@ -871,8 +871,6 @@ int llama_context::decode(llama_batch & inp_batch) {
871
871
const int64_t n_tokens_all = batch.n_tokens ;
872
872
const int64_t n_embd = hparams.n_embd ;
873
873
874
- llama_kv_cache_guard kv_guard (kv_self);
875
-
876
874
GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
877
875
878
876
if (batch.token ) {
@@ -912,8 +910,6 @@ int llama_context::decode(llama_batch & inp_batch) {
912
910
n_outputs_all = 1 ;
913
911
}
914
912
915
- llama_sbatch sbatch = kv_self->sbatch_init (batch, /* logits_all */ n_outputs_all == n_tokens_all);
916
-
917
913
// reserve output buffer
918
914
if (output_reserve (n_outputs_all) < n_outputs_all) {
919
915
LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %" PRId64 " outputs\n " , __func__, n_outputs_all);
@@ -923,11 +919,59 @@ int llama_context::decode(llama_batch & inp_batch) {
923
919
// handle any pending defrags/shifts
924
920
kv_self_update ();
925
921
926
- int64_t n_outputs_prev = 0 ;
922
+ llama_kv_cache_guard kv_guard (kv_self);
923
+
924
+ // this is the sequence-aware batch that we construct based on the input batch
925
+ llama_sbatch sbatch;
926
+
927
+ // we then split the sbatch into a set of ubatches. the split logic is delegated to the KV cache
928
+ std::vector<llama_ubatch> ubatches;
929
+
930
+ // if we fail to find a slot for the batch, we can retry after applying a defrag.
931
+ // in some cases, this can free up some space, which would be enough to fit the ubatches
932
+ // ref: https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2881412612
933
+ bool retry = true ;
927
934
928
- while (sbatch.n_tokens > 0 ) {
929
- llama_ubatch ubatch = kv_self->ubatch_next (sbatch, cparams.n_ubatch , embd_pooled);
935
+ while (true ) {
936
+ bool success = true ;
937
+
938
+ sbatch = kv_self->sbatch_init (batch, /* logits_all */ n_outputs_all == n_tokens_all);
939
+
940
+ while (sbatch.n_tokens > 0 ) {
941
+ ubatches.emplace_back (kv_self->ubatch_next (sbatch, cparams.n_ubatch , embd_pooled));
942
+
943
+ // find an empty KV slot that can fit the current ubatch
944
+ if (!kv_self->find_slot (ubatches.back ())) {
945
+ success = false ;
946
+ break ;
947
+ }
948
+ }
949
+
950
+ if (success) {
951
+ break ;
952
+ }
930
953
954
+ if (!retry) {
955
+ LLAMA_LOG_WARN (" %s: failed to fit the batch in the KV cache, batch size = %d\n " , __func__, (int ) n_tokens_all);
956
+ return 1 ;
957
+ }
958
+
959
+ // we failed to fit the sbatch once, and now we will try to defrag the KV cache and try to fit it again
960
+ retry = false ;
961
+
962
+ kv_self->restore ();
963
+ kv_self->defrag_sched (-1 .0f );
964
+
965
+ kv_self_update ();
966
+
967
+ ubatches.clear ();
968
+ }
969
+
970
+ // we now have prepared the ubatches for this llama_decode and are ready to start processing
971
+
972
+ int64_t n_outputs_prev = 0 ;
973
+
974
+ for (const auto & ubatch : ubatches) {
931
975
// count the outputs in this u_batch
932
976
{
933
977
int32_t n_outputs_new = 0 ;
@@ -945,13 +989,6 @@ int llama_context::decode(llama_batch & inp_batch) {
945
989
n_outputs = n_outputs_new;
946
990
}
947
991
948
- // find KV slot
949
- if (!kv_self->find_slot (ubatch)) {
950
- LLAMA_LOG_WARN (" %s: failed to find KV cache slot for ubatch of size %d\n " , __func__, ubatch.n_tokens );
951
-
952
- return 1 ;
953
- }
954
-
955
992
ggml_backend_sched_reset (sched.get ());
956
993
ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
957
994
0 commit comments