@@ -7807,8 +7807,6 @@ static int llama_decode_impl(
7807
7807
uint32_t n_outputs = 0 ;
7808
7808
uint32_t n_outputs_prev = 0 ;
7809
7809
7810
- const auto n_ubatch = cparams.n_ubatch ;
7811
-
7812
7810
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
7813
7811
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
7814
7812
@@ -7832,27 +7830,19 @@ static int llama_decode_impl(
7832
7830
return -2 ;
7833
7831
};
7834
7832
7835
- auto & kv_self = lctx.kv_self ;
7836
- llama_kv_slot_restorer kv_slot_restorer (kv_self);
7833
+ const bool logits_all = n_outputs == n_tokens_all;
7834
+
7835
+ // auto & kv_self = lctx.kv_self;
7836
+ // llama_kv_slot_restorer kv_slot_restorer(kv_self);
7837
+
7838
+ // lctx.sbatch.from_batch(batch, n_embd,
7839
+ // /* simple_split */ !kv_self.recurrent,
7840
+ // /* logits_all */ logits_all);
7837
7841
7838
- lctx.sbatch .from_batch (batch, n_embd,
7839
- /* simple_split */ !kv_self.recurrent ,
7840
- /* logits_all */ n_outputs == n_tokens_all);
7842
+ auto batch_manager = lctx.prepare_batch (batch, logits_all);
7841
7843
7842
7844
while (lctx.sbatch .n_tokens > 0 ) {
7843
- llama_ubatch ubatch;
7844
- if (kv_self.recurrent ) {
7845
- if (embd_pooled) {
7846
- // Pooled embeddings cannot be split across ubatches (yet)
7847
- ubatch = lctx.sbatch .split_seq (n_ubatch);
7848
- } else {
7849
- // recurrent model architectures are easier to implement
7850
- // with equal-length sequences
7851
- ubatch = lctx.sbatch .split_equal (n_ubatch);
7852
- }
7853
- } else {
7854
- ubatch = lctx.sbatch .split_simple (n_ubatch);
7855
- }
7845
+ llama_ubatch ubatch = batch_manager->next ();
7856
7846
7857
7847
const uint32_t n_tokens = ubatch.n_tokens ;
7858
7848
@@ -7873,32 +7863,10 @@ static int llama_decode_impl(
7873
7863
lctx.n_outputs = n_outputs_new;
7874
7864
}
7875
7865
7876
- lctx.prepare_decode (ubatch);
7877
-
7878
- // non-causal masks do not use the KV cache
7879
- if (hparams.causal_attn ) {
7880
- llama_kv_self_update (&lctx);
7881
-
7882
- // if we have enough unused cells before the current head ->
7883
- // better to start searching from the beginning of the cache, hoping to fill it
7884
- if (kv_self.head > kv_self.used + 2 *n_tokens) {
7885
- kv_self.head = 0 ;
7886
- }
7887
-
7888
- const auto slot_info = kv_self.find_slot (ubatch);
7889
- if (!slot_info) {
7890
- return 1 ;
7891
- }
7892
- kv_slot_restorer.save (slot_info);
7893
-
7894
- if (!kv_self.recurrent ) {
7895
- // a heuristic, to avoid attending the full cache if it is not yet utilized
7896
- // after enough generations, the benefit from this heuristic disappears
7897
- // if we start defragmenting the cache, the benefit from this will be more important
7898
- const uint32_t pad = kv_self.get_padding (cparams);
7899
- kv_self.n = std::min (kv_self.size , std::max (pad, GGML_PAD (kv_self.cell_max (), pad)));
7900
- // kv_self.n = llama_kv_cache_cell_max(kv_self);
7901
- }
7866
+ if (!batch_manager->prepare ()) {
7867
+ LLAMA_LOG_ERROR (" %s: failed to prepare ubatch\n " , __func__);
7868
+ batch_manager->restore ();
7869
+ return -3 ;
7902
7870
}
7903
7871
7904
7872
// reserve a worst case graph if needed
@@ -7963,7 +7931,7 @@ static int llama_decode_impl(
7963
7931
7964
7932
const auto compute_status = lctx.compute_graph (gf, n_tokens > 1 );
7965
7933
if (compute_status != GGML_STATUS_SUCCESS) {
7966
- kv_slot_restorer. restore (kv_self );
7934
+ batch_manager-> restore ();
7967
7935
switch (compute_status) {
7968
7936
case GGML_STATUS_ABORTED:
7969
7937
return 2 ;
@@ -7975,15 +7943,7 @@ static int llama_decode_impl(
7975
7943
}
7976
7944
}
7977
7945
7978
- // update the kv ring buffer
7979
- {
7980
- kv_self.head += n_tokens;
7981
-
7982
- // Ensure kv cache head points to a valid index.
7983
- if (kv_self.head >= kv_self.size ) {
7984
- kv_self.head = 0 ;
7985
- }
7986
- }
7946
+ batch_manager->update ();
7987
7947
7988
7948
// plot the computation graph in dot format (for debugging purposes)
7989
7949
// if (n_past%100 == 0) {
@@ -8061,6 +8021,7 @@ static int llama_decode_impl(
8061
8021
}
8062
8022
}
8063
8023
}
8024
+
8064
8025
n_outputs_prev += lctx.n_outputs ;
8065
8026
}
8066
8027
@@ -8089,17 +8050,7 @@ static int llama_decode_impl(
8089
8050
// wait for the computation to finish (automatically done when obtaining the model output)
8090
8051
// llama_synchronize(&lctx);
8091
8052
8092
- // decide if we need to defrag the kv cache
8093
- if (cparams.causal_attn && cparams.defrag_thold >= 0 .0f ) {
8094
- const float fragmentation = kv_self.n >= 128 ? 1 .0f - float (kv_self.used )/float (kv_self.n ) : 0 .0f ;
8095
-
8096
- // queue defragmentation for next llama_kv_cache_update
8097
- if (fragmentation > cparams.defrag_thold ) {
8098
- // LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
8099
-
8100
- kv_self.defrag ();
8101
- }
8102
- }
8053
+ batch_manager->finalize ();
8103
8054
8104
8055
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
8105
8056
// overlap with device computation.
@@ -8178,7 +8129,7 @@ static int llama_encode_impl(
8178
8129
lctx.inp_embd_enc = NULL ;
8179
8130
lctx.n_outputs = n_tokens;
8180
8131
8181
- lctx. prepare_decode (ubatch);
8132
+ // batch_manager->prepare (ubatch);
8182
8133
8183
8134
// reserve a worst case graph if needed
8184
8135
// TODO: extract to a function
0 commit comments