Skip to content

Commit 3685aae

Browse files
ngxsonjwcolin
authored andcommitted
llama : Support llama 4 text-only (ggml-org#12791)
* llama4 conversion * initial support, no chat template * clean up a bit * fix tokenizer conversion * correct hparams * try this * fix shexp * ffn_inp_normed * chat template * clean up model conversion * add_bos * add scale_before_ffn * fix order * weight_before_ffn * llm_graph_input_attn_temp * add chunk attn mask * build_inp_attn_scale() * add comment about ggml_repeat * clarify comments * fix build
1 parent fd63243 commit 3685aae

17 files changed

+532
-22
lines changed

convert_hf_to_gguf.py

+64-4
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
714714
if chkhsh == "96a5f08be6259352137b512d4157e333e21df7edd3fcd152990608735a65b224":
715715
# ref: https://huggingface.co/inclusionAI/Ling-lite
716716
res = "bailingmoe"
717+
if chkhsh == "d353350c764d8c3b39c763113960e4fb4919bea5fbf208a0e3b22e8469dc7406":
718+
# ref: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct
719+
res = "llama4"
717720

718721
if res is None:
719722
logger.warning("\n")
@@ -1608,6 +1611,7 @@ def prepare_tensors(self):
16081611
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
16091612
class LlamaModel(Model):
16101613
model_arch = gguf.MODEL_ARCH.LLAMA
1614+
undo_permute = True
16111615

16121616
def set_vocab(self):
16131617
try:
@@ -1672,10 +1676,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
16721676
n_head = self.hparams["num_attention_heads"]
16731677
n_kv_head = self.hparams.get("num_key_value_heads")
16741678

1675-
if name.endswith(("q_proj.weight", "q_proj.bias")):
1676-
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
1677-
if name.endswith(("k_proj.weight", "k_proj.bias")):
1678-
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
1679+
if self.undo_permute:
1680+
if name.endswith(("q_proj.weight", "q_proj.bias")):
1681+
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
1682+
if name.endswith(("k_proj.weight", "k_proj.bias")):
1683+
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
16791684

16801685
# process the experts separately
16811686
if name.find("block_sparse_moe.experts") != -1:
@@ -1752,6 +1757,61 @@ def prepare_tensors(self):
17521757
raise ValueError(f"Unprocessed experts: {experts}")
17531758

17541759

1760+
@Model.register("Llama4ForConditionalGeneration")
1761+
class Llama4Model(LlamaModel):
1762+
model_arch = gguf.MODEL_ARCH.LLAMA4
1763+
has_vision: bool = False
1764+
undo_permute = False
1765+
1766+
# TODO @ngxson : avoid duplicate this code everywhere by at least support "text_config"
1767+
# same with llama, but we need to merge the text_config into the root level of hparams
1768+
def __init__(self, *args, **kwargs):
1769+
hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0])
1770+
if "text_config" in hparams:
1771+
hparams = {**hparams, **hparams["text_config"]}
1772+
kwargs["hparams"] = hparams
1773+
super().__init__(*args, **kwargs)
1774+
if "vision_config" in hparams:
1775+
logger.info("Has vision encoder, but it will be ignored")
1776+
self.has_vision = True
1777+
# IMPORTANT: the normal "intermediate_size" is renamed to "intermediate_size_mlp", we need to undo this
1778+
self.hparams["intermediate_size_moe"] = self.hparams["intermediate_size"]
1779+
self.hparams["intermediate_size"] = self.hparams["intermediate_size_mlp"]
1780+
1781+
def set_vocab(self):
1782+
self._set_vocab_gpt2()
1783+
self.gguf_writer.add_add_bos_token(True)
1784+
1785+
def set_gguf_parameters(self):
1786+
super().set_gguf_parameters()
1787+
self.gguf_writer.add_interleave_moe_layer_step(self.hparams["interleave_moe_layer_step"])
1788+
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"])
1789+
1790+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
1791+
name = name.replace("language_model.", "")
1792+
name = name.replace("feed_forward.", "mlp.") # a bit hacky for now
1793+
name = name.replace(".router.weight", ".gate.weight") # a bit hacky for now
1794+
1795+
# split the gate_up into gate and up
1796+
if "gate_up_proj" in name:
1797+
name_up = name.replace("gate_up_proj", "up_proj.weight")
1798+
name_gate = name.replace("gate_up_proj", "gate_proj.weight")
1799+
dim_half = data_torch.shape[-1] // 2
1800+
gate_proj_weight, up_proj_weight = data_torch.transpose(-1, -2).split(dim_half, dim=-2)
1801+
return [
1802+
(self.map_tensor_name(name_gate), gate_proj_weight),
1803+
(self.map_tensor_name(name_up), up_proj_weight)
1804+
]
1805+
1806+
if name.endswith("down_proj"):
1807+
name += ".weight"
1808+
data_torch = data_torch.transpose(-1, -2)
1809+
1810+
if "multi_modal_projector" in name or "vision_model" in name:
1811+
return []
1812+
return super().modify_tensors(data_torch, name, bid)
1813+
1814+
17551815
@Model.register("Mistral3ForConditionalGeneration")
17561816
class Mistral3Model(LlamaModel):
17571817
model_arch = gguf.MODEL_ARCH.LLAMA

