Skip to content

graph : normalize Q, K, V shapes and add comments #12449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 18, 2025

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented Mar 18, 2025

ref #12435 (comment)

  • Always pass 3D shapes to the attention operations
  • Add missing synchronization when getting the cross attention data

@@ -1378,7 +1378,7 @@ ggml_tensor * llm_graph_context::build_attn(
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));

assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should fix the T5 bug: #12435

@ggerganov ggerganov requested a review from fairydreaming March 18, 2025 11:55
Copy link
Collaborator

@fairydreaming fairydreaming left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirm that this fixed the remaining issue with T5 models. I went over the code and don't see any obvious problems.

Since these changes affect multiple models I briefly tested some that I had on disk: llama-3.1, llama 3.2, phi-4, gemma-2, deepseek r1, qwen 2.5 and t5 family. All seem to work fine.

@steampunque
Copy link

I applied this diff to b4914 and it still does not work:

llama-cli -m /datahd/models/madlad400-7b-mt.Q6_K.gguf      --color -n -1 -ngl 60 -c 512 -ctk f16 -ctv f16 -b 512 -ub 512    -n 512 --keep 0    --temp 0.0 --dynatemp-range 0.0 --dynatemp-exp 1.0    --top-k 40 --top-p 0.95 --typical 1.0 --min-p 0.00    --repeat-last-n 64 --repeat-penalty 1.0    --presence-penalty 0.0 --frequency-penalty 0.0    --mirostat 0 --mirostat-lr 0.1 --mirostat-ent 5.0        -p "<2de> Today it rains" --in-prefix "" --in-suffix ""

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: yes
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce GTX 1070, compute capability 6.1, VMM: yes
build: 4914 (8551c44) with cc (GCC) 11.2.0 for x86_64-slackware-linux
main: llama backend init
main: load the model and apply lora adapter, if any
llama_model_load_from_file_impl: using device CUDA0 (NVIDIA GeForce GTX 1070) - 7828 MiB free
llama_model_loader: loaded meta data with 26 key-value pairs and 1110 tensors from /datahd/models/madlad400-7b-mt.Q6_K.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv 0: general.architecture str = t5
llama_model_loader: - kv 1: general.name str = T5
llama_model_loader: - kv 2: t5.context_length u32 = 512
llama_model_loader: - kv 3: t5.embedding_length u32 = 2048
llama_model_loader: - kv 4: t5.feed_forward_length u32 = 8192
llama_model_loader: - kv 5: t5.block_count u32 = 48
llama_model_loader: - kv 6: t5.attention.head_count u32 = 16
llama_model_loader: - kv 7: t5.attention.key_length u32 = 128
llama_model_loader: - kv 8: t5.attention.value_length u32 = 128
llama_model_loader: - kv 9: t5.attention.layer_norm_epsilon f32 = 0.000001
llama_model_loader: - kv 10: t5.attention.relative_buckets_count u32 = 32
llama_model_loader: - kv 11: t5.attention.layer_norm_rms_epsilon f32 = 0.000001
llama_model_loader: - kv 12: t5.decoder_start_token_id u32 = 0
llama_model_loader: - kv 13: general.file_type u32 = 18
llama_model_loader: - kv 14: tokenizer.ggml.model str = t5
llama_model_loader: - kv 15: tokenizer.ggml.pre str = default
llama_model_loader: - kv 16: tokenizer.ggml.tokens arr[str,256000] = ["", "", "", "\n", "<2ace>...
llama_model_loader: - kv 17: tokenizer.ggml.scores arr[f32,256000] = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv 18: tokenizer.ggml.token_type arr[i32,256000] = [2, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv 19: tokenizer.ggml.add_space_prefix bool = true
llama_model_loader: - kv 20: tokenizer.ggml.remove_extra_whitespaces bool = false
llama_model_loader: - kv 21: tokenizer.ggml.eos_token_id u32 = 2
llama_model_loader: - kv 22: tokenizer.ggml.padding_token_id u32 = 1
llama_model_loader: - kv 23: tokenizer.ggml.add_bos_token bool = false
llama_model_loader: - kv 24: tokenizer.ggml.add_eos_token bool = true
llama_model_loader: - kv 25: general.quantization_version u32 = 2
llama_model_loader: - type f32: 242 tensors
llama_model_loader: - type q6_K: 866 tensors
llama_model_loader: - type bf16: 2 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type = Q6_K
print_info: file size = 6.34 GiB (6.56 BPW)
load: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect
load: special tokens cache size = 3
load: token to piece cache size = 1.7509 MB
print_info: arch = t5
print_info: vocab_only = 0
print_info: n_ctx_train = 512
print_info: n_embd = 2048
print_info: n_layer = 48
print_info: n_head = 16
print_info: n_head_kv = 16
print_info: n_rot = 128
print_info: n_swa = 0
print_info: n_swa_pattern = 1
print_info: n_embd_head_k = 128
print_info: n_embd_head_v = 128
print_info: n_gqa = 1
print_info: n_embd_k_gqa = 2048
print_info: n_embd_v_gqa = 2048
print_info: f_norm_eps = 0.0e+00
print_info: f_norm_rms_eps = 1.0e-06
print_info: f_clamp_kqv = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale = 0.0e+00
print_info: f_attn_scale = 0.0e+00
print_info: n_ff = 8192
print_info: n_expert = 0
print_info: n_expert_used = 0
print_info: causal attn = 1
print_info: pooling type = 0
print_info: rope type = -1
print_info: rope scaling = linear
print_info: freq_base_train = 10000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn = 512
print_info: rope_finetuned = unknown
print_info: ssm_d_conv = 0
print_info: ssm_d_inner = 0
print_info: ssm_d_state = 0
print_info: ssm_dt_rank = 0
print_info: ssm_dt_b_c_rms = 0
print_info: model type = ?B
print_info: model params = 8.30 B
print_info: general.name = T5
print_info: vocab type = UGM
print_info: n_vocab = 256000
print_info: n_merges = 0
print_info: EOS token = 2 ''
print_info: UNK token = 2 ''
print_info: PAD token = 1 ''
print_info: LF token = 805 '▁'
print_info: EOG token = 2 '
'
print_info: max token length = 48
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 48 repeating layers to GPU
load_tensors: offloading output layer to GPU
load_tensors: offloaded 49/49 layers to GPU
load_tensors: CPU_Mapped model buffer size = 2917.78 MiB
load_tensors: CUDA0 model buffer size = 6082.05 MiB
..........................................................................................
llama_context: constructing llama_context
llama_context: n_seq_max = 1
llama_context: n_ctx = 512
llama_context: n_ctx_per_seq = 512
llama_context: n_batch = 512
llama_context: n_ubatch = 512
llama_context: causal_attn = 1
llama_context: flash_attn = 0
llama_context: freq_base = 10000.0
llama_context: freq_scale = 1
llama_context: yarn_log_mul = 0
llama_context: CUDA_Host output buffer size = 0.98 MiB
init: kv_size = 512, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 48, can_shift = 1
init: CUDA0 KV buffer size = 192.00 MiB
llama_context: KV self size = 192.00 MiB, K (f16): 96.00 MiB, V (f16): 96.00 MiB
llama_context: CUDA0 compute buffer size = 508.03 MiB
llama_context: CUDA_Host compute buffer size = 23.00 MiB
llama_context: graph nodes = 2790
llama_context: graph splits = 98
common_init_from_params: setting dry_penalty_last_n to ctx_size = 512
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
main: llama threadpool init, n_threads = 4

system_info: n_threads = 4 (n_threads_batch = 4) / 4 | CUDA : ARCHS = 520,610,700,750 | FORCE_MMQ = 1 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 |

sampler seed: 2296139272
sampler params:
repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 512
top_k = 40, top_p = 0.950, min_p = 0.000, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, top_n_sigma = -1.000, temp = 0.000
mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist
generate: n_ctx = 512, n_batch = 512, n_predict = 512, n_keep = 0

                    1. [end of text]

llama_perf_sampler_print: sampling time = 46.71 ms / 128 runs ( 0.36 ms per token, 2740.55 tokens per second)
llama_perf_context_print: load time = 4489.38 ms
llama_perf_context_print: prompt eval time = 130.47 ms / 8 tokens ( 16.31 ms per token, 61.32 tokens per second)
llama_perf_context_print: eval time = 7664.42 ms / 126 runs ( 60.83 ms per token, 16.44 tokens per second)
llama_perf_context_print: total time = 8016.61 ms / 134 tokens

@steampunque
Copy link

It does work now with -ngl 0 though, if that might give any clue as to problem.

ggml_cuda_init: GGML_CUDA_FORCE_MMQ: yes
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce GTX 1070, compute capability 6.1, VMM: yes
build: 4914 (8551c44) with cc (GCC) 11.2.0 for x86_64-slackware-linux
main: llama backend init
main: load the model and apply lora adapter, if any
llama_model_load_from_file_impl: using device CUDA0 (NVIDIA GeForce GTX 1070) - 7824 MiB free
llama_model_loader: loaded meta data with 26 key-value pairs and 1110 tensors from /datahd/models/madlad400-7b-mt.Q6_K.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv 0: general.architecture str = t5
llama_model_loader: - kv 1: general.name str = T5
llama_model_loader: - kv 2: t5.context_length u32 = 512
llama_model_loader: - kv 3: t5.embedding_length u32 = 2048
llama_model_loader: - kv 4: t5.feed_forward_length u32 = 8192
llama_model_loader: - kv 5: t5.block_count u32 = 48
llama_model_loader: - kv 6: t5.attention.head_count u32 = 16
llama_model_loader: - kv 7: t5.attention.key_length u32 = 128
llama_model_loader: - kv 8: t5.attention.value_length u32 = 128
llama_model_loader: - kv 9: t5.attention.layer_norm_epsilon f32 = 0.000001
llama_model_loader: - kv 10: t5.attention.relative_buckets_count u32 = 32
llama_model_loader: - kv 11: t5.attention.layer_norm_rms_epsilon f32 = 0.000001
llama_model_loader: - kv 12: t5.decoder_start_token_id u32 = 0
llama_model_loader: - kv 13: general.file_type u32 = 18
llama_model_loader: - kv 14: tokenizer.ggml.model str = t5
llama_model_loader: - kv 15: tokenizer.ggml.pre str = default
llama_model_loader: - kv 16: tokenizer.ggml.tokens arr[str,256000] = ["", "", "", "\n", "<2ace>...
llama_model_loader: - kv 17: tokenizer.ggml.scores arr[f32,256000] = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv 18: tokenizer.ggml.token_type arr[i32,256000] = [2, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv 19: tokenizer.ggml.add_space_prefix bool = true
llama_model_loader: - kv 20: tokenizer.ggml.remove_extra_whitespaces bool = false
llama_model_loader: - kv 21: tokenizer.ggml.eos_token_id u32 = 2
llama_model_loader: - kv 22: tokenizer.ggml.padding_token_id u32 = 1
llama_model_loader: - kv 23: tokenizer.ggml.add_bos_token bool = false
llama_model_loader: - kv 24: tokenizer.ggml.add_eos_token bool = true
llama_model_loader: - kv 25: general.quantization_version u32 = 2
llama_model_loader: - type f32: 242 tensors
llama_model_loader: - type q6_K: 866 tensors
llama_model_loader: - type bf16: 2 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type = Q6_K
print_info: file size = 6.34 GiB (6.56 BPW)
load: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect
load: special tokens cache size = 3
load: token to piece cache size = 1.7509 MB
print_info: arch = t5
print_info: vocab_only = 0
print_info: n_ctx_train = 512
print_info: n_embd = 2048
print_info: n_layer = 48
print_info: n_head = 16
print_info: n_head_kv = 16
print_info: n_rot = 128
print_info: n_swa = 0
print_info: n_swa_pattern = 1
print_info: n_embd_head_k = 128
print_info: n_embd_head_v = 128
print_info: n_gqa = 1
print_info: n_embd_k_gqa = 2048
print_info: n_embd_v_gqa = 2048
print_info: f_norm_eps = 0.0e+00
print_info: f_norm_rms_eps = 1.0e-06
print_info: f_clamp_kqv = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale = 0.0e+00
print_info: f_attn_scale = 0.0e+00
print_info: n_ff = 8192
print_info: n_expert = 0
print_info: n_expert_used = 0
print_info: causal attn = 1
print_info: pooling type = 0
print_info: rope type = -1
print_info: rope scaling = linear
print_info: freq_base_train = 10000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn = 512
print_info: rope_finetuned = unknown
print_info: ssm_d_conv = 0
print_info: ssm_d_inner = 0
print_info: ssm_d_state = 0
print_info: ssm_dt_rank = 0
print_info: ssm_dt_b_c_rms = 0
print_info: model type = ?B
print_info: model params = 8.30 B
print_info: general.name = T5
print_info: vocab type = UGM
print_info: n_vocab = 256000
print_info: n_merges = 0
print_info: EOS token = 2 ''
print_info: UNK token = 2 ''
print_info: PAD token = 1 ''
print_info: LF token = 805 '▁'
print_info: EOG token = 2 '
'
print_info: max token length = 48
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 0 repeating layers to GPU
load_tensors: offloaded 0/49 layers to GPU
load_tensors: CPU_Mapped model buffer size = 6492.21 MiB
..........................................................................................
llama_context: constructing llama_context
llama_context: n_seq_max = 1
llama_context: n_ctx = 512
llama_context: n_ctx_per_seq = 512
llama_context: n_batch = 512
llama_context: n_ubatch = 512
llama_context: causal_attn = 1
llama_context: flash_attn = 1
llama_context: freq_base = 10000.0
llama_context: freq_scale = 1
llama_context: yarn_log_mul = 0
llama_context: CPU output buffer size = 0.98 MiB
init: kv_size = 512, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 48, can_shift = 1
init: CPU KV buffer size = 192.00 MiB
llama_context: KV self size = 192.00 MiB, K (f16): 96.00 MiB, V (f16): 96.00 MiB
llama_context: CUDA0 compute buffer size = 947.28 MiB
llama_context: CUDA_Host compute buffer size = 51.00 MiB
llama_context: graph nodes = 2600
llama_context: graph splits = 772 (with bs=512), 151 (with bs=1)
common_init_from_params: setting dry_penalty_last_n to ctx_size = 512
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
main: llama threadpool init, n_threads = 4

system_info: n_threads = 4 (n_threads_batch = 4) / 4 | CUDA : ARCHS = 520,610,700,750 | FORCE_MMQ = 1 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 |

sampler seed: 2041154203
sampler params:
repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 512
top_k = 40, top_p = 0.950, min_p = 0.000, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, top_n_sigma = -1.000, temp = 0.000
mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist
generate: n_ctx = 512, n_batch = 512, n_predict = 512, n_keep = 0

Heute regnet es [end of text]

llama_perf_sampler_print: sampling time = 2.02 ms / 7 runs ( 0.29 ms per token, 3458.50 tokens per second)
llama_perf_context_print: load time = 3213.88 ms
llama_perf_context_print: prompt eval time = 4676.34 ms / 8 tokens ( 584.54 ms per token, 1.71 tokens per second)
llama_perf_context_print: eval time = 6373.64 ms / 5 runs ( 1274.73 ms per token, 0.78 tokens per second)
llama_perf_context_print: total time = 11061.44 ms / 13 tokens

@fairydreaming
Copy link
Collaborator

@steampunque Good catch, there still seem to be a problem when using T5 with the CUDA backend. I'm going to take a look at this now.

@fairydreaming
Copy link
Collaborator

@steampunque ok, it looks like the problem you found is unrelated to this PR, it's another issue introduced with the kv cache refactor PR. A temporary fix:

diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index 42332acf..8d441b0c 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -1100,7 +1100,8 @@ int llama_context::encode(llama_batch & inp_batch) {
                 {
                     // extract token embeddings
                     GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
-                    ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
+                    ggml_backend_synchronize(backend_embd);
+                    ggml_backend_tensor_get(t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
                 } break;
             case LLAMA_POOLING_TYPE_MEAN:
             case LLAMA_POOLING_TYPE_CLS:

@ggerganov
Copy link
Member Author

@fairydreaming Can you check if this alternative patch also fixes the issue:

diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index 42332acf1..664703a89 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -1143,6 +1143,8 @@ int llama_context::encode(llama_batch & inp_batch) {
     if (model.arch == LLM_ARCH_T5 && t_embd) {
         //cross.t_embd = t_embd;
 
+        synchronize();
+
         cross.n_embd = t_embd->ne[0];
         cross.n_enc  = t_embd->ne[1];
         cross.v_embd.resize(cross.n_embd*cross.n_enc);

@steampunque
Copy link

@steampunque ok, it looks like the problem you found is unrelated to this PR, it's another issue introduced with the kv cache refactor PR. A temporary fix:

Thanks for tracking that down so quickly, t5 is back on line :

bash-5.1$ lm "<2de> Today the sun shines"
Heute scheint die Sonne

@fairydreaming
Copy link
Collaborator

@ggerganov Yeah, the synchronize() call fixes the issue too.

@ggerganov ggerganov merged commit 75422e8 into master Mar 18, 2025
46 checks passed
@ggerganov ggerganov deleted the gg/graph-normalize-qkv-shapes branch March 18, 2025 19:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants