Skip to content

llama : Support llama 4 text-only #12791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 64 additions & 4 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "96a5f08be6259352137b512d4157e333e21df7edd3fcd152990608735a65b224":
# ref: https://huggingface.co/inclusionAI/Ling-lite
res = "bailingmoe"
if chkhsh == "d353350c764d8c3b39c763113960e4fb4919bea5fbf208a0e3b22e8469dc7406":
# ref: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct
res = "llama4"

if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -1608,6 +1611,7 @@ def prepare_tensors(self):
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
class LlamaModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA
undo_permute = True

def set_vocab(self):
try:
Expand Down Expand Up @@ -1672,10 +1676,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")

if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
if self.undo_permute:
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)

# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
Expand Down Expand Up @@ -1752,6 +1757,61 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@Model.register("Llama4ForConditionalGeneration")
class Llama4Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.LLAMA4
has_vision: bool = False
undo_permute = False

# TODO @ngxson : avoid duplicate this code everywhere by at least support "text_config"
# same with llama, but we need to merge the text_config into the root level of hparams
def __init__(self, *args, **kwargs):
hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0])
if "text_config" in hparams:
hparams = {**hparams, **hparams["text_config"]}
kwargs["hparams"] = hparams
super().__init__(*args, **kwargs)
if "vision_config" in hparams:
logger.info("Has vision encoder, but it will be ignored")
self.has_vision = True
# IMPORTANT: the normal "intermediate_size" is renamed to "intermediate_size_mlp", we need to undo this
self.hparams["intermediate_size_moe"] = self.hparams["intermediate_size"]
self.hparams["intermediate_size"] = self.hparams["intermediate_size_mlp"]