convert_hf_to_gguf_update.py

+1
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class TOKENIZER_TYPE(IntEnum):
113113
{"name": "superbpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k", },
114114
{"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", },
115115
{"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", },
116+
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
116117
]
117118

118119

gguf-py/gguf/constants.py

+26
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ class LLM:
116116
RESIDUAL_SCALE = "{arch}.residual_scale"
117117
EMBEDDING_SCALE = "{arch}.embedding_scale"
118118
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
119+
INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step"
119120

120121
class Attention:
121122
HEAD_COUNT = "{arch}.attention.head_count"
@@ -227,6 +228,7 @@ class GGUFType:
227228

228229
class MODEL_ARCH(IntEnum):
229230
LLAMA = auto()
231+
LLAMA4 = auto()
230232
DECI = auto()
231233
FALCON = auto()
232234
BAICHUAN = auto()
@@ -431,6 +433,7 @@ class MODEL_TENSOR(IntEnum):
431433

432434
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
433435
MODEL_ARCH.LLAMA: "llama",
436+
MODEL_ARCH.LLAMA4: "llama4",
434437
MODEL_ARCH.DECI: "deci",
435438
MODEL_ARCH.FALCON: "falcon",
436439
MODEL_ARCH.BAICHUAN: "baichuan",
@@ -654,6 +657,29 @@ class MODEL_TENSOR(IntEnum):
654657
MODEL_TENSOR.FFN_DOWN_EXP,
655658
MODEL_TENSOR.FFN_UP_EXP,
656659
],
660+
MODEL_ARCH.LLAMA4: [
661+
MODEL_TENSOR.TOKEN_EMBD,
662+
MODEL_TENSOR.OUTPUT_NORM,
663+
MODEL_TENSOR.OUTPUT,
664+
MODEL_TENSOR.ROPE_FREQS,
665+
MODEL_TENSOR.ATTN_NORM,
666+
MODEL_TENSOR.ATTN_Q,
667+
MODEL_TENSOR.ATTN_K,
668+
MODEL_TENSOR.ATTN_V,
669+
MODEL_TENSOR.ATTN_OUT,
670+
MODEL_TENSOR.ATTN_ROT_EMBD,
671+
MODEL_TENSOR.FFN_GATE_INP,
672+
MODEL_TENSOR.FFN_NORM,
673+
MODEL_TENSOR.FFN_GATE,
674+
MODEL_TENSOR.FFN_DOWN,
675+
MODEL_TENSOR.FFN_UP,
676+
MODEL_TENSOR.FFN_GATE_EXP,
677+
MODEL_TENSOR.FFN_DOWN_EXP,
678+
MODEL_TENSOR.FFN_UP_EXP,
679+
MODEL_TENSOR.FFN_GATE_SHEXP,
680+
MODEL_TENSOR.FFN_DOWN_SHEXP,
681+
MODEL_TENSOR.FFN_UP_SHEXP,
682+
],
657683
MODEL_ARCH.DECI: [
658684
MODEL_TENSOR.TOKEN_EMBD,
659685
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/gguf_writer.py

+3
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,9 @@ def add_wkv_head_size(self, size: int) -> None:
746746
def add_token_shift_count(self, count: int) -> None:
747747
self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count)
748748

749+
def add_interleave_moe_layer_step(self, value: int) -> None:
750+
self.add_uint32(Keys.LLM.INTERLEAVE_MOE_LAYER_STEP.format(arch=self.arch), value)
751+
749752
def add_layer_norm_eps(self, value: float) -> None:
750753
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
751754

include/llama.h

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ extern "C" {
110110
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
111111
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
112112
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
113+
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
113114
};
114115

