Skip to content

Commit 02d9a19

Browse files
committed
kv-cache : simplify SWA logic
ggml-ci
1 parent 65eee87 commit 02d9a19

File tree

5 files changed

+76
-46
lines changed

5 files changed

+76
-46
lines changed

src/llama-graph.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -362,17 +362,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
362362

363363
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
364364
if (self_kq_mask) {
365-
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn, false);
365+
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366366
}
367367
}
368368

369369
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
370370
if (self_kq_mask) {
371-
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn, false);
371+
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
372372
}
373373

374374
if (self_kq_mask_swa) {
375-
kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn, true);
375+
kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
376376
}
377377
}
378378

src/llama-hparams.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ enum llama_expert_gating_func_type {
1515
};
1616

1717
enum llama_swa_type {
18-
LLAMA_SWA_TYPE_STANDARD = 0,
19-
LLAMA_SWA_TYPE_CHUNKED = 1,
18+
LLAMA_SWA_TYPE_NONE = 0,
19+
LLAMA_SWA_TYPE_STANDARD = 1,
20+
LLAMA_SWA_TYPE_CHUNKED = 2,
2021
};
2122

2223
struct llama_hparams_posnet {
@@ -100,7 +101,7 @@ struct llama_hparams {
100101
std::array<int, 4> rope_sections;
101102

102103
// Sliding Window Attention (SWA)
103-
llama_swa_type swa_type = LLAMA_SWA_TYPE_STANDARD;
104+
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
104105

105106
uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA)
106107
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention

src/llama-kv-cache.cpp

+55-37
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
3030
bool v_trans,
3131
bool offload,
3232
uint32_t kv_size,
33-
uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
33+
uint32_t padding,
34+
uint32_t n_swa,
35+
llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding), n_swa(n_swa), swa_type(swa_type) {
3436
GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
3537

3638
this->type_k = type_k;
@@ -640,7 +642,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
640642
return ggml_cpy(ctx, v_cur, v_view);
641643
}
642644

643-
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn, bool swa) const {
645+
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
644646
const int64_t n_tokens = ubatch->n_tokens;
645647
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
646648
const int64_t n_seqs = ubatch->n_seqs;
@@ -667,41 +669,28 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
667669
const llama_seq_id seq_id = ubatch->seq_id[s][0];
668670

669671
for (int j = 0; j < n_seq_tokens; ++j) {
670-
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
672+
const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
671673

672674
for (int i = 0; i < n_kv; ++i) {
673-
float f;
674-
// mask the token if:
675-
if (!cells[i].has_seq_id(seq_id) // not the correct sequence
676-
|| (causal_attn && cells[i].pos > pos) // for causal, mask future tokens
677-
) {
678-
f = -INFINITY;
679-
} else {
680-
if (hparams.use_alibi) {
681-
f = -std::abs(cells[i].pos - pos);
682-
} else {
683-
f = 0.0f;
684-
}
685-
}
675+
const llama_pos p0 = cells[i].pos;
676+
677+
bool masked = false;
678+
679+
// mask the token if not the same sequence
680+
masked = masked || (!cells[i].has_seq_id(seq_id));
681+
682+
// mask future tokens
683+
masked = masked || (causal_attn && p0 > p1);
686684

687-
if (swa) {
688-
// may need to cut off old tokens for sliding window
689-
switch (hparams.swa_type) {
690-
case LLAMA_SWA_TYPE_STANDARD:
691-
{
692-
if (pos - cells[i].pos >= (int32_t) hparams.n_swa) {
693-
f = -INFINITY;
694-
}
695-
} break;
696-
case LLAMA_SWA_TYPE_CHUNKED:
697-
{
698-
const llama_pos pos_chunk_start = (pos / hparams.n_swa) * hparams.n_swa;
699-
700-
if (cells[i].pos < pos_chunk_start) {
701-
f = -INFINITY;
702-
}
703-
} break;
704-
}
685+
// apply SWA if any
686+
masked = masked || (is_masked_swa(p0, p1));
687+
688+
float f = 0.0f;
689+
690+
if (masked) {
691+
f = -INFINITY;
692+
} else if (hparams.use_alibi) {
693+
f = -std::abs(p0 - p1);
705694
}
706695

707696
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
@@ -1191,6 +1180,30 @@ uint32_t llama_kv_cache_unified::cell_max() const {
11911180
return 0;
11921181
}
11931182

1183+
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
1184+
switch (swa_type) {
1185+
case LLAMA_SWA_TYPE_NONE:
1186+
{
1187+
} break;
1188+
case LLAMA_SWA_TYPE_STANDARD:
1189+
{
1190+
if (p1 - p0 >= (int32_t) n_swa) {
1191+
return true;
1192+
}
1193+
} break;
1194+
case LLAMA_SWA_TYPE_CHUNKED:
1195+
{
1196+
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
1197+
1198+
if (p0 < pos_chunk_start) {
1199+
return true;
1200+
}
1201+
} break;
1202+
}
1203+
1204+
return false;
1205+
}
1206+
11941207
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
11951208
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
11961209
uint32_t cell_count = 0;
@@ -1586,11 +1599,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15861599

15871600
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, kv_size_base);
15881601

1589-
kv_base = std::make_unique<llama_kv_cache_unified>(model, std::move(filter_base), type_k, type_v, v_trans, offload, kv_size_base, padding);
1602+
kv_base = std::make_unique<llama_kv_cache_unified>(
1603+
model, std::move(filter_base), type_k, type_v,
1604+
v_trans, offload, kv_size_base, padding,
1605+
0, LLAMA_SWA_TYPE_NONE);
15901606

15911607
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, kv_size_swa);
15921608

1593-
kv_swa = std::make_unique<llama_kv_cache_unified>(model, std::move(filter_swa), type_k, type_v, v_trans, offload, kv_size_swa, padding);
1609+
kv_swa = std::make_unique<llama_kv_cache_unified>(
1610+
model, std::move(filter_swa), type_k, type_v,
1611+
v_trans, offload, kv_size_swa, padding,
1612+
hparams.n_swa, hparams.swa_type);
15941613
}
15951614

