Skip to content

Commit 38db8a5

Browse files
committed
llama : introduce concept of llama_memory
ggml-ci
1 parent 828effd commit 38db8a5

6 files changed

+1345
-45
lines changed

src/llama-context.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ llama_context_base::llama_context_base(
4949
const llama_model & model,
5050
llama_context_params params,
5151
llama_graph_type gtype) :
52-
llama_context_i(),
52+
llama_context(),
5353
llama_graph_i(gtype),
5454
model(model) {
5555
LLAMA_LOG_INFO("%s: constructing llama_context_base, gtype = %d\n", __func__, gtype);

src/llama-context.h

+8-9
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ class llama_io_write_i;
2121
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
2222

2323
// abstract interface corresponding to the public C API
24-
struct llama_context {
24+
class llama_context_i {
2525
public:
26-
llama_context() = default;
27-
virtual ~llama_context() = default;
26+
llama_context_i() = default;
27+
virtual ~llama_context_i() = default;
2828

2929
virtual void init() = 0;
3030

@@ -157,14 +157,13 @@ struct llama_context {
157157
size_t n_token_count) = 0;
158158
};
159159

160-
// C++ alias
161-
class llama_context_i : public llama_context {
162-
public:
163-
using llama_context::llama_context;
160+
// C alias
161+
struct llama_context : public llama_context_i {
162+
using llama_context_i::llama_context_i;
164163
};
165164

166165
// basic transformer without KV cache
167-
class llama_context_base : public llama_context_i, public llama_graph_i {
166+
class llama_context_base : public llama_context, public llama_graph_i {
168167
public:
169168
llama_context_base(
170169
const llama_model & model,
@@ -821,7 +820,7 @@ class llama_context_dec : public llama_context_kv_self {
821820
llama_cross * cross = nullptr;
822821
};
823822

824-
class llama_context_enc_dec : public llama_context_i {
823+
class llama_context_enc_dec : public llama_context {
825824
public:
826825
llama_context_enc_dec(
827826
const llama_model & model,

src/llama-kv-cache.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ bool llama_kv_cache_unified::init(
122122
return true;
123123
}
124124

125-
int32_t llama_kv_cache_unified::n_tokens() const {
125+
int32_t llama_kv_cache_unified::get_n_tokens() const {
126126
int32_t result = 0;
127127

128128
for (uint32_t i = 0; i < size; i++) {
@@ -132,7 +132,7 @@ int32_t llama_kv_cache_unified::n_tokens() const {
132132
return result;
133133
}
134134

135-
uint32_t llama_kv_cache_unified::used_cells() const {
135+
uint32_t llama_kv_cache_unified::get_used_cells() const {
136136
return used;
137137
}
138138

@@ -1091,15 +1091,15 @@ int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
10911091
return 0;
10921092
}
10931093

1094-
return kv->n_tokens();
1094+
return kv->get_n_tokens();
10951095
}
10961096

10971097
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
10981098
if (!kv) {
10991099
return 0;
11001100
}
11011101

1102-
return kv->used_cells();
1102+
return kv->get_used_cells();
11031103
}
11041104

11051105
void llama_kv_cache_clear(llama_kv_cache * kv) {

src/llama-kv-cache.h

+16-31
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include "llama.h"
44
#include "llama-io.h"
5-
#include "llama-graph.h"
5+
#include "llama-memory.h"
66

77
#include "ggml-cpp.h"
88

@@ -13,6 +13,17 @@ struct llama_cparams;
1313
struct llama_hparams;
1414
struct llama_ubatch;
1515

16+
struct llama_kv_cache : public llama_memory_i {
17+
using llama_memory_i::llama_memory_i;
18+
19+
virtual int32_t get_n_tokens() const = 0;
20+
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
21+
22+
virtual bool get_can_shift() const = 0;
23+
24+
bool get_can_edit() const override { return get_can_shift(); }
25+
};
26+
1627
struct llama_kv_cell {
1728
llama_pos pos = -1;
1829
llama_pos delta = 0;
@@ -45,36 +56,10 @@ struct llama_kv_cache_slot_info {
4556
operator bool() const { return found; }
4657
};
4758

48-
struct llama_kv_cache {
49-
public:
50-
virtual int32_t n_tokens() const = 0;
51-
virtual uint32_t used_cells() const = 0; // TODO: remove
52-
53-
virtual void clear() = 0;
54-
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
55-
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
56-
virtual void seq_keep(llama_seq_id seq_id) = 0;
57-
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
58-
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
59-
60-
virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
61-
62-
virtual void defrag() = 0;
63-
virtual bool get_can_shift() const = 0;
64-
};
65-
66-
67-
// C++ alias
68-
class llama_kv_cache_i : public llama_kv_cache {
69-
public:
70-
using llama_kv_cache::llama_kv_cache;
71-
};
72-
73-
7459
// ring-buffer of cached KV data
7560
// TODO: pimpl
7661
// TODO: add notion of max sequences
77-
class llama_kv_cache_unified : public llama_kv_cache_i {
62+
class llama_kv_cache_unified : public llama_kv_cache {
7863
public:
7964
llama_kv_cache_unified(const llama_hparams & hparams);
8065
virtual ~llama_kv_cache_unified() = default;
@@ -88,15 +73,16 @@ class llama_kv_cache_unified : public llama_kv_cache_i {
8873
uint32_t kv_size,
8974
bool offload);
9075

91-
int32_t n_tokens() const override;
92-
uint32_t used_cells() const override;
76+
int32_t get_n_tokens() const override;
77+
uint32_t get_used_cells() const override;
9378

9479
size_t total_size() const;
9580

9681
// TODO: better data structures to reduce the cost of this operation
9782
llama_pos pos_max() const;
9883

9984
void clear() override;
85+
void defrag() override;
10086

10187
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
10288
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
@@ -106,7 +92,6 @@ class llama_kv_cache_unified : public llama_kv_cache_i {
10692

10793
llama_pos seq_pos_max(llama_seq_id seq_id) override;
10894

109-
void defrag() override;
11095
bool get_can_shift() const override;
11196

11297
// find an empty slot of size "n_tokens" in the cache

0 commit comments

Comments
 (0)