Skip to content

Eval bug: b5237 broke Llama Scout #13287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
steampunque opened this issue May 3, 2025 · 14 comments · Fixed by #13294 or #13299
Closed

Eval bug: b5237 broke Llama Scout #13287

steampunque opened this issue May 3, 2025 · 14 comments · Fixed by #13294 or #13299
Assignees
Labels
bug Something isn't working

Comments

@steampunque
Copy link

steampunque commented May 3, 2025

Name and Version

b5237

Operating systems

Linux

GGML backends

CUDA

Hardware

4070

Models

Llama4 scout. Quant should not matter but I am using my hybrid layer quant here:

https://huggingface.co/steampunque/Llama-4-Scout-17B-16E-Instruct-GGUF/blob/main/Llama-4-Scout-17B-16E-Instruct.Q3_K_H.gguf

Problem description & steps to reproduce

crash with illegal memory access running a perplexity:

short.txt is first 9783 bytes of wiki.test.raw

llama-perplexity -m /data3hd/models/Llama-4-Scout-17B-16E-Instruct.Q3_K_H.gguf -ngl 10 -c 1024 -b 128 -fa -f short.txt

First Bad Commit

b5237. b5236 works fine.

Relevant log output

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 4070, compute capability 8.9, VMM: yes
build: 5237 (e1e8e099) with cc (GCC) 11.2.0 for x86_64-slackware-linux
llama_model_load_from_file_impl: using device CUDA0 (NVIDIA GeForce RTX 4070) - 11536 MiB free
llama_model_loader: loaded meta data with 42 key-value pairs and 628 tensors from /data3hd/models/Llama-4-Scout-17B-16E-Instruct.Q3_K_H.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama4
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Llama 4 Scout 17B 16E Instruct
llama_model_loader: - kv   3:                           general.finetune str              = 16E-Instruct
llama_model_loader: - kv   4:                           general.basename str              = Llama-4-Scout
llama_model_loader: - kv   5:                         general.size_label str              = 17B
llama_model_loader: - kv   6:                            general.license str              = other
llama_model_loader: - kv   7:                       general.license.name str              = llama4
llama_model_loader: - kv   8:                   general.base_model.count u32              = 1
llama_model_loader: - kv   9:                  general.base_model.0.name str              = Llama 4 Scout 17B 16E
llama_model_loader: - kv  10:          general.base_model.0.organization str              = Meta Llama
llama_model_loader: - kv  11:              general.base_model.0.repo_url str              = https://huggingface.co/meta-llama/Lla...
llama_model_loader: - kv  12:                               general.tags arr[str,5]       = ["facebook", "meta", "pytorch", "llam...
llama_model_loader: - kv  13:                          general.languages arr[str,12]      = ["ar", "de", "en", "es", "fr", "hi", ...
llama_model_loader: - kv  14:                         llama4.block_count u32              = 48
llama_model_loader: - kv  15:                      llama4.context_length u32              = 10485760
llama_model_loader: - kv  16:                    llama4.embedding_length u32              = 5120
llama_model_loader: - kv  17:                 llama4.feed_forward_length u32              = 16384
llama_model_loader: - kv  18:                llama4.attention.head_count u32              = 40
llama_model_loader: - kv  19:             llama4.attention.head_count_kv u32              = 8
llama_model_loader: - kv  20:                      llama4.rope.freq_base f32              = 500000.000000
llama_model_loader: - kv  21:    llama4.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  22:                        llama4.expert_count u32              = 16
llama_model_loader: - kv  23:                   llama4.expert_used_count u32              = 1
llama_model_loader: - kv  24:                llama4.attention.key_length u32              = 128
llama_model_loader: - kv  25:              llama4.attention.value_length u32              = 128
llama_model_loader: - kv  26:                          llama4.vocab_size u32              = 202048
llama_model_loader: - kv  27:                llama4.rope.dimension_count u32              = 128
llama_model_loader: - kv  28:           llama4.interleave_moe_layer_step u32              = 1
llama_model_loader: - kv  29:          llama4.expert_feed_forward_length u32              = 8192
llama_model_loader: - kv  30:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  31:                         tokenizer.ggml.pre str              = llama4
llama_model_loader: - kv  32:                      tokenizer.ggml.tokens arr[str,202048]  = ["À", "Á", "õ", "ö", "÷", "ø", ...
llama_model_loader: - kv  33:                  tokenizer.ggml.token_type arr[i32,202048]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  34:                      tokenizer.ggml.merges arr[str,439802]  = ["Ġ Ġ", "Ġ ĠĠĠ", "ĠĠ ĠĠ", "...
llama_model_loader: - kv  35:                tokenizer.ggml.bos_token_id u32              = 200000
llama_model_loader: - kv  36:                tokenizer.ggml.eos_token_id u32              = 200008
llama_model_loader: - kv  37:            tokenizer.ggml.padding_token_id u32              = 201134
llama_model_loader: - kv  38:                    tokenizer.chat_template str              = {{- bos_token }}\n{%- if custom_tools ...
llama_model_loader: - kv  39:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  40:               general.quantization_version u32              = 2
llama_model_loader: - kv  41:                          general.file_type u32              = 12
llama_model_loader: - type  f32:  146 tensors
llama_model_loader: - type q2_K:   54 tensors
llama_model_loader: - type q3_K:  372 tensors
llama_model_loader: - type q4_K:   51 tensors
llama_model_loader: - type q5_K:    5 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = Q3_K - Medium
print_info: file size   = 43.34 GiB (3.45 BPW) 
load: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect
load: special tokens cache size = 1135
load: token to piece cache size = 1.3873 MB
print_info: arch             = llama4
print_info: vocab_only       = 0
print_info: n_ctx_train      = 10485760
print_info: n_embd           = 5120
print_info: n_layer          = 48
print_info: n_head           = 40
print_info: n_head_kv        = 8
print_info: n_rot            = 128
print_info: n_swa            = 1
print_info: n_swa_pattern    = 4
print_info: n_embd_head_k    = 128
print_info: n_embd_head_v    = 128
print_info: n_gqa            = 5
print_info: n_embd_k_gqa     = 1024
print_info: n_embd_v_gqa     = 1024
print_info: f_norm_eps       = 0.0e+00
print_info: f_norm_rms_eps   = 1.0e-05
print_info: f_clamp_kqv      = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale    = 0.0e+00
print_info: f_attn_scale     = 0.0e+00
print_info: n_ff             = 16384
print_info: n_expert         = 16
print_info: n_expert_used    = 1
print_info: causal attn      = 1
print_info: pooling type     = 0
print_info: rope type        = 0
print_info: rope scaling     = linear
print_info: freq_base_train  = 500000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn  = 10485760
print_info: rope_finetuned   = unknown
print_info: ssm_d_conv       = 0
print_info: ssm_d_inner      = 0
print_info: ssm_d_state      = 0
print_info: ssm_dt_rank      = 0
print_info: ssm_dt_b_c_rms   = 0
print_info: model type       = 17Bx16E (Scout)
print_info: model params     = 107.77 B
print_info: general.name     = Llama 4 Scout 17B 16E Instruct
print_info: vocab type       = BPE
print_info: n_vocab          = 202048
print_info: n_merges         = 439802
print_info: BOS token        = 200000 '<|begin_of_text|>'
print_info: EOS token        = 200008 '<|eot|>'
print_info: PAD token        = 201134 '<|finetune_right_pad_id|>'
print_info: LF token         = 198 'Ċ'
print_info: FIM PRE token    = 200002 '<|fim_prefix|>'
print_info: FIM SUF token    = 200004 '<|fim_suffix|>'
print_info: FIM MID token    = 200003 '<|fim_middle|>'
print_info: EOG token        = 200008 '<|eot|>'
print_info: max token length = 192
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 10 repeating layers to GPU
load_tensors: offloaded 10/49 layers to GPU
load_tensors:   CPU_Mapped model buffer size = 34140.05 MiB
load_tensors:        CUDA0 model buffer size = 10243.28 MiB
...................................................................................................
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 1024
llama_context: n_ctx_per_seq = 1024
llama_context: n_batch       = 128
llama_context: n_ubatch      = 128
llama_context: causal_attn   = 1
llama_context: flash_attn    = 1
llama_context: freq_base     = 500000.0
llama_context: freq_scale    = 1
llama_context: yarn_log_mul  = 0
llama_context: n_ctx_per_seq (1024) < n_ctx_train (10485760) -- the full capacity of the model will not be utilized
llama_context:        CPU  output buffer size =     0.77 MiB
init: kv_size = 1024, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 48, can_shift = 1
init:        CPU KV buffer size =   152.00 MiB
init:      CUDA0 KV buffer size =    40.00 MiB
llama_context: KV self size  =  192.00 MiB, K (f16):   96.00 MiB, V (f16):   96.00 MiB
llama_context:      CUDA0 compute buffer size =   786.92 MiB
llama_context:  CUDA_Host compute buffer size =     3.50 MiB
llama_context: graph nodes  = 2324
llama_context: graph splits = 575 (with bs=128), 3 (with bs=1)
common_init_from_params: setting dry_penalty_last_n to ctx_size = 1024
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)

system_info: n_threads = 8 (n_threads_batch = 8) / 16 | CUDA : ARCHS = 600,610,700,750 | F16 = 1 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 | 
perplexity: tokenizing the input ..
perplexity: tokenization took 12.048 ms
perplexity: calculating perplexity over 2 chunks, n_ctx=1024, batch_size=128, n_seq=1
/usr/local/src/ai/llamacpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:75: CUDA error
CUDA error: an illegal memory access was encountered
  current device: 0, in function ggml_backend_cuda_synchronize at /usr/local/src/ai/llamacpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:2443
  cudaStreamSynchronize(cuda_ctx->stream())
[New LWP 15539]
[New LWP 15543]
[New LWP 15544]
[New LWP 15545]
[New LWP 15546]
[New LWP 15547]
[New LWP 15548]
[New LWP 15549]
[New LWP 15550]
[New LWP 15551]
[New LWP 15552]
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib64/libthread_db.so.1".
0x00007fa1640d33c7 in wait4 () from /lib64/libc.so.6
#0  0x00007fa1640d33c7 in wait4 () from /lib64/libc.so.6
#1  0x00007fa1646601e1 in ggml_abort () from /usr/lib64/libggml-base.so
#2  0x00007fa1647cb422 in ggml_cuda_error(char const*, char const*, char const*, int, char const*) () from /usr/lib64/libggml-cuda.so
#3  0x00007fa1647cc87b in ggml_backend_cuda_synchronize(ggml_backend*) () from /usr/lib64/libggml-cuda.so
#4  0x00007fa164675565 in ggml_backend_sched_graph_compute_async () from /usr/lib64/libggml-base.so
#5  0x00007fa1797d3a91 in llama_context::graph_compute(ggml_cgraph*, bool) () from /usr/lib64/libllama.so
#6  0x00007fa1797d68d2 in llama_context::decode(llama_batch&) () from /usr/lib64/libllama.so
#7  0x00007fa1797d7adc in llama_decode () from /usr/lib64/libllama.so
#8  0x0000000000435d8c in perplexity(llama_context*, common_params const&, int) ()
#9  0x000000000042f63b in main ()
[Inferior 1 (process 15538) detached]
/usr/local/bin/ll_start: line 2092: 15538 Aborted                 ${EXPREFIX}$PERPLEXITY -m $MODEL_ROOT/$MODEL_FILE $RPC -ngl $NGL -c $NKV -b $BATCH $FATTN_FLAG -f $DATA_FILE
@CISC CISC added bug Something isn't working and removed bug-unconfirmed labels May 3, 2025
@CISC
Copy link
Collaborator

CISC commented May 3, 2025

@JohannesGaessler Same issue as in #13286?

@JohannesGaessler
Copy link
Collaborator

No, the linked PR specifically fixes imatrix. A CUDA error with illegal memory access is almost always an issue with the CUDA code where some edge case is not being considered correctly.

@CISC
Copy link
Collaborator

CISC commented May 3, 2025

No, the linked PR specifically fixes imatrix. A CUDA error with illegal memory access is almost always an issue with the CUDA code where some edge case is not being considered correctly.

Yes, but perplexity basically does the same thing as imatrix, no?
Hmmm, guess not, I thought it accessed tensors too...

@JohannesGaessler
Copy link
Collaborator

This issue should be fixed by #13294 , please confirm.

@steampunque
Copy link
Author

This issue should be fixed by #13294 , please confirm.

Same issue with this patch applied:

diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index 8c93e8326e20b..fc6ce0083007a 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -2636,6 +2636,7 @@ static __global__ void mul_mat_q(
 
         ids_dst_shared[j] = j;
     }
+    __syncthreads();
 
     // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
 #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
@@ -2664,6 +2665,7 @@ static __global__ void mul_mat_q(
                 return;
             }
 
+            // __syncthreads(); // There is no previous tile that could cause a race condition.
 #pragma unroll
             for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
                 const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
@@ -2674,6 +2676,7 @@ static __global__ void mul_mat_q(
 
                 ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
             }
+            __syncthreads();
         }
 
         offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
@@ -2740,6 +2743,7 @@ static __global__ void mul_mat_q(
                 continue;
             }
 
+            __syncthreads();
 #pragma unroll
             for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
                 const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
@@ -2750,6 +2754,7 @@ static __global__ void mul_mat_q(
 
                 ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
             }
+            __syncthreads();
         }
 
         offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
@@ -2805,6 +2810,7 @@ static __global__ void mul_mat_q(
         }
 
         // The memory layout for the fixup buffer is always contiguous, therefore reset ids:
+        __syncthreads();
 #pragma unroll
         for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
             const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
@@ -2815,6 +2821,7 @@ static __global__ void mul_mat_q(
 
             ids_dst_shared[j] = j;
         }
+        __syncthreads();
     }
 
     offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));

bash-5.1$ llama-perplexity -m /data3hd/models/Llama-4-Scout-17B-16E-Instruct.Q3_K_H.gguf -ngl 10 -c 1024 -b 128 -fa -f short.txt
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 4070, compute capability 8.9, VMM: yes
build: 5237 (e1e8e09) with cc (GCC) 11.2.0 for x86_64-slackware-linux
llama_model_load_from_file_impl: using device CUDA0 (NVIDIA GeForce RTX 4070) - 11604 MiB free
llama_model_loader: loaded meta data with 42 key-value pairs and 628 tensors from /data3hd/models/Llama-4-Scout-17B-16E-Instruct.Q3_K_H.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv 0: general.architecture str = llama4
llama_model_loader: - kv 1: general.type str = model
llama_model_loader: - kv 2: general.name str = Llama 4 Scout 17B 16E Instruct
llama_model_loader: - kv 3: general.finetune str = 16E-Instruct
llama_model_loader: - kv 4: general.basename str = Llama-4-Scout
llama_model_loader: - kv 5: general.size_label str = 17B
llama_model_loader: - kv 6: general.license str = other
llama_model_loader: - kv 7: general.license.name str = llama4
llama_model_loader: - kv 8: general.base_model.count u32 = 1
llama_model_loader: - kv 9: general.base_model.0.name str = Llama 4 Scout 17B 16E
llama_model_loader: - kv 10: general.base_model.0.organization str = Meta Llama
llama_model_loader: - kv 11: general.base_model.0.repo_url str = https://huggingface.co/meta-llama/Lla...
llama_model_loader: - kv 12: general.tags arr[str,5] = ["facebook", "meta", "pytorch", "llam...
llama_model_loader: - kv 13: general.languages arr[str,12] = ["ar", "de", "en", "es", "fr", "hi", ...
llama_model_loader: - kv 14: llama4.block_count u32 = 48
llama_model_loader: - kv 15: llama4.context_length u32 = 10485760
llama_model_loader: - kv 16: llama4.embedding_length u32 = 5120
llama_model_loader: - kv 17: llama4.feed_forward_length u32 = 16384
llama_model_loader: - kv 18: llama4.attention.head_count u32 = 40
llama_model_loader: - kv 19: llama4.attention.head_count_kv u32 = 8
llama_model_loader: - kv 20: llama4.rope.freq_base f32 = 500000.000000
llama_model_loader: - kv 21: llama4.attention.layer_norm_rms_epsilon f32 = 0.000010
llama_model_loader: - kv 22: llama4.expert_count u32 = 16
llama_model_loader: - kv 23: llama4.expert_used_count u32 = 1
llama_model_loader: - kv 24: llama4.attention.key_length u32 = 128
llama_model_loader: - kv 25: llama4.attention.value_length u32 = 128
llama_model_loader: - kv 26: llama4.vocab_size u32 = 202048
llama_model_loader: - kv 27: llama4.rope.dimension_count u32 = 128
llama_model_loader: - kv 28: llama4.interleave_moe_layer_step u32 = 1
llama_model_loader: - kv 29: llama4.expert_feed_forward_length u32 = 8192
llama_model_loader: - kv 30: tokenizer.ggml.model str = gpt2
llama_model_loader: - kv 31: tokenizer.ggml.pre str = llama4
llama_model_loader: - kv 32: tokenizer.ggml.tokens arr[str,202048] = ["À", "Á", "õ", "ö", "÷", "ø", ...
llama_model_loader: - kv 33: tokenizer.ggml.token_type arr[i32,202048] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv 34: tokenizer.ggml.merges arr[str,439802] = ["Ġ Ġ", "Ġ ĠĠĠ", "ĠĠ ĠĠ", "...
llama_model_loader: - kv 35: tokenizer.ggml.bos_token_id u32 = 200000
llama_model_loader: - kv 36: tokenizer.ggml.eos_token_id u32 = 200008
llama_model_loader: - kv 37: tokenizer.ggml.padding_token_id u32 = 201134
llama_model_loader: - kv 38: tokenizer.chat_template str = {{- bos_token }}\n{%- if custom_tools ...
llama_model_loader: - kv 39: tokenizer.ggml.add_bos_token bool = true
llama_model_loader: - kv 40: general.quantization_version u32 = 2
llama_model_loader: - kv 41: general.file_type u32 = 12
llama_model_loader: - type f32: 146 tensors
llama_model_loader: - type q2_K: 54 tensors
llama_model_loader: - type q3_K: 372 tensors
llama_model_loader: - type q4_K: 51 tensors
llama_model_loader: - type q5_K: 5 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type = Q3_K - Medium
print_info: file size = 43.34 GiB (3.45 BPW)
load: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect
load: special tokens cache size = 1135
load: token to piece cache size = 1.3873 MB
print_info: arch = llama4
print_info: vocab_only = 0
print_info: n_ctx_train = 10485760
print_info: n_embd = 5120
print_info: n_layer = 48
print_info: n_head = 40
print_info: n_head_kv = 8
print_info: n_rot = 128
print_info: n_swa = 1
print_info: n_swa_pattern = 4
print_info: n_embd_head_k = 128
print_info: n_embd_head_v = 128
print_info: n_gqa = 5
print_info: n_embd_k_gqa = 1024
print_info: n_embd_v_gqa = 1024
print_info: f_norm_eps = 0.0e+00
print_info: f_norm_rms_eps = 1.0e-05
print_info: f_clamp_kqv = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale = 0.0e+00
print_info: f_attn_scale = 0.0e+00
print_info: n_ff = 16384
print_info: n_expert = 16
print_info: n_expert_used = 1
print_info: causal attn = 1
print_info: pooling type = 0
print_info: rope type = 0
print_info: rope scaling = linear
print_info: freq_base_train = 500000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn = 10485760
print_info: rope_finetuned = unknown
print_info: ssm_d_conv = 0
print_info: ssm_d_inner = 0
print_info: ssm_d_state = 0
print_info: ssm_dt_rank = 0
print_info: ssm_dt_b_c_rms = 0
print_info: model type = 17Bx16E (Scout)
print_info: model params = 107.77 B
print_info: general.name = Llama 4 Scout 17B 16E Instruct
print_info: vocab type = BPE
print_info: n_vocab = 202048
print_info: n_merges = 439802
print_info: BOS token = 200000 '<|begin_of_text|>'
print_info: EOS token = 200008 '<|eot|>'
print_info: PAD token = 201134 '<|finetune_right_pad_id|>'
print_info: LF token = 198 'Ċ'
print_info: FIM PRE token = 200002 '<|fim_prefix|>'
print_info: FIM SUF token = 200004 '<|fim_suffix|>'
print_info: FIM MID token = 200003 '<|fim_middle|>'
print_info: EOG token = 200008 '<|eot|>'
print_info: max token length = 192
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 10 repeating layers to GPU
load_tensors: offloaded 10/49 layers to GPU
load_tensors: CPU_Mapped model buffer size = 34140.05 MiB
load_tensors: CUDA0 model buffer size = 10243.28 MiB
...................................................................................................
llama_context: constructing llama_context
llama_context: n_seq_max = 1
llama_context: n_ctx = 1024
llama_context: n_ctx_per_seq = 1024
llama_context: n_batch = 128
llama_context: n_ubatch = 128
llama_context: causal_attn = 1
llama_context: flash_attn = 1
llama_context: freq_base = 500000.0
llama_context: freq_scale = 1
llama_context: yarn_log_mul = 0
llama_context: n_ctx_per_seq (1024) < n_ctx_train (10485760) -- the full capacity of the model will not be utilized
llama_context: CPU output buffer size = 0.77 MiB
init: kv_size = 1024, offload = 1, type_k = 'f16', type_v = 'f16', n_layer = 48, can_shift = 1
init: CPU KV buffer size = 152.00 MiB
init: CUDA0 KV buffer size = 40.00 MiB
llama_context: KV self size = 192.00 MiB, K (f16): 96.00 MiB, V (f16): 96.00 MiB
llama_context: CUDA0 compute buffer size = 786.92 MiB
llama_context: CUDA_Host compute buffer size = 3.50 MiB
llama_context: graph nodes = 2324
llama_context: graph splits = 575 (with bs=128), 3 (with bs=1)
common_init_from_params: setting dry_penalty_last_n to ctx_size = 1024
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)

system_info: n_threads = 8 (n_threads_batch = 8) / 16 | CUDA : ARCHS = 600,610,700,750 | F16 = 1 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 |
perplexity: tokenizing the input ..
perplexity: tokenization took 11.36 ms
perplexity: calculating perplexity over 2 chunks, n_ctx=1024, batch_size=128, n_seq=1
/usr/local/src/ai/llamacpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:75: CUDA error
CUDA error: an illegal memory access was encountered
current device: 0, in function ggml_backend_cuda_synchronize at /usr/local/src/ai/llamacpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:2443
cudaStreamSynchronize(cuda_ctx->stream())
[New LWP 31246]
[New LWP 31250]
[New LWP 31251]
[New LWP 31252]
[New LWP 31298]
[New LWP 31299]
[New LWP 31300]
[New LWP 31301]
[New LWP 31302]
[New LWP 31303]
[New LWP 31304]
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib64/libthread_db.so.1".
0x00007f3bf5cd33c7 in wait4 () from /lib64/libc.so.6
#0 0x00007f3bf5cd33c7 in wait4 () from /lib64/libc.so.6
#1 0x00007f3bf62601e1 in ggml_abort () from /usr/lib64/libggml-base.so
#2 0x00007f3bf63cb422 in ggml_cuda_error(char const*, char const*, char const*, int, char const*) () from /usr/lib64/libggml-cuda.so
#3 0x00007f3bf63cc87b in ggml_backend_cuda_synchronize(ggml_backend*) () from /usr/lib64/libggml-cuda.so
#4 0x00007f3bf6275565 in ggml_backend_sched_graph_compute_async () from /usr/lib64/libggml-base.so
#5 0x00007f3c0b3eea91 in llama_context::graph_compute(ggml_cgraph*, bool) () from /usr/lib64/libllama.so
#6 0x00007f3c0b3f18d2 in llama_context::decode(llama_batch&) () from /usr/lib64/libllama.so
#7 0x00007f3c0b3f2adc in llama_decode () from /usr/lib64/libllama.so
#8 0x0000000000435d8c in perplexity(llama_context*, common_params const&, int) ()
#9 0x000000000042f63b in main ()
[Inferior 1 (process 31245) detached]
Aborted

A smaller gguf is now available if you would like to test it:

https://huggingface.co/steampunque/Llama-4-Scout-17B-16E-Instruct-GGUF/resolve/main/Llama-4-Scout-17B-16E-Instruct.Q2_K_H.gguf

@JohannesGaessler
Copy link
Collaborator

Can you edit line 103 of ggml/src/ggml-cuda/CMakelists.txt to set(CUDA_FLAGS -use_fast_math -lineinfo), then delete the build directory and recompile, then re-run the command producing the memory error with compute-sanitizer prepended? The compute sanitizer may not be on the path by default, on my system it's under /opt/cuda/extras/compute-sanitizer/. The output you get should be a list of memory errors showing what kind of error it is and which lines of code are responsible.

@JohannesGaessler
Copy link
Collaborator

Potential fix: #13299

@steampunque
Copy link
Author

Potential fix: #13299

Seems to work with this patch. Will be doing further testing later today. Thanks for fast response!

@CISC CISC linked a pull request May 4, 2025 that will close this issue
@steampunque
Copy link
Author

Potential fix: #13299

I ran some regressions and while it doesn't crash any more the generation quality appears to be noticeably degrading with the Q2_K_H quant on Llama 4 Scout. New perplexity shows 10.47605444724703512860 over 59513 tokens while old (prior to b5237) showed Final perplexity=10.46921429389673307437 over 59513 tokens so a tiny increase in perplexity, however across prompting the b5237 is noticeably worse on 3 test prompts (2 fairly hard questions one code gen, b5237 got both questions wrong and generated worse code and prior to b5237 got both questions right and generated better code). So it seems to have gone backwards in performance for some reason (very tiny in objective perplexity result but very noticeable on test prompts).

@JohannesGaessler
Copy link
Collaborator

Did you select 3 random problems and the model just happened to be able to solve all of them or did you select those 3 problems specifically because the model could solve them?

@steampunque
Copy link
Author

Did you select 3 random problems and the model just happened to be able to solve all of them or did you select those 3 problems specifically because the model could solve them?

They are 3 problems I use to help optimize the hybrid layer quants. Yesterday I optimized the Q2_K_H Lllama 4 quant with b5236 and it was working very well across my entire gauntlet of test prompts but it went 0 for 3 on the harder problems I use with the b5237 update. Its still functional on easier prompts but suddenly 0 for 3 on the harder prompts is concerning.

@JohannesGaessler
Copy link
Collaborator

Okay, but that doesn't answer how you selected those 3 prompts. I'm specifically asking because you may be experiencing what is called a regression towards the mean. For example, let's say someone tests 100 prompts, the model can solve 3/100 prompts, and those 3 prompts are then used to determine model performance. The model likely got very lucky on those 3 prompts. If you add a small perturbation to the model it will mostly just shuffle the performance across the 100 prompts around; the performance on the 3 best prompts will likely get worse but there may be other prompts where it performs better. If you selected 3 prompts completely at random and the performance got worse on those 3 prompts that is a very different results than if you selected specifically 3 prompts where the model was performing well above average. Sometime next week I should be able to do a high-precision benchmark run using Elo HeLLM to check your findings.

Please also consider https://github.com/ggml-org/llama.cpp/tree/master/tools/perplexity#llama-3-8b-scoreboard , specifically the columns for the token probabilities. At high bit quantization the token probabilities change on average by multiple percent, but on average the probability of predicting the "correct" token barely moves.

@steampunque
Copy link
Author

Okay, but that doesn't answer how you selected those 3 prompts. I'm specifically asking because you may be experiencing what is called a regression towards the mean. For example, let's say someone tests 100 prompts, the model can solve 3/100 prompts, and those 3 prompts are then used to determine model performance. The model likely got very lucky on those 3 prompts. If you add a small perturbation to the model it will mostly just shuffle the performance across the 100 prompts around; the performance on the 3 best prompts will likely get worse but there may be other prompts where it performs better. If you selected 3 prompts completely at random and the performance got worse on those 3 prompts that is a very different results than if you selected specifically 3 prompts where the model was performing well above average. Sometime next week I should be able to do a high-precision benchmark run using Elo HeLLM to check your findings.

Please also consider https://github.com/ggml-org/llama.cpp/tree/master/tools/perplexity#llama-3-8b-scoreboard , specifically the columns for the token probabilities. At high bit quantization the token probabilities change on average by multiple percent, but on average the probability of predicting the "correct" token barely moves.

I did some followup testing. The Q2_K_H quant is still very good at writing prose and I am getting no generation artifacts which was my main goal (the tricky questions I wasnt too worried about but I was happy it got them too).

Given a strong enough model (this thing is 108B and should be strong) I think its reasonable to assume the model is going to handle a wide range of trick questions right so in my view suddenly faltering on these 3 tests was an alarm bell.

I followed up with adaptive beam search (an experimental feature of my server) test and it corrected both of the tricky questions, so that suggests the model might have been on the knife edge of answering correctly. I would have expected the severe quants to dominate performance though not the higher precision computations going on in the attention block. I'm guessing at random there may be some rounding/truncation bias that may have been close to 0 in previous version but is now coming in with a small bias, just enough to kick the model off a correct generation (one early non optimal token can do it).

Based on all my tests I don't see any glaring issue so I think this bug report can be close. Thanks again for fast response!

@CISC CISC closed this as completed May 4, 2025
@steampunque
Copy link
Author

I ran 100 question math bench on 5236 and 5237. These 100 questions are guaranteed to be unseen by any model since I created them all, about 1/3 quite tricky and 2/3 GSM8K level (fairly simple).

Q2_K_H Llama 4 scout:
5236 = 84/100 correct (0.84)
5237 (with cuda patches) = 81/100 correct (0.81)

Result confirms small performance degradation on 5237 originally concluded based on my 3-prompt screening test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
3 participants