-
-
Notifications
You must be signed in to change notification settings - Fork 841
Add k-bit blockwise quantization (K=2-5) with warp-level CUDA kernels #1858
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
c39f791
fb649f1
2825890
4b17a2f
2973bf5
03415e1
8a2817e
f52b572
f95a7f2
d1f3d75
10cf922
ad7f194
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1005,6 +1005,199 @@ def dequantize_4bit( | |
| return out | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # K-bit blockwise quantization (K=2..5, blocksize=32) | ||
| # --------------------------------------------------------------------------- | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @TimDettmers I do think it can make some sense to hard-code the default codebooks like we have for NF4/FP4. If we did that, scipy seems more of an optional dependency. But we can handle in follow up PR. |
||
|
|
||
| # Cache for precomputed normal-float codebooks (K -> Tensor on each device) | ||
| _kbit_codebook_cache: dict[tuple[int, torch.device], torch.Tensor] = {} | ||
|
|
||
|
|
||
| def create_normal_float_codebook(k: int, device=None) -> torch.Tensor: | ||
| """Create a 2^k-entry normal-float codebook (quantiles of N(0,1), normalized to [-1, 1]). | ||
|
|
||
| For k bits we have 2^k reconstruction levels placed at the expected values | ||
| of N(0,1) within 2^k equiprobable bins. The result is sorted ascending | ||
| and normalized so the largest magnitude is 1.0. | ||
|
|
||
| Args: | ||
| k: Bit width (2-5). | ||
| device: Target device. Defaults to "cuda". | ||
|
|
||
| Returns: | ||
| Float32 tensor of shape (2^k,) with values in [-1, 1]. | ||
| """ | ||
| try: | ||
| from scipy.stats import norm | ||
| except ImportError as ie: | ||
| raise ImportError( | ||
| "Scipy is required for `create_normal_float_codebook`. Install `bitsandbytes` with the `[test]` extra.", | ||
matthewdouglas marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) from ie | ||
|
|
||
| if device is None: | ||
| device = torch.device("cuda") | ||
| device = torch.device(device) | ||
|
|
||
| cache_key = (k, device) | ||
| if cache_key in _kbit_codebook_cache: | ||
| return _kbit_codebook_cache[cache_key] | ||
|
|
||
| n_levels = 1 << k | ||
| quantiles = torch.linspace(0.5 / n_levels, 1.0 - 0.5 / n_levels, n_levels) | ||
| values = torch.tensor(norm.ppf(quantiles.numpy()), dtype=torch.float32) | ||
| values = values / values.abs().max() | ||
| values = values.to(device) | ||
|
|
||
| _kbit_codebook_cache[cache_key] = values | ||
| return values | ||
|
|
||
|
|
||
| def encode_absmax_e4m4(absmax: Tensor, bias: int = 11) -> Tensor: | ||
| """Encode fp32 absmax values to uint8 using E4M4 micro-float format. | ||
|
|
||
| Format: 4-bit exponent + 4-bit mantissa with IEEE-style subnormals. | ||
| Normal (e > 0): 2^(e - bias) * (1 + m/16) | ||
| Subnormal (e = 0): 2^(1 - bias) * (m/16) | ||
| Zero (e = 0, m = 0): 0.0 | ||
|
|
||
| Args: | ||
| absmax: float32 tensor of per-block absolute maximum values. | ||
| bias: Exponent bias. Default 11 gives range [6.1e-5, 31.0]. | ||
|
|
||
| Returns: | ||
| uint8 tensor of same shape as absmax. | ||
| """ | ||
| result = torch.zeros_like(absmax, dtype=torch.uint8) | ||
| nonzero = absmax > 0 | ||
|
|
||
| # Compute exponent: floor(log2(absmax)) | ||
| log2_val = torch.log2(absmax[nonzero]) | ||
| e_unbiased = torch.floor(log2_val).to(torch.int32) | ||
|
|
||
| # Clamp to representable range | ||
| e_biased = (e_unbiased + bias).clamp(0, 15) | ||
|
|
||
| # Handle subnormals (e_biased <= 0 before clamping) | ||
| is_subnormal = (e_unbiased + bias) <= 0 | ||
| e_biased[is_subnormal] = 0 | ||
|
|
||
| # Compute mantissa | ||
| abs_nz = absmax[nonzero] | ||
| # Normal: m = round((absmax / 2^e_unbiased - 1) * 16) | ||
| # Subnormal: m = round(absmax / 2^(1-bias) * 16) | ||
| mantissa = torch.zeros_like(abs_nz, dtype=torch.int32) | ||
|
|
||
| normal_mask = ~is_subnormal | ||
| if normal_mask.any(): | ||
| e_ub_normal = e_unbiased[normal_mask] | ||
| scale = torch.exp2(e_ub_normal.float()) | ||
| m_float = (abs_nz[normal_mask] / scale - 1.0) * 16.0 | ||
| mantissa[normal_mask] = m_float.round().to(torch.int32).clamp(0, 15) | ||
|
|
||
| if is_subnormal.any(): | ||
| subnormal_scale = 2.0 ** (1 - bias) | ||
| m_float = abs_nz[is_subnormal] / subnormal_scale * 16.0 | ||
| mantissa[is_subnormal] = m_float.round().to(torch.int32).clamp(0, 15) | ||
|
|
||
| encoded = (e_biased << 4 | mantissa).to(torch.uint8) | ||
| result[nonzero] = encoded | ||
| return result | ||
|
|
||
|
|
||
| def decode_absmax_e4m4(encoded: Tensor, bias: int = 11) -> Tensor: | ||
| """Decode uint8 E4M4 absmax values to fp32. | ||
|
|
||
| Args: | ||
| encoded: uint8 tensor of E4M4-encoded absmax values. | ||
| bias: Exponent bias (must match encoding). | ||
|
|
||
| Returns: | ||
| float32 tensor of decoded absmax values. | ||
| """ | ||
| raw = encoded.to(torch.int32) | ||
| e = raw >> 4 | ||
| m = raw & 0xF | ||
|
|
||
| # Normal: 2^(e - bias) * (1 + m/16) | ||
| # Subnormal: 2^(1 - bias) * (m/16) | ||
| is_subnormal = e == 0 | ||
| result = torch.zeros_like(encoded, dtype=torch.float32) | ||
|
|
||
| if (~is_subnormal).any(): | ||
| e_normal = e[~is_subnormal].float() | ||
| m_normal = m[~is_subnormal].float() | ||
| result[~is_subnormal] = torch.exp2(e_normal - bias) * (1.0 + m_normal / 16.0) | ||
|
|
||
| if is_subnormal.any(): | ||
| m_sub = m[is_subnormal].float() | ||
| result[is_subnormal] = (2.0 ** (1 - bias)) * (m_sub / 16.0) | ||
|
|
||
| return result | ||
|
|
||
|
|
||
| def quantize_kbit( | ||
| A: Tensor, | ||
| k: int = 4, | ||
| codebook: Optional[Tensor] = None, | ||
| absmax_format: str = "e4m4", | ||
| ) -> tuple[Tensor, Tensor, Tensor]: | ||
| """Quantize a tensor using k-bit blockwise quantization (blocksize=32). | ||
|
|
||
| Uses warp-level CUDA primitives for efficient bit-plane packing. | ||
|
|
||
| Args: | ||
| A: Input tensor. Supports float16, bfloat16, or float32. | ||
| k: Bit width (2, 3, 4, or 5). Defaults to 4. | ||
| codebook: Optional float32 codebook tensor with 2^k entries in [-1, 1], sorted ascending. | ||
| If None, uses a precomputed normal-float codebook. | ||
| absmax_format: Format for absmax storage. "e4m4" (default, uint8) or "fp32". | ||
|
|
||
| Returns: | ||
| Tuple of (packed, absmax, codebook): | ||
| - packed: int32 tensor of bit-plane packed quantized values. | ||
| - absmax: Tensor of per-block absolute maximum values (float32 or uint8). | ||
| - codebook: The codebook tensor used (useful when auto-generated). | ||
| """ | ||
| if codebook is None: | ||
| codebook = create_normal_float_codebook(k, device=A.device) | ||
| else: | ||
| codebook = codebook.to(device=A.device, dtype=torch.float32) | ||
|
|
||
| A_flat = A.contiguous().view(-1) | ||
| packed, absmax = torch.ops.bitsandbytes.quantize_kbit(A_flat, codebook, k) | ||
|
|
||
| if absmax_format == "e4m4": | ||
| absmax = encode_absmax_e4m4(absmax) | ||
|
|
||
| return packed, absmax, codebook | ||
|
|
||
|
|
||
| def dequantize_kbit( | ||
| packed: Tensor, | ||
| absmax: Tensor, | ||
| codebook: Tensor, | ||
| k: int, | ||
| n: int, | ||
| dtype: torch.dtype = torch.float16, | ||
| ) -> Tensor: | ||
| """Dequantize a k-bit blockwise quantized tensor. | ||
|
|
||
| Args: | ||
| packed: int32 tensor of bit-plane packed values (from quantize_kbit). | ||
| absmax: Tensor of per-block absmax values (from quantize_kbit). | ||
| Supports float32 or uint8 (E4M4 format). | ||
| codebook: float32 codebook tensor with 2^k entries. | ||
| k: Bit width (2, 3, 4, or 5). | ||
| n: Number of original elements. | ||
| dtype: Output dtype. Defaults to float16. | ||
|
|
||
| Returns: | ||
| Dequantized tensor of shape (n,) with the given dtype. | ||
| """ | ||
| out = torch.ops.bitsandbytes.dequantize_kbit(packed, codebook, absmax, k, n, dtype) | ||
| return out[:n] | ||
|
|
||
|
|
||
| @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) | ||
| def quantize( | ||
| A: Tensor, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the user passes fp32 absmax, the dequant dispatch silently encodes it to E4M4 before calling the kernel. This is a lossy conversion the caller may not expect — they passed fp32 precision but get E4M4 precision. Consider either warning or documenting this behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@TimDettmers Let's address in a future iteration; I'm not sure if we even want to support fp32 input at all?