Skip to content

Commit 3b8469b

Browse files
committed
Infer caches/rope weights dtype from output weight
This way one can change precision in one place in `generate.py` and it will be propagated throughout the model
1 parent c142ec1 commit 3b8469b

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,11 @@ def setup_caches(self, max_batch_size, max_seq_length):
107107
max_seq_length = find_multiple(max_seq_length, 8)
108108
self.max_seq_length = max_seq_length
109109
self.max_batch_size = max_batch_size
110+
dtype=self.output.weight.dtype
110111
for b in self.layers:
111-
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim)
112+
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
112113

113-
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base)
114+
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
114115
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
115116

116117
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:

0 commit comments

Comments
 (0)