Skip to content

Commit 0fcc8fb

Browse files
committed
fix(model): correct k_expanded transpose in forward_decode
Replace squeeze(0) with transpose(1, 2) to properly reorder dimensions from [1, N, 14, 64] to [1, 14, N, 64] for paged_attention which expects [batch, heads, seq, dim]
1 parent fefe943 commit 0fcc8fb

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

crates/model/src/qwen3/attention.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,8 @@ impl GqaAttention {
315315
let v = v.unsqueeze(0)?;
316316
let k_expanded = self.expand_kv(&k, self.num_heads, self.num_kv_heads)?;
317317
let v_expanded = self.expand_kv(&v, self.num_heads, self.num_kv_heads)?;
318-
let k_expanded = k_expanded.squeeze(0)?; // Remove batch dimension for attention
319-
let v_expanded = v_expanded.squeeze(0)?;
318+
let k_expanded = k_expanded.transpose(1, 2)?; // [1, N, 14, 64] -> [1, 14, N, 64]
319+
let v_expanded = v_expanded.transpose(1, 2)?;
320320

321321
if seq_len > tile_size {
322322
self.tiled_attention(&q, &k_expanded, &v_expanded, seq_len)

0 commit comments

Comments
 (0)