Skip to content

Commit 5923822

Browse files
committed
context : improve the batching logic in llama_decode
ggml-ci
1 parent e7bef6b commit 5923822

File tree

1 file changed

+51
-14
lines changed

1 file changed

+51
-14
lines changed

src/llama-context.cpp

+51-14
Original file line numberDiff line numberDiff line change
@@ -871,8 +871,6 @@ int llama_context::decode(llama_batch & inp_batch) {
871871
const int64_t n_tokens_all = batch.n_tokens;
872872
const int64_t n_embd = hparams.n_embd;
873873

874-
llama_kv_cache_guard kv_guard(kv_self);
875-
876874
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
877875

878876
if (batch.token) {
@@ -912,8 +910,6 @@ int llama_context::decode(llama_batch & inp_batch) {
912910
n_outputs_all = 1;
913911
}
914912

915-
llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
916-
917913
// reserve output buffer
918914
if (output_reserve(n_outputs_all) < n_outputs_all) {
919915
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
@@ -923,11 +919,59 @@ int llama_context::decode(llama_batch & inp_batch) {
923919
// handle any pending defrags/shifts
924920
kv_self_update();
925921

926-
int64_t n_outputs_prev = 0;
922+
llama_kv_cache_guard kv_guard(kv_self);
923+
924+
// this is the sequence-aware batch that we construct based on the input batch
925+
llama_sbatch sbatch;
926+
927+
// we then split the sbatch into a set of ubatches. the split logic is delegated to the KV cache
928+
std::vector<llama_ubatch> ubatches;
929+
930+
// if we fail to find a slot for the batch, we can retry after applying a defrag.
931+
// in some cases, this can free up some space, which would be enough to fit the ubatches
932+
// ref: https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2881412612
933+
bool retry = true;
927934

928-
while (sbatch.n_tokens > 0) {
929-
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
935+
while (true) {
936+
bool success = true;
937+
938+
sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
939+
940+
while (sbatch.n_tokens > 0) {
941+
ubatches.emplace_back(kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled));
942+
943+
// find an empty KV slot that can fit the current ubatch
944+
if (!kv_self->find_slot(ubatches.back())) {
945+
success = false;
946+
break;
947+
}
948+
}
949+
950+
if (success) {
951+
break;
952+
}
930953

954+
if (!retry) {
955+
LLAMA_LOG_WARN("%s: failed to fit the batch in the KV cache, batch size = %d\n", __func__, (int) n_tokens_all);
956+
return 1;
957+
}
958+
959+
// we failed to fit the sbatch once, and now we will try to defrag the KV cache and try to fit it again
960+
retry = false;
961+
962+
kv_self->restore();
963+
kv_self->defrag_sched(-1.0f);
964+
965+
kv_self_update();
966+
967+
ubatches.clear();
968+
}
969+
970+
// we now have prepared the ubatches for this llama_decode and are ready to start processing
971+
972+
int64_t n_outputs_prev = 0;
973+
974+
for (const auto & ubatch : ubatches) {
931975
// count the outputs in this u_batch
932976
{
933977
int32_t n_outputs_new = 0;
@@ -945,13 +989,6 @@ int llama_context::decode(llama_batch & inp_batch) {
945989
n_outputs = n_outputs_new;
946990
}
947991

948-
// find KV slot
949-
if (!kv_self->find_slot(ubatch)) {
950-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
951-
952-
return 1;
953-
}
954-
955992
ggml_backend_sched_reset(sched.get());
956993
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
957994

0 commit comments

Comments
 (0)