@@ -30,7 +30,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
30
30
bool v_trans,
31
31
bool offload,
32
32
uint32_t kv_size,
33
- uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
33
+ uint32_t padding,
34
+ uint32_t n_swa,
35
+ llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding), n_swa(n_swa), swa_type(swa_type) {
34
36
GGML_ASSERT (kv_size % padding == 0 && " kv_size must be a multiple of padding" );
35
37
36
38
this ->type_k = type_k;
@@ -640,7 +642,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
640
642
return ggml_cpy (ctx, v_cur, v_view);
641
643
}
642
644
643
- void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn, bool swa ) const {
645
+ void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
644
646
const int64_t n_tokens = ubatch->n_tokens ;
645
647
const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
646
648
const int64_t n_seqs = ubatch->n_seqs ;
@@ -667,41 +669,28 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
667
669
const llama_seq_id seq_id = ubatch->seq_id [s][0 ];
668
670
669
671
for (int j = 0 ; j < n_seq_tokens; ++j) {
670
- const llama_pos pos = ubatch->pos [s*n_seq_tokens + j];
672
+ const llama_pos p1 = ubatch->pos [s*n_seq_tokens + j];
671
673
672
674
for (int i = 0 ; i < n_kv; ++i) {
673
- float f;
674
- // mask the token if:
675
- if (!cells[i].has_seq_id (seq_id) // not the correct sequence
676
- || (causal_attn && cells[i].pos > pos) // for causal, mask future tokens
677
- ) {
678
- f = -INFINITY;
679
- } else {
680
- if (hparams.use_alibi ) {
681
- f = -std::abs (cells[i].pos - pos);
682
- } else {
683
- f = 0 .0f ;
684
- }
685
- }
675
+ const llama_pos p0 = cells[i].pos ;
676
+
677
+ bool masked = false ;
678
+
679
+ // mask the token if not the same sequence
680
+ masked = masked || (!cells[i].has_seq_id (seq_id));
681
+
682
+ // mask future tokens
683
+ masked = masked || (causal_attn && p0 > p1);
686
684
687
- if (swa) {
688
- // may need to cut off old tokens for sliding window
689
- switch (hparams.swa_type ) {
690
- case LLAMA_SWA_TYPE_STANDARD:
691
- {
692
- if (pos - cells[i].pos >= (int32_t ) hparams.n_swa ) {
693
- f = -INFINITY;
694
- }
695
- } break ;
696
- case LLAMA_SWA_TYPE_CHUNKED:
697
- {
698
- const llama_pos pos_chunk_start = (pos / hparams.n_swa ) * hparams.n_swa ;
699
-
700
- if (cells[i].pos < pos_chunk_start) {
701
- f = -INFINITY;
702
- }
703
- } break ;
704
- }
685
+ // apply SWA if any
686
+ masked = masked || (is_masked_swa (p0, p1));
687
+
688
+ float f = 0 .0f ;
689
+
690
+ if (masked) {
691
+ f = -INFINITY;
692
+ } else if (hparams.use_alibi ) {
693
+ f = -std::abs (p0 - p1);
705
694
}
706
695
707
696
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
@@ -1191,6 +1180,30 @@ uint32_t llama_kv_cache_unified::cell_max() const {
1191
1180
return 0 ;
1192
1181
}
1193
1182
1183
+ bool llama_kv_cache_unified::is_masked_swa (llama_pos p0, llama_pos p1) const {
1184
+ switch (swa_type) {
1185
+ case LLAMA_SWA_TYPE_NONE:
1186
+ {
1187
+ } break ;
1188
+ case LLAMA_SWA_TYPE_STANDARD:
1189
+ {
1190
+ if (p1 - p0 >= (int32_t ) n_swa) {
1191
+ return true ;
1192
+ }
1193
+ } break ;
1194
+ case LLAMA_SWA_TYPE_CHUNKED:
1195
+ {
1196
+ const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
1197
+
1198
+ if (p0 < pos_chunk_start) {
1199
+ return true ;
1200
+ }
1201
+ } break ;
1202
+ }
1203
+
1204
+ return false ;
1205
+ }
1206
+
1194
1207
void llama_kv_cache_unified::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
1195
1208
std::vector<std::pair<uint32_t , uint32_t >> cell_ranges; // ranges, from inclusive, to exclusive
1196
1209
uint32_t cell_count = 0 ;
@@ -1586,11 +1599,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
1586
1599
1587
1600
LLAMA_LOG_INFO (" %s: creating non-SWA KV cache, size = %u cells\n " , __func__, kv_size_base);
1588
1601
1589
- kv_base = std::make_unique<llama_kv_cache_unified>(model, std::move (filter_base), type_k, type_v, v_trans, offload, kv_size_base, padding);
1602
+ kv_base = std::make_unique<llama_kv_cache_unified>(
1603
+ model, std::move (filter_base), type_k, type_v,
1604
+ v_trans, offload, kv_size_base, padding,
1605
+ 0 , LLAMA_SWA_TYPE_NONE);
1590
1606
1591
1607
LLAMA_LOG_INFO (" %s: creating SWA KV cache, size = %u cells\n " , __func__, kv_size_swa);
1592
1608
1593
- kv_swa = std::make_unique<llama_kv_cache_unified>(model, std::move (filter_swa), type_k, type_v, v_trans, offload, kv_size_swa, padding);
1609
+ kv_swa = std::make_unique<llama_kv_cache_unified>(
1610
+ model, std::move (filter_swa), type_k, type_v,
1611
+ v_trans, offload, kv_size_swa, padding,
1612
+ hparams.n_swa , hparams.swa_type );
1594
1613
}
1595
1614
1596
1615
void llama_kv_cache_unified_iswa::clear () {
@@ -2801,5 +2820,4 @@ void llama_kv_cache_view_free(llama_kv_cache_view * view) {
2801
2820
void llama_kv_cache_view_update (llama_kv_cache_view * , const llama_kv_cache * ) {
2802
2821
// TODO: will be removed soon, keep this for now to avoid too many changes in
2803
2822
// https://github.com/ggml-org/llama.cpp/pull/13194
2804
- GGML_ABORT (" not implemented" );
2805
2823
}
0 commit comments