Skip to content

llama : add Xiaomi Mimo (with proper MTP - multi token predict) #13236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

ngxson
Copy link
Collaborator

@ngxson ngxson commented May 1, 2025

This is a WIP,

Given N input token, I can now generate either the next N+1 or N+2, but not yet all the 2 tokens at the same time.

The way it works is:

  • The model has 36+1 layers, the 36 layers are normal layers and 1 extra is the MTP layer
  • If we want to generate N+1 token, we pass N input tokens through the 36 layers, then pass the output to lm_head
  • If we want to generate N+2, take the output from 36 layers + the input embedding, pass it through MTP layer, then finally go to lm_head

If we had an API for multiple output head, it might have been easier. WDYT @ggerganov ?


Illustration from their technical report:

image

@ngxson ngxson linked an issue May 1, 2025 that may be closed by this pull request
4 tasks
@github-actions github-actions bot added the python python script changes label May 1, 2025
@ngxson
Copy link
Collaborator Author

ngxson commented May 1, 2025

Hmm ok I could be missing something here, I'm reusing the same set of input tokens for both N+1 and N+2 steps, while N+2 token need the sampled token from N+1

Not sure yet how we can do this

image

@sorasoras
Copy link

Would this implementation work on deepseev3?

@ngxson
Copy link
Collaborator Author

ngxson commented May 1, 2025

@sorasoras Judging from this illustration, yes it's the same:

https://dataturbo.medium.com/deepseek-technical-analysis-3-multi-token-prediction-f8f3ea7eaf9c

image

@ggerganov
Copy link
Member

Maybe this can be implemented by loading the MTP layers as a separate draft model and reusing the speculative decoding functionality. AFAICT, the predicted tokens from the MTP blocks are technically draft tokens and they have to be accepted by the main decoder.

Btw, based on these 2 diagrams, if they are correct, there is a small difference between DS and Mimo - Mimo uses the same h from the main decoder for all MTP blocks, while DS updates the h after each MTP block.

What is not clear to me is how big is N. In both diagrams, we have N = 4. Is this a parameter?

@ngxson
Copy link
Collaborator Author

ngxson commented May 1, 2025

Ok thanks for the clue, it sounds like what you suggest is exactly what they did on vLLM implementation

I think the N is not important as the MTP layer has its own KV cache. The number of N token need to be corresponse to the number of embedding vectors from h. In other words, this is a way to implement residual connection, basically making the input token embedding to bypass the whole 36 "normal" layers

@ngxson
Copy link
Collaborator Author

ngxson commented May 4, 2025

I'm thinking more about the implementation today, having 2 main ideas in my mind but both have downsides:

First idea is to have an API like llama_model_get_mtp(struct llama_model * model, int32_t i) which returns a shallow copy of the llama_model object, meaning only the pointer-to-tensors are copy, but not the actual data. The copied llama_model object will have a different layer index (or n_layer) to specify that it is a "child" model.

The downside is that:

  • Because this is a shallow copy, llama_model_free will free tensors or both main and child models
  • More importantly, this mean we need 2 different llama_context, and this mean we need yet another API to pass the hidden embeddings from one context to another

My second idea is to have something equivalent to llama_set_causal_attn, meaning the will be an attribute in llama_cparams to specify if we are about to run the llama_decode using main layers or MTP layers. However, managing KV cache in this case is a bit tricky and I still don't yet have any idea how to handle this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Request: XiaomiMiMo/MiMo-7B-RL
3 participants