Skip to content

llama : add Xiaomi Mimo (with proper MTP - multi token predict) #13236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1773,7 +1773,8 @@ def prepare_tensors(self):
"MistralForCausalLM",
"MixtralForCausalLM",
"VLlama3ForCausalLM",
"LlavaForConditionalGeneration")
"LlavaForConditionalGeneration",
)
class LlamaModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLAMA
undo_permute = True
Expand Down Expand Up @@ -2741,6 +2742,32 @@ def set_gguf_parameters(self):
self.gguf_writer.add_causal_attention(False)


@ModelBase.register("MiMoForCausalLM")
class MimoModel(Qwen2Model):
model_arch = gguf.MODEL_ARCH.QWEN2
n_multi_token_predict: int
n_layers_no_mtp: int

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.n_multi_token_predict = self.hparams["num_nextn_predict_layers"]
self.n_layers_no_mtp = self.block_count
self.block_count = self.block_count + self.n_multi_token_predict
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)

def set_gguf_parameters(self):
super().set_gguf_parameters()
print(self.hparams)
self.gguf_writer.add_n_multi_token_predict(self.hparams["num_nextn_predict_layers"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if "mtp_layers" in name and bid is not None:
name = name.replace(".mtp_layers", ".layers")
for i in range(self.n_multi_token_predict):
name = name.replace(f"layers.{i}.", f"layers.{self.n_layers_no_mtp + i}.")
return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Qwen2MoeForCausalLM")
class Qwen2MoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.QWEN2MOE
Expand Down
11 changes: 11 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class LLM:
EMBEDDING_SCALE = "{arch}.embedding_scale"
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step"
N_MULTI_TOKEN_PREDICT = "{arch}.n_multi_token_predict"

class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
Expand Down Expand Up @@ -375,6 +376,9 @@ class MODEL_TENSOR(IntEnum):
ATTN_Q_NORM = auto()
ATTN_K_NORM = auto()
LAYER_OUT_NORM = auto()
MTP_INP_PROJ = auto()
MTP_TOKEN_NORM = auto() # token_layernorm
MTP_HIDDEN_NORM = auto() # hidden_layernorm
SSM_IN = auto()
SSM_CONV1D = auto()
SSM_X = auto()
Expand Down Expand Up @@ -632,6 +636,9 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
MODEL_TENSOR.MTP_INP_PROJ: "blk.{bid}.mtp_inp_proj",
MODEL_TENSOR.MTP_TOKEN_NORM: "blk.{bid}.mtp_token_norm",
MODEL_TENSOR.MTP_HIDDEN_NORM: "blk.{bid}.mtp_hidden_norm",
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
Expand Down Expand Up @@ -1103,6 +1110,10 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.MTP_INP_PROJ, # xiaomi mimo
MODEL_TENSOR.MTP_HIDDEN_NORM, # xiaomi mimo
MODEL_TENSOR.MTP_TOKEN_NORM, # xiaomi mimo
MODEL_TENSOR.LAYER_OUT_NORM, # xiaomi mimo
],
MODEL_ARCH.QWEN2VL: [
MODEL_TENSOR.TOKEN_EMBD,
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,9 @@ def add_remove_extra_whitespaces(self, value: bool) -> None:
def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None:
self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap)

def add_n_multi_token_predict(self, value: int) -> None:
self.add_uint32(Keys.LLM.N_MULTI_TOKEN_PREDICT.format(arch=self.arch), value)

def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
if not isinstance(value, str):
template_default = None
Expand Down
15 changes: 14 additions & 1 deletion gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,20 @@ class TensorNameMap:
"encoder.layers.{bid}.norm2", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
"encoder.layer.{bid}.layer_norm_2" # jina-v2-code
"encoder.layer.{bid}.layer_norm_2", # jina-v2-code
"model.layers.{bid}.final_layernorm", # xiaomi mimo
),

MODEL_TENSOR.MTP_INP_PROJ: (
"model.layers.{bid}.input_proj", # xiaomi mimo
),