def set_vocab(self):
self._set_vocab_gpt2()
self.gguf_writer.add_add_bos_token(True)

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_interleave_moe_layer_step(self.hparams["interleave_moe_layer_step"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
name = name.replace("language_model.", "")
name = name.replace("feed_forward.", "mlp.") # a bit hacky for now
name = name.replace(".router.weight", ".gate.weight") # a bit hacky for now

# split the gate_up into gate and up
if "gate_up_proj" in name:
name_up = name.replace("gate_up_proj", "up_proj.weight")
name_gate = name.replace("gate_up_proj", "gate_proj.weight")
dim_half = data_torch.shape[-1] // 2
gate_proj_weight, up_proj_weight = data_torch.transpose(-1, -2).split(dim_half, dim=-2)
Copy link
Collaborator

@compilade compilade Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lazy evaluation doesn't support splitting yet, so this will always eagerly evaluate (and so it will take more RAM than ideal during conversion).

This may or may not explain the conversion slowness others are noticing.

This can be fixed in gguf/gguf-py/lazy.py by handling tuples of tensors as output values. I have the necessary changes somewhere, I'll open a PR once I find them.

(EDIT: see #12809)

Copy link
Collaborator Author

@ngxson ngxson Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Off-topic question: is it possible to somehow extend LazyTorchTensor to load a tensor remotely? FYI huggingface backend supports byte range, so an idea could be to read the tensor one by one completely on RAM, without having to download them to disk

Copy link
Collaborator

@compilade compilade Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Off-topic question: is it possible to somehow extend LazyTorchTensor to load a tensor remotely? FYI huggingface backend supports byte range, so an idea could be to read the tensor one by one completely on RAM, without having to download them to disk

Huh, I never thought of it, but yes, technically this should totally be possible. Lazy tensors only needs the name, shape and type of the tensor for the fake tensors, and then a way to turn the original fake tensors into real tensors.

The hardest part of this wouldn't necessarily be the lazy tensors, but how the remote paths would be specified and how it would interact with the default output path and default name of the output file, and how the tensors would be enumerated and how the config file and the tokenizer would be fetched.

There's a lot of tokenizer-related code which assumes local files.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can rely on AutoTokenizer.from_pretrained, which will download tokenizer files to a temporary directory. Will have a look on how it works.

We can alternatively rely on huggingface_hub.download() which accepts a pattern of file name to download (so for example, we can disallow downloading safetensors)

In my case, loading safetensors remotely can be very useful. I couldn't test the 409B maverick model as it requires 1.5TB in total to store both HF model + gguf, but HF space only provides at max 1TB of storage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngxson I just tested Maverick -sadly it doesn't work on BF16 conversion - it's cause "interleave_moe_layer_step": 2, so every 2nd layer / odd layer is MoE, whilst the rest are FFN.

Error:

INFO:hf-to-gguf:gguf: loading model part 'model-00002-of-00055.safetensors'
INFO:hf-to-gguf:blk.1.ffn_down_exps.weight,   torch.bfloat16 --> BF16, shape = {8192, 5120, 128}
INFO:hf-to-gguf:gguf: loading model part 'model-00003-of-00050.safetensors'
INFO:hf-to-gguf:blk.1.ffn_down_exps.weight,   torch.bfloat16 --> BF16, shape = {8192, 5120, 16}
Traceback (most recent call last):
ValueError: Duplicated tensor name 'blk.1.ffn_down_exps.weight'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just downloaded unsloth/Llama-4-Maverick-17B-128E-Instruct from HF. Everything looks good so far. (On a machine with 2 TB RAM and 13 TB SSD.)

It took me about 50 minutes to reach this point.

INFO:hf-to-gguf:gguf: loading model part 'model-00030-of-00055.safetensors'
INFO:hf-to-gguf:blk.27.ffn_gate_exps.weight,  torch.bfloat16 --> F16, shape = {5120, 8192, 128}
INFO:hf-to-gguf:blk.27.ffn_up_exps.weight,    torch.bfloat16 --> F16, shape = {5120, 8192, 128}
INFO:hf-to-gguf:gguf: loading model part 'model-00031-of-00055.safetensors'
INFO:hf-to-gguf:blk.27.ffn_down_exps.weight,  torch.bfloat16 --> F16, shape = {8192, 5120, 128}
INFO:hf-to-gguf:gguf: loading model part 'model-00032-of-00055.safetensors'
INFO:hf-to-gguf:blk.29.ffn_gate_exps.weight,  torch.bfloat16 --> F16, shape = {5120, 8192, 128}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving this comment here for viz: we discussed via DM and turns out Daniel was using wrong directory 😂

+1 reason to support converting HF --> gguf without downloading to disk

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 reason to support converting HF --> gguf without downloading to disk

This feature will be 🔥

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got past the ValueError: Duplicated tensor name 'blk.1.ffn_down_exps.weight' error, thanks!:

INFO:gguf.gguf_writer:Writing the following files:
INFO:gguf.gguf_writer://Volumes/storage 1/models/LLaMA-4-Maverick-17B.gguf: n_tensors = 531, total_size = 801.5G
Writing:   0%|                                                                                                     | 0.00/801G [00:00<?, ?byte/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Writing:   1%|█▏    

Now I wait!

Copy link
Contributor

@yeahdongcn yeahdongcn Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: More Information on Llama 4 Maverick

🛠️ Model Conversion

root@bb22ebf4525a:/ws# time python convert_hf_to_gguf.py model_path
...
INFO:hf-to-gguf:Set model quantization version  
INFO:gguf.gguf_writer:Writing the following files:  
INFO:gguf.gguf_writer:/model_path/840ed22d9bc7731246bc119cca026a48a0ff8ec6-128x17B-840ed22d9bc7731246bc119cca026a48a0ff8ec6-F16.gguf: n_tensors = 531, total_size = 801.5G  
...
real    302m33.600s  
user    285m44.179s  
sys     113m0.490s  

🔧 Quantization

root@bb22ebf4525a:/ws# time ./build/bin/llama-quantize model_path llama4_maverick_q4_k_m.gguf Q4_K_M
...
llama_model_quantize_impl: model size  = 764328.14 MB  
llama_model_quantize_impl: quant size  = 231508.31 MB  

main: quantize time = 1372408.18 ms  
main:    total time = 1372408.18 ms  
./build/bin/llama-quantize   Q4_K_M  55951.43s user 4741.92s system 4421% cpu 22:52.55 total  

🧪 Tested with MUSA backend:

> Hi    
Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat?

> How many 'R' in strawberry?
There are 2 'R's in the word "strawberry".

> 

return [
(self.map_tensor_name(name_gate), gate_proj_weight),
(self.map_tensor_name(name_up), up_proj_weight)
]

if name.endswith("down_proj"):
name += ".weight"
data_torch = data_torch.transpose(-1, -2)

if "multi_modal_projector" in name or "vision_model" in name:
return []
return super().modify_tensors(data_torch, name, bid)


@Model.register("Mistral3ForConditionalGeneration")
class Mistral3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.LLAMA
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "superbpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k", },
{"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", },
{"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", },
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
]


Expand Down
26 changes: 26 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class LLM:
RESIDUAL_SCALE = "{arch}.residual_scale"
EMBEDDING_SCALE = "{arch}.embedding_scale"
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step"

class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
Expand Down Expand Up @@ -227,6 +228,7 @@ class GGUFType:

class MODEL_ARCH(IntEnum):
LLAMA = auto()
LLAMA4 = auto()
DECI = auto()
FALCON = auto()
BAICHUAN = auto()
Expand Down Expand Up @@ -431,6 +433,7 @@ class MODEL_TENSOR(IntEnum):

MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.LLAMA4: "llama4",
MODEL_ARCH.DECI: "deci",
MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.BAICHUAN: "baichuan",
Expand Down Expand Up @@ -654,6 +657,29 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.LLAMA4: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
MODEL_ARCH.DECI: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,9 @@ def add_wkv_head_size(self, size: int) -> None:
def add_token_shift_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count)

def add_interleave_moe_layer_step(self, value: int) -> None:
self.add_uint32(Keys.LLM.INTERLEAVE_MOE_LAYER_STEP.format(arch=self.arch), value)

def add_layer_norm_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)

Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
};

