diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index b1668c4c11..8af023f537 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -72,15 +72,46 @@ def pertoken_quant( return y, y_scale -def per_1x32_f4_quant(x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False): +def per_1x32_f4_quant( + x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False, pack_dim=-1 +): + """Quantize a tensor to MXFP4 (e2m1) format with per-1x32 block scaling. + + By default, packing is along the last dimension (dim=-1), which produces + output suitable for ``tl.dot_scaled`` **LHS** operand: + A(M, K) -> fp4=(M, K//2), scale=(M, K//32) + + For ``tl.dot_scaled`` **RHS** operand, set ``pack_dim=0`` so the packing + is along the first dimension (the K / contraction dimension): + B(K, N) -> fp4=(K//2, N), scale=(K//32, N) + + Args: + x: Input tensor of shape (..., N) or (M, N). + scale: Pre-computed scale tensor (optional, usually None). + quant_dtype: Target quantized dtype, must be ``dtypes.fp4x2``. + shuffle: Whether to apply e8m0 scale shuffling for hardware. + pack_dim: Dimension along which to pack two FP4 values into one byte. + -1 (default): pack along the last dimension (for dot_scaled LHS). + 0: pack along the first dimension (for dot_scaled RHS). + + Returns: + Tuple of (quantized_tensor, scale_tensor). + """ assert quant_dtype == dtypes.fp4x2 block_size = 32 - F8E8M0_EXP_BIAS = 127 + F8E8M0_EXP_BIAS = 127 # noqa:F841 F4E2M1_MAX = 6.0 MAX_POW2 = int(torch.log2(torch.tensor(F4E2M1_MAX, dtype=torch.float32)).item()) # dtypeMax = F4E2M1_MAX dtypeMax = 2.0**MAX_POW2 + # For pack_dim=0, transpose so packing always happens along last dim internally + transposed = False + if pack_dim == 0: + assert x.dim() == 2, "pack_dim=0 requires a 2D input tensor (K, N)" + x = x.T.contiguous() + transposed = True + shape_original = x.shape x = x.view(-1, shape_original[-1]) @@ -102,7 +133,40 @@ def per_1x32_f4_quant(x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False): scale = scale_e8m0_biased.view(m, -1).view(torch.uint8) if shuffle: scale = fp4_utils.e8m0_shuffle(scale) - return y, scale.view(dtypes.fp8_e8m0) + scale = scale.view(dtypes.fp8_e8m0) + + # For pack_dim=0, transpose results back: (N, K//2) -> (K//2, N) + if transposed: + y = y.T.contiguous() + scale = scale.view(torch.uint8).T.contiguous().view(dtypes.fp8_e8m0) + + return y, scale + + +def per_1x32_f4_quant_for_dot_scaled(lhs, rhs, quant_dtype=dtypes.fp4x2, shuffle=False): + """Convenience function: quantize both LHS and RHS for ``tl.dot_scaled``. + + Handles the packing dimension automatically: + - LHS A(M, K): packed along K (dim=-1) -> fp4=(M, K//2), scale=(M, K//32) + - RHS B(K, N): packed along K (dim=0) -> fp4=(K//2, N), scale=(K//32, N) + + Note: Triton 3.6+ expects rhs_scale in transposed form (N, K//32). Users + should transpose the returned rhs_scale accordingly if using Triton >= 3.6. + + Args: + lhs: LHS tensor of shape (M, K). + rhs: RHS tensor of shape (K, N). + + Returns: + Tuple of (lhs_fp4, lhs_scale, rhs_fp4, rhs_scale). + """ + lhs_fp4, lhs_scale = per_1x32_f4_quant( + lhs, quant_dtype=quant_dtype, shuffle=shuffle, pack_dim=-1 + ) + rhs_fp4, rhs_scale = per_1x32_f4_quant( + rhs, quant_dtype=quant_dtype, shuffle=shuffle, pack_dim=0 + ) + return lhs_fp4, lhs_scale, rhs_fp4, rhs_scale def per_1x32_f8_scale_f8_quant(