MODEL_TENSOR.MTP_TOKEN_NORM: (
"model.layers.{bid}.token_layernorm", # xiaomi mimo
),

MODEL_TENSOR.MTP_HIDDEN_NORM: (
"model.layers.{bid}.hidden_layernorm", # xiaomi mimo
),

MODEL_TENSOR.SSM_IN: (
Expand Down
9 changes: 9 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,12 @@ extern "C" {
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);

// If model supports multi-token predict, this returns number of tokens ; returns 0 otherwise
LLAMA_API int32_t llama_model_n_mtp(const struct llama_model * model);

// Get the i-th multi-token predict model (used by speculative decoding)
LLAMA_API struct llama_model * llama_model_get_mtp(struct llama_model * model, int32_t i);

// Get the model's RoPE frequency scaling factor
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);

Expand Down Expand Up @@ -959,6 +965,9 @@ extern "C" {
// If set to true, the model will only attend to the past tokens
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);

// Set whether to use multi-token predict head ; 0 means no MTP
LLAMA_API void llama_set_mpt_head(struct llama_context * ctx, int32_t n_mtp);

// Set whether the model is in warmup mode or not
// If true, all model tensors are activated during llama_decode() to load and cache their weights.
LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup);
Expand Down
8 changes: 8 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
{ LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },
{ LLM_KV_N_MULTI_TOKEN_PREDICT, "%s.n_multi_token_predict" },

