2
2
3
3
#include " llama.h"
4
4
#include " llama-io.h"
5
- #include " llama-graph .h"
5
+ #include " llama-memory .h"
6
6
7
7
#include " ggml-cpp.h"
8
8
@@ -13,6 +13,17 @@ struct llama_cparams;
13
13
struct llama_hparams ;
14
14
struct llama_ubatch ;
15
15
16
+ struct llama_kv_cache : public llama_memory_i {
17
+ using llama_memory_i::llama_memory_i;
18
+
19
+ virtual int32_t get_n_tokens () const = 0;
20
+ virtual uint32_t get_used_cells () const = 0; // TODO: remove, this is too-specific to the unified cache
21
+
22
+ virtual bool get_can_shift () const = 0;
23
+
24
+ bool get_can_edit () const override { return get_can_shift (); }
25
+ };
26
+
16
27
struct llama_kv_cell {
17
28
llama_pos pos = -1 ;
18
29
llama_pos delta = 0 ;
@@ -45,36 +56,10 @@ struct llama_kv_cache_slot_info {
45
56
operator bool () const { return found; }
46
57
};
47
58
48
- struct llama_kv_cache {
49
- public:
50
- virtual int32_t n_tokens () const = 0;
51
- virtual uint32_t used_cells () const = 0; // TODO: remove
52
-
53
- virtual void clear () = 0;
54
- virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
55
- virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
56
- virtual void seq_keep (llama_seq_id seq_id) = 0;
57
- virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
58
- virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
59
-
60
- virtual llama_pos seq_pos_max (llama_seq_id seq_id) = 0;
61
-
62
- virtual void defrag () = 0;
63
- virtual bool get_can_shift () const = 0;
64
- };
65
-
66
-
67
- // C++ alias
68
- class llama_kv_cache_i : public llama_kv_cache {
69
- public:
70
- using llama_kv_cache::llama_kv_cache;
71
- };
72
-
73
-
74
59
// ring-buffer of cached KV data
75
60
// TODO: pimpl
76
61
// TODO: add notion of max sequences
77
- class llama_kv_cache_unified : public llama_kv_cache_i {
62
+ class llama_kv_cache_unified : public llama_kv_cache {
78
63
public:
79
64
llama_kv_cache_unified (const llama_hparams & hparams);
80
65
virtual ~llama_kv_cache_unified () = default ;
@@ -88,15 +73,16 @@ class llama_kv_cache_unified : public llama_kv_cache_i {
88
73
uint32_t kv_size,
89
74
bool offload);
90
75
91
- int32_t n_tokens () const override ;
92
- uint32_t used_cells () const override ;
76
+ int32_t get_n_tokens () const override ;
77
+ uint32_t get_used_cells () const override ;
93
78
94
79
size_t total_size () const ;
95
80
96
81
// TODO: better data structures to reduce the cost of this operation
97
82
llama_pos pos_max () const ;
98
83
99
84
void clear () override ;
85
+ void defrag () override ;
100
86
101
87
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override ;
102
88
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override ;
@@ -106,7 +92,6 @@ class llama_kv_cache_unified : public llama_kv_cache_i {
106
92
107
93
llama_pos seq_pos_max (llama_seq_id seq_id) override ;
108
94
109
- void defrag () override ;
110
95
bool get_can_shift () const override ;
111
96
112
97
// find an empty slot of size "n_tokens" in the cache
0 commit comments