Skip to content

Commit 3f3df43

Browse files
committed
fix(model): GQA tensor shape fixes in progress
- Fix expand_kv in components/attention.rs to use repeat() - Fix forward_prefill to use k_t/v_t for expand_kv - Fix forward_decode with proper transposes and contiguous() - Add GQA shape tests - Add tokenizer loading from model directory in server WIP: Still has reshape errors in decode phase
1 parent 73d5101 commit 3f3df43

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

crates/model/src/qwen3/attention.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,15 @@ impl GqaAttention {
245245
kv_cache.write_kv_batch(layer_idx, *block_id, 0, &k_block, &v_block)?;
246246
}
247247

248-
let k_expanded = self.expand_kv(&k.transpose(1, 2)?, self.num_heads, self.num_kv_heads)?;
249-
let v_expanded = self.expand_kv(&v.transpose(1, 2)?, self.num_heads, self.num_kv_heads)?;
250-
let k_expanded = k_expanded.transpose(1, 2)?;
251-
let v_expanded = v_expanded.transpose(1, 2)?;
248+
// expand_kv expects [batch, seq, heads, dim]
249+
// k_t and v_t are already in correct shape from lines 231-232
250+
let k_expanded = self.expand_kv(&k_t, self.num_heads, self.num_kv_heads)?;
251+
let v_expanded = self.expand_kv(&v_t, self.num_heads, self.num_kv_heads)?;
252+
253+
// paged_attention expects [batch, heads, seq, dim]
254+
// expand_kv outputs [batch, seq, heads, dim], so transpose
255+
let k_expanded = k_expanded.transpose(1, 2)?.contiguous()?;
256+
let v_expanded = v_expanded.transpose(1, 2)?.contiguous()?;
252257

253258
if seq_len > tile_size {
254259
self.tiled_attention(&q, &k_expanded, &v_expanded, seq_len)
@@ -286,10 +291,16 @@ impl GqaAttention {
286291

287292
let k = apply_rope(&k, &position_ids, self.theta)?;
288293

294+
// k/v from read_kv after transposes: [head_dim, num_kv_heads, seq]
295+
// Need to reshape to [batch=1, seq, num_kv_heads, head_dim] for expand_kv
296+
let k = k.transpose(0, 2)?; // [head_dim, num_kv_heads, seq] -> [seq, num_kv_heads, head_dim]
297+
let v = v.transpose(0, 2)?;
298+
let k = k.unsqueeze(0)?; // Add batch dimension: [1, seq, num_kv_heads, head_dim]
299+
let v = v.unsqueeze(0)?;
289300
let k_expanded = self.expand_kv(&k, self.num_heads, self.num_kv_heads)?;
290301
let v_expanded = self.expand_kv(&v, self.num_heads, self.num_kv_heads)?;
291-
let k_expanded = k_expanded.transpose(1, 2)?;
292-
let v_expanded = v_expanded.transpose(1, 2)?;
302+
let k_expanded = k_expanded.squeeze(0)?; // Remove batch dimension for attention
303+
let v_expanded = v_expanded.squeeze(0)?;
293304

294305
if seq_len > tile_size {
295306
self.tiled_attention(&q, &k_expanded, &v_expanded, seq_len)

crates/server/src/main.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,19 @@ async fn main() {
144144
engine.run(msg_rx);
145145
});
146146

147-
let tokenizer = Arc::new(Tokenizer::new());
147+
let tokenizer_path = PathBuf::from(&model_path).join("tokenizer.json");
148+
let tokenizer = if tokenizer_path.exists() {
149+
Arc::new(
150+
Tokenizer::from_file(tokenizer_path.to_str().unwrap())
151+
.unwrap_or_else(|e| {
152+
tracing::warn!(error = %e, "Failed to load tokenizer from file, using default");
153+
Tokenizer::new()
154+
})
155+
)
156+
} else {
157+
tracing::warn!("No tokenizer.json found in model directory, using default tokenizer");
158+
Arc::new(Tokenizer::new())
159+
};
148160
let batch_manager = Arc::new(BatchManager::new());
149161

150162
let auth_middleware = if !app_config.auth.api_keys.is_empty() {
@@ -183,8 +195,8 @@ async fn main() {
183195
// Batch API
184196
.route("/v1/batches", post(create_batch))
185197
.route("/v1/batches", get(list_batches))
186-
.route("/v1/batches/:id", get(get_batch))
187-
.route("/v1/batches/:id/results", get(get_batch_results))
198+
.route("/v1/batches/{id}", get(get_batch))
199+
.route("/v1/batches/{id}/results", get(get_batch_results))
188200
// Health, readiness, and metrics endpoints
189201
.route("/health", get(health_handler))
190202
.route("/ready", get(ready_handler))

0 commit comments

Comments
 (0)