{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
Expand Down Expand Up @@ -579,6 +580,10 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_MTP_INP_PROJ, "blk.%d.mtp_inp_proj" },
{ LLM_TENSOR_MTP_TOKEN_NORM, "blk.%d.mtp_token_norm" },
{ LLM_TENSOR_MTP_HIDDEN_NORM, "blk.%d.mtp_hidden_norm" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
},
},
{
Expand Down Expand Up @@ -1678,6 +1683,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_MTP_INP_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_MTP_TOKEN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_MTP_HIDDEN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
// this tensor is loaded for T5, but never used
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
Expand Down
4 changes: 4 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ enum llm_kv {
LLM_KV_EMBEDDING_SCALE,
LLM_KV_TOKEN_SHIFT_COUNT,
LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
LLM_KV_N_MULTI_TOKEN_PREDICT,

LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
Expand Down Expand Up @@ -362,6 +363,9 @@ enum llm_tensor {
LLM_TENSOR_POS_NET_ATTN_K,
LLM_TENSOR_POS_NET_ATTN_V,
LLM_TENSOR_POS_NET_ATTN_OUT,
LLM_TENSOR_MTP_INP_PROJ,
LLM_TENSOR_MTP_TOKEN_NORM,
LLM_TENSOR_MTP_HIDDEN_NORM,
};

enum llm_tensor_layer {
Expand Down
17 changes: 17 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,18 @@ void llama_context::set_causal_attn(bool value) {
cparams.causal_attn = value;
}

void llama_context::set_causal_attn(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);

cparams.causal_attn = value;
}

void llama_context::set_mpt_head(int32_t value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);

cparams.curr_mtp = value;
}

void llama_context::set_warmup(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);

Expand Down Expand Up @@ -1981,6 +1993,11 @@ void llama_set_causal_attn(llama_context * ctx, bool causal_attn) {
ctx->set_causal_attn(causal_attn);
}

void llama_set_mpt_head(llama_context * ctx, int32_t n_mtp) {
GGML_ASSERT(n_mtp <= llama_model_n_mtp(llama_get_model(ctx)));
ctx->set_mpt_head(n_mtp);
}

void llama_set_warmup(llama_context * ctx, bool warmup) {
ctx->set_warmup(warmup);
}
Expand Down
1 change: 1 addition & 0 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct llama_context {

void set_embeddings (bool value);
void set_causal_attn(bool value);
void set_mpt_head(int32_t value);
void set_warmup(bool value);

void set_adapter_lora(
Expand Down
5 changes: 5 additions & 0 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ struct llama_cparams {
bool no_perf;
bool warmup;

// multi-token predict
// 0 means not using MTP
// N means using the nth MTP head
int32_t curr_mtp = 0;

enum llama_pooling_type pooling_type;

ggml_backend_sched_eval_callback cb_eval;
Expand Down
3 changes: 3 additions & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ struct llama_hparams {
uint32_t n_embd_head_k_mla = 0;
uint32_t n_embd_head_v_mla = 0;

// for multi-token predict
uint32_t n_mtp = 0;

// for WavTokenizer
struct llama_hparams_posnet posnet;
struct llama_hparams_convnext convnext;
Expand Down
36 changes: 36 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
GGML_ASSERT(hparams.n_expert_used == 0);
}

// multi-token predict
ml.get_key(LLM_KV_N_MULTI_TOKEN_PREDICT, hparams.n_mtp, false);

// zero-out the array hparams
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
Expand Down Expand Up @@ -2371,6 +2374,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);

// optional MTP (multi token predict), used by Xiaomi Mimo
layer.mtp_inp_proj = create_tensor(tn(LLM_TENSOR_MTP_INP_PROJ, "weight", i), {n_embd*2, n_embd}, TENSOR_NOT_REQUIRED);
layer.mtp_token_norm = create_tensor(tn(LLM_TENSOR_MTP_TOKEN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.mtp_hidden_norm = create_tensor(tn(LLM_TENSOR_MTP_HIDDEN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
}
} break;
case LLM_ARCH_QWEN2MOE:
Expand Down Expand Up @@ -4317,6 +4326,10 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
}

if (hparams.n_mtp) {
LLAMA_LOG_INFO("%s: n_mtp = %u\n", __func__, hparams.n_mtp);
}

LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
if (pimpl->n_elements >= 1e12) {
LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12);
Expand Down Expand Up @@ -6429,15 +6442,28 @@ struct llm_build_qwen2 : public llm_graph_context {
ggml_tensor * inpL;

inpL = build_inp_embd(model.tok_embd);
ggml_tensor * inp_embd = inpL;

// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();

auto * inp_attn = build_attn_inp_kv_unified();

for (int il = 0; il < n_layer; ++il) {
bool is_mtp = model.layers[il].mtp_inp_proj != nullptr;
ggml_tensor * inpSA = inpL;

// multi token predict
// https://huggingface.co/XiaomiMiMo/MiMo-7B-RL/blob/main/modeling_mimo.py
if (is_mtp) {
ggml_tensor * tmp = build_norm(inp_embd,
model.layers[il].mtp_token_norm,
NULL,
LLM_NORM_RMS, il);
tmp = ggml_concat(ctx0, inpL, tmp, 0); // concat prev state with token embd
inpL = build_lora_mm(model.layers[il].mtp_inp_proj, tmp);
}

// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
Expand Down Expand Up @@ -6510,6 +6536,12 @@ struct llm_build_qwen2 : public llm_graph_context {

cur = ggml_add(ctx0, cur, ffn_inp);

if (is_mtp && model.layers[il].layer_out_norm) {
cur = build_norm(cur,
model.layers[il].layer_out_norm, NULL,
LLM_NORM_RMS, il);
}

cur = build_cvec(cur, il);
cb(cur, "l_out", il);

Expand Down Expand Up @@ -13209,6 +13241,10 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
return model->hparams.n_head_kv();
}

int32_t llama_model_n_mtp(const llama_model * model) {
return model->hparams.n_mtp;
}

// deprecated
int32_t llama_n_ctx_train(const llama_model * model) {
return llama_model_n_ctx_train(model);
Expand Down
5 changes: 5 additions & 0 deletions src/llama-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,11 @@ struct llama_layer {
struct ggml_tensor * ffn_up_scale = nullptr;
struct ggml_tensor * ffn_down_scale = nullptr;

// MTP (multi token predict)
struct ggml_tensor * mtp_inp_proj = nullptr;
struct ggml_tensor * mtp_token_norm = nullptr;
struct ggml_tensor * mtp_hidden_norm = nullptr;

struct llama_layer_posnet posnet;

struct llama_layer_convnext convnext;
Expand Down
Loading