Skip to content

Commit 03e10bc

Browse files
authored
Infer cache/RoPE weight dtype from output weights (#146)
- Add `dtype` argument to `precompute_freqs_cis` - Infer caches/RoPE weights `dtype` from output weight `dtype` in the `Transformer` constructor This way one can change precision in one place in `generate.py` and it will be propagated throughout the model
2 parents 1190c08 + 3b8469b commit 03e10bc

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

model.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,11 @@ def setup_caches(self, max_batch_size, max_seq_length):
107107
max_seq_length = find_multiple(max_seq_length, 8)
108108
self.max_seq_length = max_seq_length
109109
self.max_batch_size = max_batch_size
110+
dtype=self.output.weight.dtype
110111
for b in self.layers:
111-
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim)
112+
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
112113

113-
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base)
114+
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
114115
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
115116

116117
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
@@ -222,14 +223,15 @@ def forward(self, x: Tensor) -> Tensor:
222223

223224

224225
def precompute_freqs_cis(
225-
seq_len: int, n_elem: int, base: int = 10000
226+
seq_len: int, n_elem: int, base: int = 10000,
227+
dtype: torch.dtype = torch.bfloat16
226228
) -> Tensor:
227229
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
228230
t = torch.arange(seq_len, device=freqs.device)
229231
freqs = torch.outer(t, freqs)
230232
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
231233
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
232-
return cache.to(dtype=torch.bfloat16)
234+
return cache.to(dtype=dtype)
233235

234236

235237
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:

0 commit comments

Comments
 (0)