Skip to content

kv-cache : add SWA support #13194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
9 changes: 6 additions & 3 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -943,9 +943,12 @@ extern "C" {
// Requires KV cache.
// For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning.
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// < 0 - error. the KV cache state is restored to the state before this call
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// the KV cache is restored to the state before this call
// 2 - aborted. the KV cache is in undefined state
// -1 - invalid input batch. the KV cache is unmodified
// < -1 - error. the KV cache is in undefined state
LLAMA_API int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch);
Expand Down
65 changes: 51 additions & 14 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -871,8 +871,6 @@ int llama_context::decode(llama_batch & inp_batch) {
const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;

llama_kv_cache_guard kv_guard(kv_self);

GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT

if (batch.token) {
Expand Down Expand Up @@ -912,8 +910,6 @@ int llama_context::decode(llama_batch & inp_batch) {
n_outputs_all = 1;
}

llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);

// reserve output buffer
if (output_reserve(n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
Expand All @@ -923,11 +919,59 @@ int llama_context::decode(llama_batch & inp_batch) {
// handle any pending defrags/shifts
kv_self_update();

int64_t n_outputs_prev = 0;
llama_kv_cache_guard kv_guard(kv_self);

// this is the sequence-aware batch that we construct based on the input batch
llama_sbatch sbatch;

// we then split the sbatch into a set of ubatches. the split logic is delegated to the KV cache
std::vector<llama_ubatch> ubatches;

// if we fail to find a slot for the batch, we can retry after applying a defrag.
// in some cases, this can free up some space, which would be enough to fit the ubatches
// ref: https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2881412612
bool retry = true;

while (sbatch.n_tokens > 0) {
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
while (true) {
bool success = true;

sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);

while (sbatch.n_tokens > 0) {
ubatches.emplace_back(kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled));

// find an empty KV slot that can fit the current ubatch
if (!kv_self->find_slot(ubatches.back())) {
success = false;
break;
}
}

if (success) {
break;
}

if (!retry) {
LLAMA_LOG_WARN("%s: failed to fit the batch in the KV cache, batch size = %d\n", __func__, (int) n_tokens_all);
return 1;
}

// we failed to fit the sbatch once, and now we will try to defrag the KV cache and try to fit it again
retry = false;

kv_self->restore();
kv_self->defrag_sched(-1.0f);

kv_self_update();

ubatches.clear();
}

// we now have prepared the ubatches for this llama_decode and are ready to start processing

int64_t n_outputs_prev = 0;

for (const auto & ubatch : ubatches) {
// count the outputs in this u_batch
{
int32_t n_outputs_new = 0;
Expand All @@ -945,13 +989,6 @@ int llama_context::decode(llama_batch & inp_batch) {
n_outputs = n_outputs_new;
}

// find KV slot
if (!kv_self->find_slot(ubatch)) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);

return 1;
}

ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);

Expand Down
Loading
Loading