Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions aiter/ops/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,46 @@ def pertoken_quant(
return y, y_scale


def per_1x32_f4_quant(x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False):
def per_1x32_f4_quant(
x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False, pack_dim=-1
):
"""Quantize a tensor to MXFP4 (e2m1) format with per-1x32 block scaling.

By default, packing is along the last dimension (dim=-1), which produces
output suitable for ``tl.dot_scaled`` **LHS** operand:
A(M, K) -> fp4=(M, K//2), scale=(M, K//32)

For ``tl.dot_scaled`` **RHS** operand, set ``pack_dim=0`` so the packing
is along the first dimension (the K / contraction dimension):
B(K, N) -> fp4=(K//2, N), scale=(K//32, N)

Args:
x: Input tensor of shape (..., N) or (M, N).
scale: Pre-computed scale tensor (optional, usually None).
quant_dtype: Target quantized dtype, must be ``dtypes.fp4x2``.
shuffle: Whether to apply e8m0 scale shuffling for hardware.
pack_dim: Dimension along which to pack two FP4 values into one byte.
-1 (default): pack along the last dimension (for dot_scaled LHS).
0: pack along the first dimension (for dot_scaled RHS).

Returns:
Tuple of (quantized_tensor, scale_tensor).
"""
assert quant_dtype == dtypes.fp4x2
block_size = 32
F8E8M0_EXP_BIAS = 127
F8E8M0_EXP_BIAS = 127 # noqa:F841
F4E2M1_MAX = 6.0
MAX_POW2 = int(torch.log2(torch.tensor(F4E2M1_MAX, dtype=torch.float32)).item())
# dtypeMax = F4E2M1_MAX
dtypeMax = 2.0**MAX_POW2

# For pack_dim=0, transpose so packing always happens along last dim internally
transposed = False
if pack_dim == 0:
assert x.dim() == 2, "pack_dim=0 requires a 2D input tensor (K, N)"
x = x.T.contiguous()
transposed = True

shape_original = x.shape
x = x.view(-1, shape_original[-1])

Expand All @@ -102,7 +133,40 @@ def per_1x32_f4_quant(x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False):
scale = scale_e8m0_biased.view(m, -1).view(torch.uint8)
if shuffle:
scale = fp4_utils.e8m0_shuffle(scale)
return y, scale.view(dtypes.fp8_e8m0)
scale = scale.view(dtypes.fp8_e8m0)

# For pack_dim=0, transpose results back: (N, K//2) -> (K//2, N)
if transposed:
y = y.T.contiguous()
scale = scale.view(torch.uint8).T.contiguous().view(dtypes.fp8_e8m0)

return y, scale


def per_1x32_f4_quant_for_dot_scaled(lhs, rhs, quant_dtype=dtypes.fp4x2, shuffle=False):
"""Convenience function: quantize both LHS and RHS for ``tl.dot_scaled``.

Handles the packing dimension automatically:
- LHS A(M, K): packed along K (dim=-1) -> fp4=(M, K//2), scale=(M, K//32)
- RHS B(K, N): packed along K (dim=0) -> fp4=(K//2, N), scale=(K//32, N)

Note: Triton 3.6+ expects rhs_scale in transposed form (N, K//32). Users
should transpose the returned rhs_scale accordingly if using Triton >= 3.6.

Args:
lhs: LHS tensor of shape (M, K).
rhs: RHS tensor of shape (K, N).

Returns:
Tuple of (lhs_fp4, lhs_scale, rhs_fp4, rhs_scale).
"""
lhs_fp4, lhs_scale = per_1x32_f4_quant(
lhs, quant_dtype=quant_dtype, shuffle=shuffle, pack_dim=-1
)
rhs_fp4, rhs_scale = per_1x32_f4_quant(
rhs, quant_dtype=quant_dtype, shuffle=shuffle, pack_dim=0
)
return lhs_fp4, lhs_scale, rhs_fp4, rhs_scale


def per_1x32_f8_scale_f8_quant(
Expand Down
Loading