115116
enum llama_rope_type {

models/ggml-vocab-llama4.gguf.inp

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
ied 4 ½ months
2+
__ggml_vocab_test__
3+
Führer
4+
__ggml_vocab_test__
5+
6+
__ggml_vocab_test__
7+
8+
__ggml_vocab_test__
9+
10+
__ggml_vocab_test__
11+
12+
__ggml_vocab_test__
13+
14+
__ggml_vocab_test__
15+
16+
17+
__ggml_vocab_test__
18+
19+
20+
21+
__ggml_vocab_test__
22+
23+
24+
25+
26+
__ggml_vocab_test__
27+
28+
29+
__ggml_vocab_test__
30+
Hello world
31+
__ggml_vocab_test__
32+
Hello world
33+
__ggml_vocab_test__
34+
Hello World
35+
__ggml_vocab_test__
36+
Hello World
37+
__ggml_vocab_test__
38+
Hello World!
39+
__ggml_vocab_test__
40+
Hello, world!
41+
__ggml_vocab_test__
42+
Hello, world!
43+
__ggml_vocab_test__
44+
this is 🦙.cpp
45+
__ggml_vocab_test__
46+
w048 7tuijk dsdfhu
47+
__ggml_vocab_test__
48+
нещо на Български
49+
__ggml_vocab_test__
50+
កាន់តែពិសេសអាចខលចេញ
51+
__ggml_vocab_test__
52+
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
53+
__ggml_vocab_test__
54+
Hello
55+
__ggml_vocab_test__
56+
Hello
57+
__ggml_vocab_test__
58+
Hello
59+
__ggml_vocab_test__
60+
Hello
61+
__ggml_vocab_test__
62+
Hello
63+
__ggml_vocab_test__
64+
Hello
65+
Hello
66+
__ggml_vocab_test__
67+
(
68+
__ggml_vocab_test__
69+
70+
=
71+
__ggml_vocab_test__
72+
' era
73+
__ggml_vocab_test__
74+
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
75+
__ggml_vocab_test__
76+
!!!!!!
77+
__ggml_vocab_test__
78+
3
79+
__ggml_vocab_test__
80+
33
81+
__ggml_vocab_test__
82+
333
83+
__ggml_vocab_test__
84+
3333
85+
__ggml_vocab_test__
86+
33333
87+
__ggml_vocab_test__
88+
333333
89+
__ggml_vocab_test__
90+
3333333
91+
__ggml_vocab_test__
92+
33333333
93+
__ggml_vocab_test__
94+
333333333
95+
__ggml_vocab_test__
96+
Cửa Việt
97+
__ggml_vocab_test__
98+
discards
99+
__ggml_vocab_test__
100+
101+
102+
103+
104+
105+
106+
107+
108+
109+
110+
111+
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
112+
__ggml_vocab_test__

models/ggml-vocab-llama4.gguf.out

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
1190 220 32 220 18215 7112
2+
50 16800 258
3+
4+
220
5+
256
6+
277
7+
197
8+
198
9+
368
10+
2946
11+
3271
12+
19873 3817
13+
39715 3817
14+
19873 7353
15+
39715 7353
16+
39715 7353 13
17+
19873 24 3817 13
18+
39715 24 3817 13
19+
544 373 9522 112 247 26 36315
20+
99 39923 220 35 9607 21498 21470 3679 9433
21+
1595 7653 633 79829 34051 1636
22+
8755 102595 115960 21125 148305 96819 102816 39048 14105 22528 160234
23+
114590 222 330 14879 21 51358 127 12817 93293 117 24204 330 68239 881 120327 170428 21 89101 330 7384 88230 511 947 1492 3742 7233 21
24+
19873
25+
39715
26+
220 39715
27+
256 39715
28+
277 39715
29+
277 39715 198 277 39715
30+
330
31+
198 319
32+
19 7359
33+
19873 24 386 87799 13 2403 583 650 51358 223 1663 155736 1522 42056 7544 13336 28785 29 4412 20645
34+
17931 4959
35+
31
36+
1922
37+
12325
38+
12325 31
39+
12325 1922
40+
12325 12325
41+
12325 12325 31
42+
12325 12325 1922
43+
12325 12325 12325
44+
47 19811 12077
45+
3260 3579
46+
198 7283 51499 191231 20192 3271 3322 9287 2143 17860 114590 222 330 14879 21 51358 127 12817 93293 117 24204 330 68239 881 120327 170428 21 89101 9522 112 247 172394 247 220 31 220 1922 220 12325 220 12325 31 220 12325 1922 220 12325 12325 220 12325 12325 31 220 12325 12325 1922 220 31 26 31 220 31 396 31 220 31 1043 31 117131 102595 115960 21125 148305 96819 102816 80883 223 1663 155736 1522 42056 7544 13336 28785 29 4412 20645 79745 150278 117079 633 79829 34051 1636 25611 41990 109428 1488 91054 24072 17931 4959 29795 9296 16517 1806 481 96 1386 36633 1609 24 481 1109 650 5074 43 481 57 702 5074 27088 2170 536 24 481 48 650 1933 1696 30262 43 1665 19 32818 262 27236 56

src/llama-arch.cpp

+31
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
88
{ LLM_ARCH_LLAMA, "llama" },
9+
{ LLM_ARCH_LLAMA4, "llama4" },
910
{ LLM_ARCH_DECI, "deci" },
1011
{ LLM_ARCH_FALCON, "falcon" },
1112
{ LLM_ARCH_GROK, "grok" },
@@ -114,6 +115,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
114115
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
115116
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
116117
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
118+
{ LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },
117119

118120
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
119121
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -233,6 +235,35 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
233235
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
234236
},
235237
},
238+
{
239+
LLM_ARCH_LLAMA4,
240+
{
241+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
242+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
243+
{ LLM_TENSOR_OUTPUT, "output" },
244+
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
245+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
246+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
247+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
248+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
249+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
250+
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
251+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
252+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
253+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
254+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
255+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
256+
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
257+
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
258+
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
259+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
260+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
261+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
262+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
263+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
264+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
265+
},
266+
},
236267
{
237268
LLM_ARCH_DECI,
238269
{

0 commit comments

Comments
 (0)