diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 838999531e580..9a6f513761afd 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1773,7 +1773,8 @@ def prepare_tensors(self): "MistralForCausalLM", "MixtralForCausalLM", "VLlama3ForCausalLM", - "LlavaForConditionalGeneration") + "LlavaForConditionalGeneration", +) class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA undo_permute = True @@ -2741,6 +2742,32 @@ def set_gguf_parameters(self): self.gguf_writer.add_causal_attention(False) +@ModelBase.register("MiMoForCausalLM") +class MimoModel(Qwen2Model): + model_arch = gguf.MODEL_ARCH.QWEN2 + n_multi_token_predict: int + n_layers_no_mtp: int + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.n_multi_token_predict = self.hparams["num_nextn_predict_layers"] + self.n_layers_no_mtp = self.block_count + self.block_count = self.block_count + self.n_multi_token_predict + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + print(self.hparams) + self.gguf_writer.add_n_multi_token_predict(self.hparams["num_nextn_predict_layers"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if "mtp_layers" in name and bid is not None: + name = name.replace(".mtp_layers", ".layers") + for i in range(self.n_multi_token_predict): + name = name.replace(f"layers.{i}.", f"layers.{self.n_layers_no_mtp + i}.") + return super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("Qwen2MoeForCausalLM") class Qwen2MoeModel(TextModel): model_arch = gguf.MODEL_ARCH.QWEN2MOE diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 7dd7bb6d1b5d9..951a44e1fde02 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -118,6 +118,7 @@ class LLM: EMBEDDING_SCALE = "{arch}.embedding_scale" TOKEN_SHIFT_COUNT = "{arch}.token_shift_count" INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step" + N_MULTI_TOKEN_PREDICT = "{arch}.n_multi_token_predict" class Attention: HEAD_COUNT = "{arch}.attention.head_count" @@ -375,6 +376,9 @@ class MODEL_TENSOR(IntEnum): ATTN_Q_NORM = auto() ATTN_K_NORM = auto() LAYER_OUT_NORM = auto() + MTP_INP_PROJ = auto() + MTP_TOKEN_NORM = auto() # token_layernorm + MTP_HIDDEN_NORM = auto() # hidden_layernorm SSM_IN = auto() SSM_CONV1D = auto() SSM_X = auto() @@ -632,6 +636,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", + MODEL_TENSOR.MTP_INP_PROJ: "blk.{bid}.mtp_inp_proj", + MODEL_TENSOR.MTP_TOKEN_NORM: "blk.{bid}.mtp_token_norm", + MODEL_TENSOR.MTP_HIDDEN_NORM: "blk.{bid}.mtp_hidden_norm", MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", @@ -1103,6 +1110,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.MTP_INP_PROJ, # xiaomi mimo + MODEL_TENSOR.MTP_HIDDEN_NORM, # xiaomi mimo + MODEL_TENSOR.MTP_TOKEN_NORM, # xiaomi mimo + MODEL_TENSOR.LAYER_OUT_NORM, # xiaomi mimo ], MODEL_ARCH.QWEN2VL: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index ff50d3de31287..e43f3a7076669 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -899,6 +899,9 @@ def add_remove_extra_whitespaces(self, value: bool) -> None: def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None: self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap) + def add_n_multi_token_predict(self, value: int) -> None: + self.add_uint32(Keys.LLM.N_MULTI_TOKEN_PREDICT.format(arch=self.arch), value) + def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: if not isinstance(value, str): template_default = None diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 2b089f84a841a..f97e956e5b2af 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -457,7 +457,20 @@ class TensorNameMap: "encoder.layers.{bid}.norm2", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_3", # Grok "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 - "encoder.layer.{bid}.layer_norm_2" # jina-v2-code + "encoder.layer.{bid}.layer_norm_2", # jina-v2-code + "model.layers.{bid}.final_layernorm", # xiaomi mimo + ), + + MODEL_TENSOR.MTP_INP_PROJ: ( + "model.layers.{bid}.input_proj", # xiaomi mimo + ), + + MODEL_TENSOR.MTP_TOKEN_NORM: ( + "model.layers.{bid}.token_layernorm", # xiaomi mimo + ), + + MODEL_TENSOR.MTP_HIDDEN_NORM: ( + "model.layers.{bid}.hidden_layernorm", # xiaomi mimo ), MODEL_TENSOR.SSM_IN: ( diff --git a/include/llama.h b/include/llama.h index 06c56395c139f..8cb3f68a607c4 100644 --- a/include/llama.h +++ b/include/llama.h @@ -496,6 +496,12 @@ extern "C" { LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); + // If model supports multi-token predict, this returns number of tokens ; returns 0 otherwise + LLAMA_API int32_t llama_model_n_mtp(const struct llama_model * model); + + // Get the i-th multi-token predict model (used by speculative decoding) + LLAMA_API struct llama_model * llama_model_get_mtp(struct llama_model * model, int32_t i); + // Get the model's RoPE frequency scaling factor LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); @@ -959,6 +965,9 @@ extern "C" { // If set to true, the model will only attend to the past tokens LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn); + // Set whether to use multi-token predict head ; 0 means no MTP + LLAMA_API void llama_set_mpt_head(struct llama_context * ctx, int32_t n_mtp); + // Set whether the model is in warmup mode or not // If true, all model tensors are activated during llama_decode() to load and cache their weights. LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup); diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index f2bc8ca768502..5789e96c911d2 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -121,6 +121,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, { LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" }, + { LLM_KV_N_MULTI_TOKEN_PREDICT, "%s.n_multi_token_predict" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -579,6 +580,10 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_MTP_INP_PROJ, "blk.%d.mtp_inp_proj" }, + { LLM_TENSOR_MTP_TOKEN_NORM, "blk.%d.mtp_token_norm" }, + { LLM_TENSOR_MTP_HIDDEN_NORM, "blk.%d.mtp_hidden_norm" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, }, }, { @@ -1678,6 +1683,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_MTP_INP_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_MTP_TOKEN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_MTP_HIDDEN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 41a023da3da6e..3223d798ef9ba 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -125,6 +125,7 @@ enum llm_kv { LLM_KV_EMBEDDING_SCALE, LLM_KV_TOKEN_SHIFT_COUNT, LLM_KV_INTERLEAVE_MOE_LAYER_STEP, + LLM_KV_N_MULTI_TOKEN_PREDICT, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -362,6 +363,9 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_K, LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_OUT, + LLM_TENSOR_MTP_INP_PROJ, + LLM_TENSOR_MTP_TOKEN_NORM, + LLM_TENSOR_MTP_HIDDEN_NORM, }; enum llm_tensor_layer { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 45591be992d87..f6584e743c1be 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -625,6 +625,18 @@ void llama_context::set_causal_attn(bool value) { cparams.causal_attn = value; } +void llama_context::set_causal_attn(bool value) { + LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + + cparams.causal_attn = value; +} + +void llama_context::set_mpt_head(int32_t value) { + LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + + cparams.curr_mtp = value; +} + void llama_context::set_warmup(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -1981,6 +1993,11 @@ void llama_set_causal_attn(llama_context * ctx, bool causal_attn) { ctx->set_causal_attn(causal_attn); } +void llama_set_mpt_head(llama_context * ctx, int32_t n_mtp) { + GGML_ASSERT(n_mtp <= llama_model_n_mtp(llama_get_model(ctx))); + ctx->set_mpt_head(n_mtp); +} + void llama_set_warmup(llama_context * ctx, bool warmup) { ctx->set_warmup(warmup); } diff --git a/src/llama-context.h b/src/llama-context.h index cf41ac57b9fba..e90b9751dbfac 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -69,6 +69,7 @@ struct llama_context { void set_embeddings (bool value); void set_causal_attn(bool value); + void set_mpt_head(int32_t value); void set_warmup(bool value); void set_adapter_lora( diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 30e550f023a9e..20b1611feacc2 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -31,6 +31,11 @@ struct llama_cparams { bool no_perf; bool warmup; + // multi-token predict + // 0 means not using MTP + // N means using the nth MTP head + int32_t curr_mtp = 0; + enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 7ee6a5b75ad1e..72708a98a7087 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -47,6 +47,9 @@ struct llama_hparams { uint32_t n_embd_head_k_mla = 0; uint32_t n_embd_head_v_mla = 0; + // for multi-token predict + uint32_t n_mtp = 0; + // for WavTokenizer struct llama_hparams_posnet posnet; struct llama_hparams_convnext convnext; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8d25070f42f77..fb00195a3c863 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -455,6 +455,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { GGML_ASSERT(hparams.n_expert_used == 0); } + // multi-token predict + ml.get_key(LLM_KV_N_MULTI_TOKEN_PREDICT, hparams.n_mtp, false); + // zero-out the array hparams std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); @@ -2371,6 +2374,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MTP (multi token predict), used by Xiaomi Mimo + layer.mtp_inp_proj = create_tensor(tn(LLM_TENSOR_MTP_INP_PROJ, "weight", i), {n_embd*2, n_embd}, TENSOR_NOT_REQUIRED); + layer.mtp_token_norm = create_tensor(tn(LLM_TENSOR_MTP_TOKEN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.mtp_hidden_norm = create_tensor(tn(LLM_TENSOR_MTP_HIDDEN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); } } break; case LLM_ARCH_QWEN2MOE: @@ -4317,6 +4326,10 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); } + if (hparams.n_mtp) { + LLAMA_LOG_INFO("%s: n_mtp = %u\n", __func__, hparams.n_mtp); + } + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); if (pimpl->n_elements >= 1e12) { LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); @@ -6429,6 +6442,7 @@ struct llm_build_qwen2 : public llm_graph_context { ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); + ggml_tensor * inp_embd = inpL; // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -6436,8 +6450,20 @@ struct llm_build_qwen2 : public llm_graph_context { auto * inp_attn = build_attn_inp_kv_unified(); for (int il = 0; il < n_layer; ++il) { + bool is_mtp = model.layers[il].mtp_inp_proj != nullptr; ggml_tensor * inpSA = inpL; + // multi token predict + // https://huggingface.co/XiaomiMiMo/MiMo-7B-RL/blob/main/modeling_mimo.py + if (is_mtp) { + ggml_tensor * tmp = build_norm(inp_embd, + model.layers[il].mtp_token_norm, + NULL, + LLM_NORM_RMS, il); + tmp = ggml_concat(ctx0, inpL, tmp, 0); // concat prev state with token embd + inpL = build_lora_mm(model.layers[il].mtp_inp_proj, tmp); + } + // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, @@ -6510,6 +6536,12 @@ struct llm_build_qwen2 : public llm_graph_context { cur = ggml_add(ctx0, cur, ffn_inp); + if (is_mtp && model.layers[il].layer_out_norm) { + cur = build_norm(cur, + model.layers[il].layer_out_norm, NULL, + LLM_NORM_RMS, il); + } + cur = build_cvec(cur, il); cb(cur, "l_out", il); @@ -13209,6 +13241,10 @@ int32_t llama_model_n_head_kv(const llama_model * model) { return model->hparams.n_head_kv(); } +int32_t llama_model_n_mtp(const llama_model * model) { + return model->hparams.n_mtp; +} + // deprecated int32_t llama_n_ctx_train(const llama_model * model) { return llama_model_n_ctx_train(model); diff --git a/src/llama-model.h b/src/llama-model.h index 815fa11ebca59..e91743e20b7ea 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -313,6 +313,11 @@ struct llama_layer { struct ggml_tensor * ffn_up_scale = nullptr; struct ggml_tensor * ffn_down_scale = nullptr; + // MTP (multi token predict) + struct ggml_tensor * mtp_inp_proj = nullptr; + struct ggml_tensor * mtp_token_norm = nullptr; + struct ggml_tensor * mtp_hidden_norm = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext;