Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ if(BUILD_CUDA)
set_target_properties(bitsandbytes
PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
CUDA_RESOLVE_DEVICE_SYMBOLS ON
)
endif()
if(BUILD_HIP)
Expand Down
44 changes: 44 additions & 0 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,47 @@ def _(
qmap2.dtype == absmax2.dtype == torch.float32,
lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
)


# K-bit blockwise quantization (K=2..5, blocksize=32)

torch.library.define(
"bitsandbytes::quantize_kbit",
"(Tensor A, Tensor codebook, int k) -> (Tensor, Tensor)",
)


@register_fake("bitsandbytes::quantize_kbit")
def _(A: torch.Tensor, codebook: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
torch._check(codebook.numel() == (1 << k), lambda: f"codebook must have {1 << k} entries for k={k}")
n = A.numel()
num_blocks = -(n // -32)
# packed: num_blocks * k int32 words + k padding words
packed = torch.empty(num_blocks * k + k, device=A.device, dtype=torch.int32)
absmax = torch.empty(num_blocks + 1, device=A.device, dtype=torch.float32)
return packed, absmax


torch.library.define(
"bitsandbytes::dequantize_kbit",
"(Tensor packed, Tensor codebook, Tensor absmax, int k, int n, ScalarType dtype) -> Tensor",
)


@register_fake("bitsandbytes::dequantize_kbit")
def _(
packed: torch.Tensor,
codebook: torch.Tensor,
absmax: torch.Tensor,
k: int,
n: int,
dtype: torch.dtype,
) -> torch.Tensor:
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
torch._check(
absmax.dtype in (torch.float32, torch.uint8),
lambda: f"absmax must be float32 or uint8 (E4M4), got {absmax.dtype}",
)
num_blocks = -(n // -32)
return torch.empty(num_blocks * 32, device=packed.device, dtype=dtype)
90 changes: 90 additions & 0 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,3 +764,93 @@ def _optimizer_update_8bit_blockwise_impl(

register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "cuda")(_optimizer_update_8bit_blockwise_impl)
register_kernel("bitsandbytes::optimizer_update_32bit", "cuda")(_optimizer_update_32bit_impl)


# K-bit blockwise quantization (K=2..5, blocksize=32)

_KBIT_DTYPE_SUFFIX = {
torch.float16: "fp16",
torch.bfloat16: "bf16",
torch.float32: "fp32",
}


@register_kernel("bitsandbytes::quantize_kbit", "cuda")
def _(A: torch.Tensor, codebook: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
torch._check(
A.dtype in _KBIT_DTYPE_SUFFIX,
lambda: f"quantize_kbit only supports float16/bfloat16/float32, got {A.dtype}",
)
torch._check(codebook.dtype == torch.float32, lambda: f"codebook must be float32, got {codebook.dtype}")
torch._check(codebook.numel() == (1 << k), lambda: f"codebook must have {1 << k} entries for k={k}")

n = A.numel()
num_blocks = -(n // -32)
packed = torch.zeros(num_blocks * k + k, device=A.device, dtype=torch.int32)
absmax = torch.zeros(num_blocks + 1, device=A.device, dtype=torch.float32)

with _cuda_device_of(A):
tname = _KBIT_DTYPE_SUFFIX[A.dtype]
fn = getattr(lib, f"cquantize_kbit_{tname}_k{k}")
fn(
get_ptr(codebook),
get_ptr(A),
get_ptr(absmax),
get_ptr(packed),
ct.c_int(n),
)

return packed, absmax


_KBIT_ABSMAX_SUFFIX = {
torch.uint8: "u8abs",
torch.float16: "fp16abs",
}


Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the user passes fp32 absmax, the dequant dispatch silently encodes it to E4M4 before calling the kernel. This is a lossy conversion the caller may not expect — they passed fp32 precision but get E4M4 precision. Consider either warning or documenting this behavior.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TimDettmers Let's address in a future iteration; I'm not sure if we even want to support fp32 input at all?

@register_kernel("bitsandbytes::dequantize_kbit", "cuda")
def _(
packed: torch.Tensor,
codebook: torch.Tensor,
absmax: torch.Tensor,
k: int,
n: int,
dtype: torch.dtype,
) -> torch.Tensor:
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
torch._check(
dtype in _KBIT_DTYPE_SUFFIX,
lambda: f"dequantize_kbit only supports float16/bfloat16/float32, got {dtype}",
)
torch._check(codebook.dtype == torch.float32, lambda: f"codebook must be float32, got {codebook.dtype}")
torch._check(
absmax.dtype in (torch.float32, torch.float16, torch.uint8),
lambda: f"absmax must be float32, float16, or uint8 (E4M4), got {absmax.dtype}",
)

# If fp32 absmax, encode to E4M4 first
if absmax.dtype == torch.float32:
from bitsandbytes.functional import encode_absmax_e4m4

absmax = encode_absmax_e4m4(absmax)

num_blocks = -(n // -32)
out = torch.empty(num_blocks * 32, device=packed.device, dtype=dtype)

tname = _KBIT_DTYPE_SUFFIX[dtype]
aname = _KBIT_ABSMAX_SUFFIX[absmax.dtype]

with _cuda_device_of(packed):
fn = getattr(lib, f"cdequantize_kbit_{tname}_{aname}_k{k}")
fn(
get_ptr(packed),
get_ptr(codebook),
get_ptr(absmax),
get_ptr(out),
ct.c_int(n),
_get_tensor_stream(packed),
)

return out
193 changes: 193 additions & 0 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,199 @@ def dequantize_4bit(
return out


# ---------------------------------------------------------------------------
# K-bit blockwise quantization (K=2..5, blocksize=32)
# ---------------------------------------------------------------------------
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The create_normal_float_codebook function imports scipy at runtime. Since scipy is only in the [test] extra, calling quantize_kbit() without a pre-built codebook in a production install will raise ImportError. Consider hardcoding the codebooks (like get_4bit_type does for NF4) or documenting the scipy requirement.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TimDettmers I do think it can make some sense to hard-code the default codebooks like we have for NF4/FP4. If we did that, scipy seems more of an optional dependency. But we can handle in follow up PR.


# Cache for precomputed normal-float codebooks (K -> Tensor on each device)
_kbit_codebook_cache: dict[tuple[int, torch.device], torch.Tensor] = {}


def create_normal_float_codebook(k: int, device=None) -> torch.Tensor:
"""Create a 2^k-entry normal-float codebook (quantiles of N(0,1), normalized to [-1, 1]).

For k bits we have 2^k reconstruction levels placed at the expected values
of N(0,1) within 2^k equiprobable bins. The result is sorted ascending
and normalized so the largest magnitude is 1.0.

Args:
k: Bit width (2-5).
device: Target device. Defaults to "cuda".

Returns:
Float32 tensor of shape (2^k,) with values in [-1, 1].
"""
try:
from scipy.stats import norm
except ImportError as ie:
raise ImportError(
"Scipy is required for `create_normal_float_codebook`. Install `bitsandbytes` with the `[test]` extra.",
) from ie

if device is None:
device = torch.device("cuda")
device = torch.device(device)

cache_key = (k, device)
if cache_key in _kbit_codebook_cache:
return _kbit_codebook_cache[cache_key]

n_levels = 1 << k
quantiles = torch.linspace(0.5 / n_levels, 1.0 - 0.5 / n_levels, n_levels)
values = torch.tensor(norm.ppf(quantiles.numpy()), dtype=torch.float32)
values = values / values.abs().max()
values = values.to(device)

_kbit_codebook_cache[cache_key] = values
return values


def encode_absmax_e4m4(absmax: Tensor, bias: int = 11) -> Tensor:
"""Encode fp32 absmax values to uint8 using E4M4 micro-float format.

Format: 4-bit exponent + 4-bit mantissa with IEEE-style subnormals.
Normal (e > 0): 2^(e - bias) * (1 + m/16)
Subnormal (e = 0): 2^(1 - bias) * (m/16)
Zero (e = 0, m = 0): 0.0

Args:
absmax: float32 tensor of per-block absolute maximum values.
bias: Exponent bias. Default 11 gives range [6.1e-5, 31.0].

Returns:
uint8 tensor of same shape as absmax.
"""
result = torch.zeros_like(absmax, dtype=torch.uint8)
nonzero = absmax > 0

# Compute exponent: floor(log2(absmax))
log2_val = torch.log2(absmax[nonzero])
e_unbiased = torch.floor(log2_val).to(torch.int32)

# Clamp to representable range
e_biased = (e_unbiased + bias).clamp(0, 15)

# Handle subnormals (e_biased <= 0 before clamping)
is_subnormal = (e_unbiased + bias) <= 0
e_biased[is_subnormal] = 0

# Compute mantissa
abs_nz = absmax[nonzero]
# Normal: m = round((absmax / 2^e_unbiased - 1) * 16)
# Subnormal: m = round(absmax / 2^(1-bias) * 16)
mantissa = torch.zeros_like(abs_nz, dtype=torch.int32)

normal_mask = ~is_subnormal
if normal_mask.any():
e_ub_normal = e_unbiased[normal_mask]
scale = torch.exp2(e_ub_normal.float())
m_float = (abs_nz[normal_mask] / scale - 1.0) * 16.0
mantissa[normal_mask] = m_float.round().to(torch.int32).clamp(0, 15)

if is_subnormal.any():
subnormal_scale = 2.0 ** (1 - bias)
m_float = abs_nz[is_subnormal] / subnormal_scale * 16.0
mantissa[is_subnormal] = m_float.round().to(torch.int32).clamp(0, 15)

encoded = (e_biased << 4 | mantissa).to(torch.uint8)
result[nonzero] = encoded
return result


def decode_absmax_e4m4(encoded: Tensor, bias: int = 11) -> Tensor:
"""Decode uint8 E4M4 absmax values to fp32.

Args:
encoded: uint8 tensor of E4M4-encoded absmax values.
bias: Exponent bias (must match encoding).

Returns:
float32 tensor of decoded absmax values.
"""
raw = encoded.to(torch.int32)
e = raw >> 4
m = raw & 0xF

# Normal: 2^(e - bias) * (1 + m/16)
# Subnormal: 2^(1 - bias) * (m/16)
is_subnormal = e == 0
result = torch.zeros_like(encoded, dtype=torch.float32)

if (~is_subnormal).any():
e_normal = e[~is_subnormal].float()
m_normal = m[~is_subnormal].float()
result[~is_subnormal] = torch.exp2(e_normal - bias) * (1.0 + m_normal / 16.0)

if is_subnormal.any():
m_sub = m[is_subnormal].float()
result[is_subnormal] = (2.0 ** (1 - bias)) * (m_sub / 16.0)

return result


def quantize_kbit(
A: Tensor,
k: int = 4,
codebook: Optional[Tensor] = None,
absmax_format: str = "e4m4",
) -> tuple[Tensor, Tensor, Tensor]:
"""Quantize a tensor using k-bit blockwise quantization (blocksize=32).

Uses warp-level CUDA primitives for efficient bit-plane packing.

Args:
A: Input tensor. Supports float16, bfloat16, or float32.
k: Bit width (2, 3, 4, or 5). Defaults to 4.
codebook: Optional float32 codebook tensor with 2^k entries in [-1, 1], sorted ascending.
If None, uses a precomputed normal-float codebook.
absmax_format: Format for absmax storage. "e4m4" (default, uint8) or "fp32".

Returns:
Tuple of (packed, absmax, codebook):
- packed: int32 tensor of bit-plane packed quantized values.
- absmax: Tensor of per-block absolute maximum values (float32 or uint8).
- codebook: The codebook tensor used (useful when auto-generated).
"""
if codebook is None:
codebook = create_normal_float_codebook(k, device=A.device)
else:
codebook = codebook.to(device=A.device, dtype=torch.float32)

A_flat = A.contiguous().view(-1)
packed, absmax = torch.ops.bitsandbytes.quantize_kbit(A_flat, codebook, k)

if absmax_format == "e4m4":
absmax = encode_absmax_e4m4(absmax)

return packed, absmax, codebook


def dequantize_kbit(
packed: Tensor,
absmax: Tensor,
codebook: Tensor,
k: int,
n: int,
dtype: torch.dtype = torch.float16,
) -> Tensor:
"""Dequantize a k-bit blockwise quantized tensor.

Args:
packed: int32 tensor of bit-plane packed values (from quantize_kbit).
absmax: Tensor of per-block absmax values (from quantize_kbit).
Supports float32 or uint8 (E4M4 format).
codebook: float32 codebook tensor with 2^k entries.
k: Bit width (2, 3, 4, or 5).
n: Number of original elements.
dtype: Output dtype. Defaults to float16.

Returns:
Dequantized tensor of shape (n,) with the given dtype.
"""
out = torch.ops.bitsandbytes.dequantize_kbit(packed, codebook, absmax, k, n, dtype)
return out[:n]


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def quantize(
A: Tensor,
Expand Down
2 changes: 2 additions & 0 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2601,3 +2601,5 @@ MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256, 1)

// K-bit kernel definitions moved to ops.cu to avoid RDC device linking issues.
3 changes: 3 additions & 0 deletions csrc/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,7 @@ __global__ void kgemm_4bit_inference_naive(

template <typename T, int FUNC> __global__ void kfunc(T* A, T* B, T value, long n);

// K-bit kernel definitions live in ops.cu (not kernels.cu) to keep kernel
// and launch wrapper in the same compilation unit. No declarations needed here.

#endif
Loading