Skip to content

Commit fefe943

Browse files
committed
fix(model): correct seq_len extraction in paged_attention
Fix bug where seq_len was incorrectly extracted after transpose, causing reshape to use wrong sequence length dimension. Before: seq_len = attn_output.dims()[1] (was heads, not seq) After: actual_seq_len = attn_output.dims()[1] after transpose
1 parent a961e7b commit fefe943

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

crates/model/src/components/attention.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ pub fn paged_attention(
9595

9696
let attn_output = Tensor::matmul(&attn_weights, v)?;
9797
let attn_output = attn_output.transpose(1, 2)?;
98-
let seq_len = attn_output.dims()[1];
99-
let attn_output = attn_output.reshape((batch_size, seq_len, num_heads * head_dim))?;
98+
// attn_output now [batch, seq, heads, dim]
99+
let actual_seq_len = attn_output.dims()[1];
100+
let attn_output = attn_output.reshape((batch_size, actual_seq_len, num_heads * head_dim))?;
100101
Ok(attn_output)
101102
}
102103

0 commit comments

Comments
 (0)