15961615
void llama_kv_cache_unified_iswa::clear() {
@@ -2801,5 +2820,4 @@ void llama_kv_cache_view_free(llama_kv_cache_view * view) {
28012820
void llama_kv_cache_view_update(llama_kv_cache_view * , const llama_kv_cache * ) {
28022821
// TODO: will be removed soon, keep this for now to avoid too many changes in
28032822
// https://github.com/ggml-org/llama.cpp/pull/13194
2804-
GGML_ABORT("not implemented");
28052823
}

src/llama-kv-cache.h

+11-2
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ class llama_kv_cache_unified : public llama_kv_cache {
102102
bool v_trans,
103103
bool offload,
104104
uint32_t kv_size,
105-
uint32_t padding);
105+
uint32_t padding,
106+
uint32_t n_swa,
107+
llama_swa_type swa_type);
106108

107109
~llama_kv_cache_unified() = default;
108110

@@ -169,7 +171,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
169171
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
170172
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
171173

172-
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn, bool swa) const;
174+
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
173175
void set_input_k_shift (ggml_tensor * dst) const;
174176
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
175177

@@ -223,6 +225,11 @@ class llama_kv_cache_unified : public llama_kv_cache {
223225
ggml_type type_k = GGML_TYPE_F16;
224226
ggml_type type_v = GGML_TYPE_F16;
225227

228+
// SWA
229+
uint32_t n_swa = 0;
230+
231+
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
232+
226233
std::vector<ggml_context_ptr> ctxs;
227234
std::vector<ggml_backend_buffer_ptr> bufs;
228235

@@ -264,6 +271,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
264271
size_t size_k_bytes() const;
265272
size_t size_v_bytes() const;
266273

274+
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
275+
267276
ggml_tensor * build_rope_shift(
268277
const llama_cparams & cparams,
269278
ggml_context * ctx,

src/llama-model.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -13228,7 +13228,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1322813228
!cparams.flash_attn,
1322913229
cparams.offload_kqv,
1323013230
cparams.n_ctx,
13231-
padding);
13231+
padding,
13232+
hparams.n_swa,
13233+
hparams.swa_type);
1323213234
}
1323313235
}
1323413236
}

0 commit comments

Comments
 (0)