Skip to content

Update quant.pyfix: add pack_dim to per_1x32_f4_quant for tl.dot_scaled RHS compatibility#2704

Open
GeisYaO wants to merge 5 commits intoROCm:mainfrom
GeisYaO:GeisYaO-patch-1
Open

Update quant.pyfix: add pack_dim to per_1x32_f4_quant for tl.dot_scaled RHS compatibility#2704
GeisYaO wants to merge 5 commits intoROCm:mainfrom
GeisYaO:GeisYaO-patch-1

Conversation

@GeisYaO
Copy link
Copy Markdown

@GeisYaO GeisYaO commented Apr 12, 2026

Problem

per_1x32_f4_quant always packs FP4 values along the last dimension (dim=-1). This is correct for tl.dot_scaled LHS operand A(M, K) → (M, K//2), but produces the wrong layout for the RHS operand:

  • Current: per_1x32_f4_quant(B) where B(K, N)fp4=(K, N//2) (packed along N)
  • Expected by tl.dot_scaled: fp4=(K//2, N) (packed along K)

This dimension mismatch causes Memory Access Fault when AITER-quantized data is passed to tl.dot_scaled on MI355X/MI350 hardware.

Additional finding: Triton 3.6 breaking change

Triton 3.6 also changed the expected rhs_scale shape:

  • Triton 3.4: rhs_scale = (K//32, N)
  • Triton 3.6: rhs_scale = (N, K//32) ← transposed

Solution

Add a pack_dim parameter to per_1x32_f4_quant:

  • pack_dim=-1 (default): pack along last dimension → for dot_scaled LHS (backward compatible)
  • pack_dim=0: pack along first dimension → for dot_scaled RHS

Also adds per_1x32_f4_quant_for_dot_scaled() convenience function.

Usage

# LHS (unchanged)
a_fp4, a_scale = per_1x32_f4_quant(A)

# RHS (new)
b_fp4, b_scale = per_1x32_f4_quant(B, pack_dim=0)

# Or one-shot:
a_fp4, a_scale, b_fp4, b_scale = per_1x32_f4_quant_for_dot_scaled(A, B)

Verification

Tested on MI350 (GFX950/CDNA4) with both Triton 3.4 and 3.6:

Shape (M, K, N) Old AITER With pack_dim=0
(4, 512, 2880) ❌ wrong shape ✅ pass
(16, 512, 2112) ❌ wrong shape ✅ pass
(64, 512, 7168) ❌ wrong shape ✅ pass
(16, 2048, 7168) ❌ wrong shape ✅ pass
(64, 7168, 2048) ❌ wrong shape ✅ pass
(64, 1536, 7168) ❌ wrong shape ✅ pass

Backward compatibility verified: pack_dim=-1 (default) produces identical output to the original function.

@GeisYaO GeisYaO requested a review from a team April 12, 2026 12:28
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2704 --add-label <label>

@GeisYaO GeisYaO changed the title Update quant.py Update quant.pyfix: add pack_dim to per_1x32_f4_quant for tl.dot_scaled RHS compatibility Apr 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant