Skip to content

Commit 661e783

Browse files
committed
context : introduce llama_batch_manager
ggml-ci
1 parent 9c21995 commit 661e783

File tree

3 files changed

+162
-73
lines changed

3 files changed

+162
-73
lines changed

src/llama-context.cpp

+126-4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,132 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
3232
return relative_bucket;
3333
}
3434

35+
struct llama_batch_manager : public llama_batch_manager_i {
36+
llama_batch_manager(llama_context & lctx, const llama_batch & batch, bool logits_all) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
37+
const auto & hparams = lctx.model.hparams;
38+
const auto & n_embd = hparams.n_embd;
39+
40+
const auto & kv_self = lctx.kv_self;
41+
42+
lctx.sbatch.from_batch(batch, n_embd,
43+
/* simple_split */ !kv_self.recurrent,
44+
/* logits_all */ logits_all);
45+
}
46+
47+
~llama_batch_manager() override {
48+
}
49+
50+
virtual llama_ubatch next() override {
51+
ubatch = llama_ubatch();
52+
53+
const auto & cparams = lctx.cparams;
54+
const auto & kv_self = lctx.kv_self;
55+
56+
const auto & n_ubatch = cparams.n_ubatch;
57+
58+
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
59+
60+
if (kv_self.recurrent) {
61+
if (embd_pooled) {
62+
// Pooled embeddings cannot be split across ubatches (yet)
63+
ubatch = lctx.sbatch.split_seq(n_ubatch);
64+
} else {
65+
// recurrent model architectures are easier to implement
66+
// with equal-length sequences
67+
ubatch = lctx.sbatch.split_equal(n_ubatch);
68+
}
69+
} else {
70+
ubatch = lctx.sbatch.split_simple(n_ubatch);
71+
}
72+
73+
return ubatch;
74+
}
75+
76+
virtual bool prepare() override {
77+
const auto & cparams = lctx.cparams;
78+
const auto & hparams = lctx.model.hparams;
79+
80+
auto & kv_self = lctx.kv_self;
81+
82+
// non-causal masks do not use the KV cache
83+
if (hparams.causal_attn) {
84+
llama_kv_self_update(&lctx);
85+
86+
// if we have enough unused cells before the current head ->
87+
// better to start searching from the beginning of the cache, hoping to fill it
88+
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
89+
kv_self.head = 0;
90+
}
91+
92+
const auto slot_info = kv_self.find_slot(ubatch);
93+
if (!slot_info) {
94+
return false;
95+
}
96+
97+
kv_slot_restorer.save(slot_info);
98+
99+
if (!kv_self.recurrent) {
100+
// a heuristic, to avoid attending the full cache if it is not yet utilized
101+
// after enough generations, the benefit from this heuristic disappears
102+
// if we start defragmenting the cache, the benefit from this will be more important
103+
const uint32_t pad = kv_self.get_padding(cparams);
104+
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
105+
//kv_self.n = llama_kv_cache_cell_max(kv_self);
106+
}
107+
}
108+
109+
return true;
110+
}
111+
112+
virtual void restore() override {
113+
kv_slot_restorer.restore(lctx.kv_self);
114+
}
115+
116+
virtual void update() override {
117+
auto & kv_self = lctx.kv_self;
118+
119+
// update the kv ring buffer
120+
{
121+
kv_self.head += ubatch.n_tokens;
122+
123+
// Ensure kv cache head points to a valid index.
124+
if (kv_self.head >= kv_self.size) {
125+
kv_self.head = 0;
126+
}
127+
}
128+
}
129+
130+
virtual void finalize() override {
131+
const auto & cparams = lctx.cparams;
132+
133+
auto & kv_self = lctx.kv_self;
134+
135+
// decide if we need to defrag the kv cache
136+
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
137+
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
138+
139+
// queue defragmentation for next llama_kv_cache_update
140+
if (fragmentation > cparams.defrag_thold) {
141+
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
142+
143+
kv_self.defrag();
144+
}
145+
}
146+
}
147+
148+
llama_context & lctx;
149+
150+
const llama_batch & batch;
151+
152+
llama_ubatch ubatch;
153+
154+
llama_kv_slot_restorer kv_slot_restorer;
155+
};
156+
157+
std::unique_ptr<llama_batch_manager_i> llama_context::prepare_batch(const llama_batch & batch, bool logits_all) {
158+
return std::make_unique<llama_batch_manager>(*this, batch, logits_all);
159+
}
160+
35161
enum ggml_status llama_context::compute_graph(
36162
ggml_cgraph * graph,
37163
bool batched) {
@@ -59,7 +185,6 @@ enum ggml_status llama_context::compute_graph(
59185
return status;
60186
}
61187

62-
63188
llama_pos llama_context::pos_max() const {
64189
return kv_self.pos_max();
65190
}
@@ -94,9 +219,6 @@ void llama_context::prepare_k_shift() {
94219
void llama_context::prepare_defrag() {
95220
}
96221

97-
void llama_context::prepare_decode(const llama_ubatch & /*ubatch*/) {
98-
}
99-
100222
// llama input
101223

102224
void llama_context::set_inputs(const llama_ubatch & ubatch) {

src/llama-context.h

+17-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,20 @@
1616

1717
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
1818

19+
// TODO: this is very WIP - improve
20+
struct llama_batch_manager_i {
21+
virtual ~llama_batch_manager_i() = default;
22+
23+
//bool is_done() const;
24+
25+
virtual llama_ubatch next() = 0;
26+
27+
virtual bool prepare() = 0;
28+
virtual void restore() = 0;
29+
virtual void update() = 0;
30+
virtual void finalize() = 0;
31+
};
32+
1933
struct llama_context {
2034
llama_context(const llama_model & model)
2135
: model(model)
@@ -80,6 +94,9 @@ struct llama_context {
8094
ggml_abort_callback abort_callback = nullptr;
8195
void * abort_callback_data = nullptr;
8296

97+
// TODO: do not pass logits_all explicitly
98+
std::unique_ptr<llama_batch_manager_i> prepare_batch(const llama_batch & batch, bool logits_all);
99+
83100
// returns the result of ggml_backend_sched_graph_compute_async execution
84101
enum ggml_status compute_graph(
85102
ggml_cgraph * graph,
@@ -95,7 +112,6 @@ struct llama_context {
95112

96113
void prepare_k_shift();
97114
void prepare_defrag();
98-
void prepare_decode(const llama_ubatch & ubatch);
99115

100116
void set_inputs(const llama_ubatch & ubatch);
101117

src/llama.cpp

+19-68
Original file line numberDiff line numberDiff line change
@@ -7807,8 +7807,6 @@ static int llama_decode_impl(
78077807
uint32_t n_outputs = 0;
78087808
uint32_t n_outputs_prev = 0;
78097809

7810-
const auto n_ubatch = cparams.n_ubatch;
7811-
78127810
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
78137811
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
78147812

@@ -7832,27 +7830,19 @@ static int llama_decode_impl(
78327830
return -2;
78337831
};
78347832

7835-
auto & kv_self = lctx.kv_self;
7836-
llama_kv_slot_restorer kv_slot_restorer(kv_self);
7833+
const bool logits_all = n_outputs == n_tokens_all;
7834+
7835+
//auto & kv_self = lctx.kv_self;
7836+
//llama_kv_slot_restorer kv_slot_restorer(kv_self);
7837+
7838+
//lctx.sbatch.from_batch(batch, n_embd,
7839+
// /* simple_split */ !kv_self.recurrent,
7840+
// /* logits_all */ logits_all);
78377841

7838-
lctx.sbatch.from_batch(batch, n_embd,
7839-
/* simple_split */ !kv_self.recurrent,
7840-
/* logits_all */ n_outputs == n_tokens_all);
7842+
auto batch_manager = lctx.prepare_batch(batch, logits_all);
78417843

78427844
while (lctx.sbatch.n_tokens > 0) {
7843-
llama_ubatch ubatch;
7844-
if (kv_self.recurrent) {
7845-
if (embd_pooled) {
7846-
// Pooled embeddings cannot be split across ubatches (yet)
7847-
ubatch = lctx.sbatch.split_seq(n_ubatch);
7848-
} else {
7849-
// recurrent model architectures are easier to implement
7850-
// with equal-length sequences
7851-
ubatch = lctx.sbatch.split_equal(n_ubatch);
7852-
}
7853-
} else {
7854-
ubatch = lctx.sbatch.split_simple(n_ubatch);
7855-
}
7845+
llama_ubatch ubatch = batch_manager->next();
78567846

78577847
const uint32_t n_tokens = ubatch.n_tokens;
78587848

@@ -7873,32 +7863,10 @@ static int llama_decode_impl(
78737863
lctx.n_outputs = n_outputs_new;
78747864
}
78757865

7876-
lctx.prepare_decode(ubatch);
7877-
7878-
// non-causal masks do not use the KV cache
7879-
if (hparams.causal_attn) {
7880-
llama_kv_self_update(&lctx);
7881-
7882-
// if we have enough unused cells before the current head ->
7883-
// better to start searching from the beginning of the cache, hoping to fill it
7884-
if (kv_self.head > kv_self.used + 2*n_tokens) {
7885-
kv_self.head = 0;
7886-
}
7887-
7888-
const auto slot_info = kv_self.find_slot(ubatch);
7889-
if (!slot_info) {
7890-
return 1;
7891-
}
7892-
kv_slot_restorer.save(slot_info);
7893-
7894-
if (!kv_self.recurrent) {
7895-
// a heuristic, to avoid attending the full cache if it is not yet utilized
7896-
// after enough generations, the benefit from this heuristic disappears
7897-
// if we start defragmenting the cache, the benefit from this will be more important
7898-
const uint32_t pad = kv_self.get_padding(cparams);
7899-
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
7900-
//kv_self.n = llama_kv_cache_cell_max(kv_self);
7901-
}
7866+
if (!batch_manager->prepare()) {
7867+
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
7868+
batch_manager->restore();
7869+
return -3;
79027870
}
79037871

79047872
// reserve a worst case graph if needed
@@ -7963,7 +7931,7 @@ static int llama_decode_impl(
79637931

79647932
const auto compute_status = lctx.compute_graph(gf, n_tokens > 1);
79657933
if (compute_status != GGML_STATUS_SUCCESS) {
7966-
kv_slot_restorer.restore(kv_self);
7934+
batch_manager->restore();
79677935
switch (compute_status) {
79687936
case GGML_STATUS_ABORTED:
79697937
return 2;
@@ -7975,15 +7943,7 @@ static int llama_decode_impl(
79757943
}
79767944
}
79777945

7978-
// update the kv ring buffer
7979-
{
7980-
kv_self.head += n_tokens;
7981-
7982-
// Ensure kv cache head points to a valid index.
7983-
if (kv_self.head >= kv_self.size) {
7984-
kv_self.head = 0;
7985-
}
7986-
}
7946+
batch_manager->update();
79877947

79887948
// plot the computation graph in dot format (for debugging purposes)
79897949
//if (n_past%100 == 0) {
@@ -8061,6 +8021,7 @@ static int llama_decode_impl(
80618021
}
80628022
}
80638023
}
8024+
80648025
n_outputs_prev += lctx.n_outputs;
80658026
}
80668027

@@ -8089,17 +8050,7 @@ static int llama_decode_impl(
80898050
// wait for the computation to finish (automatically done when obtaining the model output)
80908051
//llama_synchronize(&lctx);
80918052

8092-
// decide if we need to defrag the kv cache
8093-
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
8094-
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
8095-
8096-
// queue defragmentation for next llama_kv_cache_update
8097-
if (fragmentation > cparams.defrag_thold) {
8098-
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
8099-
8100-
kv_self.defrag();
8101-
}
8102-
}
8053+
batch_manager->finalize();
81038054

81048055
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
81058056
// overlap with device computation.
@@ -8178,7 +8129,7 @@ static int llama_encode_impl(
81788129
lctx.inp_embd_enc = NULL;
81798130
lctx.n_outputs = n_tokens;
81808131

8181-
lctx.prepare_decode(ubatch);
8132+
//batch_manager->prepare(ubatch);
81828133

81838134
// reserve a worst case graph if needed
81848135
// TODO: extract to a function

0 commit comments

Comments
 (0)