enum llama_rope_type {
Expand Down
112 changes: 112 additions & 0 deletions models/ggml-vocab-llama4.gguf.inp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
ied 4 ½ months
__ggml_vocab_test__
Führer
__ggml_vocab_test__

__ggml_vocab_test__

__ggml_vocab_test__

__ggml_vocab_test__

__ggml_vocab_test__

__ggml_vocab_test__


__ggml_vocab_test__



__ggml_vocab_test__




__ggml_vocab_test__


__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
this is 🦙.cpp
__ggml_vocab_test__
w048 7tuijk dsdfhu
__ggml_vocab_test__
нещо на Български
__ggml_vocab_test__
កាន់តែពិសេសអាចខលចេញ
__ggml_vocab_test__
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
Hello
__ggml_vocab_test__
(
__ggml_vocab_test__

=
__ggml_vocab_test__
' era
__ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
__ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3
__ggml_vocab_test__
33
__ggml_vocab_test__
333
__ggml_vocab_test__
3333
__ggml_vocab_test__
33333
__ggml_vocab_test__
333333
__ggml_vocab_test__
3333333
__ggml_vocab_test__
33333333
__ggml_vocab_test__
333333333
__ggml_vocab_test__
Cửa Việt
__ggml_vocab_test__
discards
__ggml_vocab_test__











🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
__ggml_vocab_test__
46 changes: 46 additions & 0 deletions models/ggml-vocab-llama4.gguf.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
1190 220 32 220 18215 7112
50 16800 258

220
256
277
197
198
368
2946
3271
19873 3817
39715 3817
19873 7353
39715 7353
39715 7353 13
19873 24 3817 13
39715 24 3817 13
544 373 9522 112 247 26 36315
99 39923 220 35 9607 21498 21470 3679 9433
1595 7653 633 79829 34051 1636
8755 102595 115960 21125 148305 96819 102816 39048 14105 22528 160234
114590 222 330 14879 21 51358 127 12817 93293 117 24204 330 68239 881 120327 170428 21 89101 330 7384 88230 511 947 1492 3742 7233 21
19873
39715
220 39715
256 39715
277 39715
277 39715 198 277 39715
330
198 319
19 7359
19873 24 386 87799 13 2403 583 650 51358 223 1663 155736 1522 42056 7544 13336 28785 29 4412 20645
17931 4959
31
1922
12325
12325 31
12325 1922
12325 12325
12325 12325 31
12325 12325 1922
12325 12325 12325
47 19811 12077
3260 3579
198 7283 51499 191231 20192 3271 3322 9287 2143 17860 114590 222 330 14879 21 51358 127 12817 93293 117 24204 330 68239 881 120327 170428 21 89101 9522 112 247 172394 247 220 31 220 1922 220 12325 220 12325 31 220 12325 1922 220 12325 12325 220 12325 12325 31 220 12325 12325 1922 220 31 26 31 220 31 396 31 220 31 1043 31 117131 102595 115960 21125 148305 96819 102816 80883 223 1663 155736 1522 42056 7544 13336 28785 29 4412 20645 79745 150278 117079 633 79829 34051 1636 25611 41990 109428 1488 91054 24072 17931 4959 29795 9296 16517 1806 481 96 1386 36633 1609 24 481 1109 650 5074 43 481 57 702 5074 27088 2170 536 24 481 48 650 1933 1696 30262 43 1665 19 32818 262 27236 56
31 changes: 31 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA, "llama" },
{ LLM_ARCH_LLAMA4, "llama4" },
{ LLM_ARCH_DECI, "deci" },
{ LLM_ARCH_FALCON, "falcon" },
{ LLM_ARCH_GROK, "grok" },
Expand Down Expand Up @@ -114,6 +115,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
{ LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },

{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
Expand Down Expand Up @@ -233,6 +235,35 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_LLAMA4,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_DECI,
{
Expand Down
Loading
Loading