diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..48e29fe9 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,37 @@ +# Kernel-specific instructions + +## flash-attn4 + +When the user asks to sync a flash-attn4 release, carry out the following +steps: + +- Fetch the upstream Git repository from https://github.com/Dao-AILab/flash-attention.git +- Check out the tag that the user specified. +- Flash Attention 4 is in the directory `flash_attn/cute` of the upstream repo. +- Copy Flash Attention 4 upstream files to `flash-attn4/torch-ext/flash_attn4`. +- Copy tests from the tests from the upstream directory `tests/cute` to + `flash-attn4/tests/cute`. +- Check in `flash_attn/cute/pyproject.toml` upstream what version of quack is + required. +- Get this version of quack from https://github.com/Dao-AILab/quack.git +- Copy the `quack` directory from quack to `flash-attn4/torch-ext/flash_attn4/quack` +- Now make all imports of Flash Attention 4 and quack in + `flash-attn4/torch-ext/flash_attn4` and `flash-attn4/torch-ext/flash_attn4/quack` + relative imports. +- Remove all quack files in `flash-attn4/torch-ext/flash_attn4/quack` that are not used. +- Update imports of `flash_attn.cute` in `flash-attn4/tests/cute` to `flash_attn4`. +- Set `__version__` in `flash-attn4/torch-ext/flash_attn4/__init__.py` to the + version from the tag (e.g. for tag `fa4-v4.0.0.beta8` set it to + `"4.0.0.beta8"`). Remove any `importlib.metadata` version lookup code. +- Check whether any Torch custom ops are defined in `flash-attn4/torch-ext/flash_attn4` + or `flash-attn4/torch-ext/flash_attn4/quack` (look for `torch.library.custom_op`, + `torch.library.define`, etc.). If any are found, update them to use + `add_op_namespace_prefix` for the op name. For example, a definition like + `@torch.library.custom_op("_flash_attn_forward", mutates_args=(), device_types="cuda")` + should become + `@torch.library.custom_op(add_op_namespace_prefix("_flash_attn_forward"), mutates_args=(), device_types="cuda")`. + `add_op_namespace_prefix` is imported from `._ops` (see + `flash-attn3/torch-ext/flash_attn3/flash_attn_interface.py` for an example). + +If the user did not specify the version tag, stop and ask which tag to sync +from. diff --git a/flash-attn4/tests/cute/benchmark_mask_mod.py b/flash-attn4/tests/cute/benchmark_mask_mod.py index 4eef607c..841834de 100644 --- a/flash-attn4/tests/cute/benchmark_mask_mod.py +++ b/flash-attn4/tests/cute/benchmark_mask_mod.py @@ -14,7 +14,7 @@ import numpy as np import torch -from flash_attn4.flash_fwd import FlashAttentionForwardSm90 +from flash_attn4.flash_fwd_sm90 import FlashAttentionForwardSm90 from mask_mod_definitions import ( get_mask_pair, random_doc_id_tensor, diff --git a/flash-attn4/tests/cute/conftest.py b/flash-attn4/tests/cute/conftest.py index 6ee05e9a..d2162255 100644 --- a/flash-attn4/tests/cute/conftest.py +++ b/flash-attn4/tests/cute/conftest.py @@ -1,5 +1,11 @@ import os import subprocess +import logging +import tempfile +import json +import time +from pathlib import Path +from getpass import getuser def _get_gpu_ids(): @@ -16,16 +22,50 @@ def _get_gpu_ids(): ) if result.returncode == 0: return result.stdout.strip().splitlines() - except (FileNotFoundError, subprocess.TimeoutExpired): + except (FileNotFoundError,): pass + logging.warning("Failed to get gpu ids, use default '0'") return ["0"] def pytest_configure(config): + tmp = Path(tempfile.gettempdir()) / getuser() / "flash_attention_tests" + tmp.mkdir(parents=True, exist_ok=True) + worker_id = os.environ.get("PYTEST_XDIST_WORKER") + logging.basicConfig( + format=config.getini("log_file_format"), + filename=str(tmp / f"tests_{worker_id}.log"), + level=config.getini("log_file_level"), + ) if not worker_id: return worker_num = int(worker_id.replace("gw", "")) - gpu_ids = _get_gpu_ids() + + # cache gpu_ids, because nvidia-smi is expensive when we launch many workers doing torch initialization + # Always elect worker_0 to get gpu_ids. + cached_gpu_ids = tmp / "gpu_ids.json" + if worker_num == 0: + gpu_ids = _get_gpu_ids() + with cached_gpu_ids.open(mode="w") as f: + json.dump(gpu_ids, f) + else: + while not cached_gpu_ids.exists(): + time.sleep(1) + with cached_gpu_ids.open() as f: + gpu_ids = json.load(f) + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[worker_num % len(gpu_ids)] + +def pytest_collection_finish(session): + if not session.config.option.collectonly: + return + + # file_name -> test_name -> counter + test_counts: dict[str, dict[str, int]] = {} + for item in session.items: + funcname = item.function.__name__ + parent = test_counts.setdefault(item.parent.name, {}) + parent[funcname] = parent.setdefault(funcname, 0) + 1 + print(json.dumps(test_counts, indent=2)) diff --git a/flash-attn4/tests/cute/test_clc_fuzz.py b/flash-attn4/tests/cute/test_clc_fuzz.py new file mode 100644 index 00000000..380b26fa --- /dev/null +++ b/flash-attn4/tests/cute/test_clc_fuzz.py @@ -0,0 +1,576 @@ +"""Adversarial regression tests for CLC tile scheduling. + +These cases intentionally target scheduler-sensitive shapes: mismatched +sequence lengths, non-aligned tiles, GQA ratios, minimal problems, and +larger persistent workloads. This is deterministic adversarial coverage, +not randomized fuzzing. +""" + +from contextlib import contextmanager +import os +from unittest import mock + +import pytest +import torch + +from flash_attn4 import utils as cute_utils +from flash_attn4.flash_fwd_sm100 import FlashAttentionForwardSm100 +from flash_attn4.interface import flash_attn_func, flash_attn_varlen_func +from flash_attn4.testing import attention_ref +from flash_attn4.tile_scheduler import SchedulingMode, SingleTileLPTScheduler, SingleTileVarlenScheduler + + +if torch.cuda.is_available(): + COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] + SM_COUNT = torch.cuda.get_device_properties("cuda").multi_processor_count +else: + COMPUTE_CAPABILITY = 0 + SM_COUNT = 0 +pytestmark = pytest.mark.skipif( + COMPUTE_CAPABILITY not in (10, 11), + reason="CLC adversarial tests require SM100/SM110 persistent forward", +) + +_captured_schedulers: list[tuple[type, SchedulingMode, bool]] = [] +_orig_init = FlashAttentionForwardSm100.__init__ + + +def _spy_init(self_inner, *a, **kw): + _orig_init(self_inner, *a, **kw) + _captured_schedulers.append(( + self_inner.TileScheduler, + self_inner.scheduling_mode, + self_inner.use_2cta_instrs, + )) + + +@contextmanager +def clc_scheduler_enabled(): + with ( + mock.patch.dict(os.environ, {"FA_CLC": "1"}, clear=False), + mock.patch.object(cute_utils, "_fa_clc_enabled", True), + mock.patch.object(FlashAttentionForwardSm100, "__init__", _spy_init), + ): + yield + + +def check_output(q, k, v, *, causal=False, window_size=(None, None), num_splits=1, assert_clc=True, assert_2cta=False): + _captured_schedulers.clear() + out, _ = flash_attn_func(q, k, v, causal=causal, window_size=window_size, num_splits=num_splits) + torch.cuda.synchronize() + if assert_clc and _captured_schedulers: + sched_cls, sched_mode, use_2cta = _captured_schedulers[-1] + assert sched_cls is SingleTileLPTScheduler, f"Expected SingleTileLPTScheduler, got {sched_cls.__name__}" + assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" + if assert_2cta: + assert use_2cta, "Expected use_2cta_instrs=True but got False" + out_ref, _ = attention_ref(q, k, v, causal=causal, window_size=window_size) + out_pt, _ = attention_ref(q, k, v, causal=causal, window_size=window_size, upcast=False, reorder_ops=True) + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + assert (out - out_ref).abs().max().item() <= 2 * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol, ( + f"max_diff={(out - out_ref).abs().max().item()}, " + f"pt_max_diff={(out_pt - out_ref).abs().max().item()}, " + f"fwd_atol={fwd_atol}, " + f"q={list(q.shape)} k={list(k.shape)} v={list(v.shape)} " + f"causal={causal} window_size={window_size} num_splits={num_splits}" + ) + + +def randn(b, s, h, d): + return torch.randn(b, s, h, d, device="cuda", dtype=torch.bfloat16) + + +def expected_total_tiles_mha(batch, seqlen_q, heads): + q_stage = 2 if COMPUTE_CAPABILITY == 10 and seqlen_q > 128 else 1 + num_block = (seqlen_q + q_stage * 128 - 1) // (q_stage * 128) + return num_block * heads * batch + + +@pytest.fixture(autouse=True) +def seed(): + torch.random.manual_seed(42) + + +@pytest.fixture(autouse=True) +def enable_clc_scheduler(): + with clc_scheduler_enabled(): + yield + + +class TestCLCMismatchedSeqlens: + + @pytest.mark.parametrize("sq,sk", [ + (128, 512), + (128, 1024), + (128, 2048), + (256, 64), + (256, 128), + (512, 127), + (512, 129), + (64, 4096), + (1, 128), + (1, 512), + (1, 1024), + ]) + def test_qk_mismatch(self, sq, sk): + check_output(randn(4, sq, 4, 128), randn(4, sk, 4, 128), randn(4, sk, 4, 128)) + + @pytest.mark.parametrize("sq,sk", [ + (128, 513), + (256, 1023), + (64, 257), + (192, 383), + (1, 255), + ]) + def test_qk_mismatch_nonaligned_k(self, sq, sk): + check_output(randn(4, sq, 4, 128), randn(4, sk, 4, 128), randn(4, sk, 4, 128)) + + @pytest.mark.parametrize("sq,sk", [ + (1, 128), + (1, 256), + (1, 1024), + (2, 128), + (3, 512), + ]) + def test_tiny_q_long_k(self, sq, sk): + check_output(randn(2, sq, 4, 128), randn(2, sk, 4, 128), randn(2, sk, 4, 128)) + + +class TestCLCNonAlignedShapes: + @pytest.mark.parametrize("sq", [1, 3, 7, 15, 31, 33, 63, 65, 127, 129, 191, 193, 255, 257]) + def test_nonaligned_q(self, sq): + check_output(randn(2, sq, 4, 128), randn(2, 256, 4, 128), randn(2, 256, 4, 128)) + + @pytest.mark.parametrize("sk", [1, 7, 31, 33, 63, 65, 127, 129, 255, 257, 511, 513]) + def test_nonaligned_k(self, sk): + check_output(randn(2, 256, 4, 128), randn(2, sk, 4, 128), randn(2, sk, 4, 128)) + + +class TestCLCPrimes: + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (1, 1, 127, 131), + (3, 5, 131, 127), + (7, 3, 257, 251), + (11, 7, 67, 509), + (13, 1, 191, 193), + (5, 11, 61, 67), + (2, 3, 509, 127), + ]) + def test_all_prime(self, batch, heads, sq, sk): + check_output( + randn(batch, sq, heads, 128), + randn(batch, sk, heads, 128), + randn(batch, sk, heads, 128), + ) + + +class TestCLC2CTA: + @pytest.mark.parametrize("sq,sk", [ + (512, 512), + (512, 127), + (512, 129), + (512, 2048), + (1024, 64), + (768, 1024), + (512, 64), + ]) + def test_2cta_qk_mismatch(self, sq, sk): + check_output(randn(4, sq, 4, 128), randn(4, sk, 4, 128), randn(4, sk, 4, 128), assert_2cta=True) + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (1, 1, 512, 128), + (1, 1, 512, 512), + (3, 5, 768, 1024), + (7, 3, 512, 127), + (9, 7, 1024, 257), + (13, 1, 512, 64), + ]) + def test_2cta_adversarial_combos(self, batch, heads, sq, sk): + check_output( + randn(batch, sq, heads, 128), + randn(batch, sk, heads, 128), + randn(batch, sk, heads, 128), + assert_2cta=True, + ) + + +class TestCLCGQA: + @pytest.mark.parametrize("q_heads,kv_heads,sq,sk", [ + (4, 1, 128, 512), + (4, 1, 256, 127), + (8, 1, 64, 1024), + (8, 2, 512, 129), + (8, 4, 1, 256), + (6, 2, 192, 383), + (6, 3, 128, 1), + (12, 4, 257, 511), + ]) + def test_gqa_mismatch(self, q_heads, kv_heads, sq, sk): + check_output( + randn(4, sq, q_heads, 128), + randn(4, sk, kv_heads, 128), + randn(4, sk, kv_heads, 128), + ) + + @pytest.mark.parametrize("q_heads,kv_heads", [ + (4, 1), (4, 2), (8, 1), (8, 2), (8, 4), (6, 2), (6, 3), (12, 4), + ]) + def test_gqa_ratios(self, q_heads, kv_heads): + check_output( + randn(4, 512, q_heads, 128), + randn(4, 512, kv_heads, 128), + randn(4, 512, kv_heads, 128), + ) + + +class TestCLCHeadDim: + @pytest.mark.parametrize("d,dv,sq,sk", [ + (64, 64, 128, 512), + (64, 64, 1, 256), + (96, 96, 255, 127), + (128, 64, 192, 384), + (128, 64, 1, 1024), + ]) + def test_head_dims_adversarial(self, d, dv, sq, sk): + check_output(randn(4, sq, 4, d), randn(4, sk, 4, d), randn(4, sk, 4, dv)) + + def test_overlap_sO_sQ_fallback(self): + from flash_attn4.tile_scheduler import SingleTileScheduler + + _captured_schedulers.clear() + check_output(randn(4, 128, 4, 192), randn(4, 257, 4, 192), randn(4, 257, 4, 128), assert_clc=False) + assert _captured_schedulers, "No scheduler was captured" + sched_cls, sched_mode, *_ = _captured_schedulers[-1] + assert sched_cls is SingleTileScheduler, f"Expected SingleTileScheduler fallback, got {sched_cls.__name__}" + assert sched_mode == SchedulingMode.STATIC, f"Expected STATIC fallback, got {sched_mode!r}" + + +class TestCLCFallback: + + def test_varlen_uses_clc(self): + _captured_schedulers.clear() + batch, seqlen, heads, d = 4, 256, 4, 128 + lens = torch.tensor([64, 128, 32, 32], dtype=torch.int32) + cu_seqlens = torch.cat([torch.zeros(1, dtype=torch.int32), lens.cumsum(0)]).to(device="cuda", dtype=torch.int32) + total = int(cu_seqlens[-1]) + q = torch.randn(total, heads, d, device="cuda", dtype=torch.bfloat16) + k = torch.randn(total, heads, d, device="cuda", dtype=torch.bfloat16) + v = torch.randn(total, heads, d, device="cuda", dtype=torch.bfloat16) + out, _ = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=int(lens.max()), + max_seqlen_k=int(lens.max()), + ) + torch.cuda.synchronize() + assert _captured_schedulers, "No scheduler was captured" + sched_cls, sched_mode, *_ = _captured_schedulers[-1] + assert sched_cls is SingleTileVarlenScheduler, ( + f"Expected SingleTileVarlenScheduler for varlen, got {sched_cls.__name__}" + ) + assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" + + @pytest.mark.parametrize("sq,sk,wl,wr", [ + (512, 512, 128, 128), + (256, 1024, 64, 64), + (512, 512, 255, 0), + (128, 2048, 32, 512), + ]) + def test_local_window_with_clc(self, sq, sk, wl, wr): + check_output( + randn(4, sq, 4, 128), + randn(4, sk, 4, 128), + randn(4, sk, 4, 128), + window_size=(wl, wr), + ) + + +def check_varlen_output(seqlens, heads, d, *, causal=False, kv_heads=None, num_splits=1): + kv_heads = kv_heads or heads + cu_seqlens = torch.cat([torch.zeros(1, dtype=torch.int32), torch.tensor(seqlens, dtype=torch.int32).cumsum(0)]).to(device="cuda", dtype=torch.int32) + total = int(cu_seqlens[-1]) + max_s = max(seqlens) + q = torch.randn(total, heads, d, device="cuda", dtype=torch.bfloat16) + k = torch.randn(total, kv_heads, d, device="cuda", dtype=torch.bfloat16) + v = torch.randn(total, kv_heads, d, device="cuda", dtype=torch.bfloat16) + + _captured_schedulers.clear() + out, _ = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_s, + max_seqlen_k=max_s, + causal=causal, + num_splits=num_splits, + ) + torch.cuda.synchronize() + if _captured_schedulers: + sched_cls, sched_mode, *_ = _captured_schedulers[-1] + assert sched_cls is SingleTileVarlenScheduler, f"Expected SingleTileVarlenScheduler, got {sched_cls.__name__}" + assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" + + for i in range(len(seqlens)): + s = slice(cu_seqlens[i], cu_seqlens[i + 1]) + qi, ki, vi, oi = q[s].unsqueeze(0), k[s].unsqueeze(0), v[s].unsqueeze(0), out[s].unsqueeze(0) + out_ref_i, _ = attention_ref(qi, ki, vi, causal=causal) + out_pt_i, _ = attention_ref(qi, ki, vi, causal=causal, upcast=False, reorder_ops=True) + fwd_atol = 2 * (out_ref_i + 0.3 - 0.3 - out_ref_i).abs().max().item() + assert (oi - out_ref_i).abs().max().item() <= 2 * ( + out_pt_i - out_ref_i + ).abs().max().item() + fwd_atol, ( + f"batch={i} max_diff={(oi - out_ref_i).abs().max().item()}, " + f"pt_max_diff={(out_pt_i - out_ref_i).abs().max().item()}, " + f"seqlens={seqlens} heads={heads} d={d} causal={causal} num_splits={num_splits}" + ) + + +def check_varlen_output_seqused(seqlens, heads, d, *, causal=False, kv_heads=None, num_splits=1): + kv_heads = kv_heads or heads + batch = len(seqlens) + max_s = max(seqlens) + seqused = torch.tensor(seqlens, device="cuda", dtype=torch.int32) + q = torch.randn(batch, max_s, heads, d, device="cuda", dtype=torch.bfloat16) + k = torch.randn(batch, max_s, kv_heads, d, device="cuda", dtype=torch.bfloat16) + v = torch.randn(batch, max_s, kv_heads, d, device="cuda", dtype=torch.bfloat16) + q_mask = torch.arange(max_s, device="cuda")[None, :] < seqused[:, None] + k_mask = q_mask + + _captured_schedulers.clear() + out, _ = flash_attn_varlen_func( + q, + k, + v, + max_seqlen_q=max_s, + max_seqlen_k=max_s, + seqused_q=seqused, + seqused_k=seqused, + causal=causal, + num_splits=num_splits, + ) + torch.cuda.synchronize() + if _captured_schedulers: + sched_cls, sched_mode, *_ = _captured_schedulers[-1] + assert sched_cls is SingleTileVarlenScheduler, f"Expected SingleTileVarlenScheduler, got {sched_cls.__name__}" + assert sched_mode == SchedulingMode.CLC, f"Expected CLC scheduling mode, got {sched_mode!r}" + + out_ref, _ = attention_ref(q, k, v, q_mask, k_mask, causal=causal) + out_pt, _ = attention_ref(q, k, v, q_mask, k_mask, causal=causal, upcast=False, reorder_ops=True) + q_mask_4d = q_mask.unsqueeze(-1).unsqueeze(-1) + out_masked = out.clone().masked_fill_(~q_mask_4d, 0.0) + out_ref_masked = out_ref.clone().masked_fill_(~q_mask_4d, 0.0) + out_pt_masked = out_pt.clone().masked_fill_(~q_mask_4d, 0.0) + fwd_atol = 2 * (out_ref_masked + 0.3 - 0.3 - out_ref_masked).abs().max().item() + assert (out_masked - out_ref_masked).abs().max().item() <= 2 * ( + out_pt_masked - out_ref_masked + ).abs().max().item() + fwd_atol, ( + f"max_diff={(out_masked - out_ref_masked).abs().max().item()}, " + f"pt_max_diff={(out_pt_masked - out_ref_masked).abs().max().item()}, " + f"seqlens={seqlens} heads={heads} d={d} causal={causal} num_splits={num_splits}" + ) + + +class TestCLCVarlen: + + @pytest.mark.parametrize("seqlens", [ + [64, 128, 32, 32], + [256, 64, 128, 256], + [1, 512, 1, 1], + [128, 128, 128, 128], + [512, 256, 128, 64], + [1, 1, 1, 1], + [255, 129, 63, 193], + ]) + def test_varlen_basic(self, seqlens): + check_varlen_output(seqlens, heads=4, d=128) + + @pytest.mark.parametrize("seqlens", [ + [64, 128, 32, 32], + [256, 64, 128, 256], + [512, 256, 128, 64], + [255, 129, 63, 193], + ]) + def test_varlen_causal(self, seqlens): + check_varlen_output(seqlens, heads=4, d=128, causal=True) + + @pytest.mark.parametrize("seqlens", [ + [64, 128, 32, 32], + [1, 512, 1, 1], + [255, 129, 63, 193], + ]) + def test_varlen_gqa(self, seqlens): + check_varlen_output(seqlens, heads=8, d=128, kv_heads=2) + + @pytest.mark.parametrize("seqlens,heads", [ + pytest.param([512], 4, id="single_batch"), + pytest.param([256, 128], 8, id="two_batch"), + pytest.param([64] * 32, 4, id="many_batches"), + pytest.param([1, 1, 1, 1024, 1, 1, 1, 1], 4, id="imbalanced"), + ]) + def test_varlen_edge_cases(self, seqlens, heads): + check_varlen_output(seqlens, heads=heads, d=128) + + @pytest.mark.parametrize("seqlens", [ + [127, 131, 251, 193], + [1, 3, 7, 13, 31, 61], + [509, 127, 251, 67], + ]) + def test_varlen_primes(self, seqlens): + check_varlen_output(seqlens, heads=4, d=128) + + @pytest.mark.parametrize("d", [64, 96, 128]) + def test_varlen_head_dims(self, d): + check_varlen_output([128, 256, 64, 192], heads=4, d=d) + + @pytest.mark.parametrize("trial", range(3)) + def test_varlen_repeatability(self, trial): + torch.random.manual_seed(trial) + check_varlen_output([64, 128, 32, 256, 1, 512], heads=4, d=128) + + @pytest.mark.parametrize("seqlens", [ + [64, 128, 32, 256], + [255, 129, 63, 193], + ]) + @pytest.mark.parametrize("num_splits", [2, 3]) + def test_varlen_splitkv(self, seqlens, num_splits): + check_varlen_output(seqlens, heads=4, d=64, num_splits=num_splits) + + @pytest.mark.parametrize("seqlens", [ + [64, 128, 32, 256], + [255, 129, 63, 193], + ]) + @pytest.mark.parametrize("num_splits", [2, 3]) + def test_varlen_seqused_splitkv(self, seqlens, num_splits): + check_varlen_output_seqused(seqlens, heads=4, d=64, num_splits=num_splits) + + @pytest.mark.parametrize("seqlens", [ + [64, 128, 32, 256], + [255, 129, 63, 193], + ]) + @pytest.mark.parametrize("num_splits", [2, 3]) + def test_varlen_splitkv_gqa(self, seqlens, num_splits): + check_varlen_output(seqlens, heads=8, kv_heads=2, d=64, num_splits=num_splits) + + @pytest.mark.parametrize("seqlens", [ + [64, 128, 32, 256], + [255, 129, 63, 193], + ]) + @pytest.mark.parametrize("num_splits", [2, 3]) + def test_varlen_seqused_splitkv_gqa(self, seqlens, num_splits): + check_varlen_output_seqused(seqlens, heads=8, kv_heads=2, d=64, num_splits=num_splits) + + +class TestCLCMinimal: + @pytest.mark.parametrize("sq,sk", [(1, 1), (1, 2), (2, 1), (1, 128), (128, 1)]) + def test_minimal(self, sq, sk): + check_output(randn(1, sq, 1, 128), randn(1, sk, 1, 128), randn(1, sk, 1, 128)) + + def test_single_element(self): + check_output(randn(1, 1, 1, 64), randn(1, 1, 1, 64), randn(1, 1, 1, 64)) + + +class TestCLCCausal: + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (3, 5, 259, 259), + (7, 3, 513, 513), + (1, 7, 1023, 1023), + (5, 11, 2049, 2049), + (2, 3, 4097, 4097), + ]) + def test_causal_square(self, batch, heads, sq, sk): + check_output(randn(batch, sq, heads, 128), randn(batch, sk, heads, 128), randn(batch, sk, heads, 128), causal=True) + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (3, 7, 127, 513), + (5, 3, 259, 1023), + (7, 5, 63, 2049), + (11, 1, 1, 511), + (2, 9, 1, 1025), + (9, 3, 33, 4097), + ]) + def test_causal_qk_mismatch(self, batch, heads, sq, sk): + check_output(randn(batch, sq, heads, 128), randn(batch, sk, heads, 128), randn(batch, sk, heads, 128), causal=True) + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (3, 7, 191, 191), + (7, 5, 193, 193), + (5, 3, 383, 383), + (11, 1, 129, 509), + (2, 13, 1, 131), + (9, 3, 67, 251), + ]) + def test_causal_nonaligned(self, batch, heads, sq, sk): + check_output(randn(batch, sq, heads, 128), randn(batch, sk, heads, 128), randn(batch, sk, heads, 128), causal=True) + + @pytest.mark.parametrize("batch,q_heads,kv_heads,sq", [ + (3, 6, 2, 513), + (7, 8, 1, 259), + (5, 12, 4, 1023), + (2, 8, 2, 2049), + (11, 4, 1, 191), + ]) + def test_causal_gqa(self, batch, q_heads, kv_heads, sq): + check_output( + randn(batch, sq, q_heads, 128), + randn(batch, sq, kv_heads, 128), + randn(batch, sq, kv_heads, 128), + causal=True, + ) + + def test_causal_large(self): + check_output(randn(3, 4097, 13, 128), randn(3, 4097, 13, 128), randn(3, 4097, 13, 128), causal=True) + + +class TestCLCLargeScale: + def test_large_batch(self): + check_output(randn(32, 512, 8, 128), randn(32, 512, 8, 128), randn(32, 512, 8, 128)) + + def test_long_seq(self): + check_output(randn(2, 4096, 4, 128), randn(2, 4096, 4, 128), randn(2, 4096, 4, 128)) + + def test_many_heads(self): + check_output(randn(4, 512, 32, 128), randn(4, 512, 32, 128), randn(4, 512, 32, 128)) + + @pytest.mark.parametrize("batch,heads,sq,sk", [ + (24, 8, 768, 2048), + (16, 8, 1536, 4096), + (12, 8, 2305, 4096), + ]) + def test_work_stealing_pressure(self, batch, heads, sq, sk): + total_tiles = expected_total_tiles_mha(batch, sq, heads) + assert total_tiles > SM_COUNT, f"expected total_tiles={total_tiles} > sm_count={SM_COUNT}" + check_output( + randn(batch, sq, heads, 128), + randn(batch, sk, heads, 128), + randn(batch, sk, heads, 128), + ) + + def test_long_k_short_q(self): + check_output(randn(8, 64, 8, 128), randn(8, 8192, 8, 128), randn(8, 8192, 8, 128)) + + def test_long_q_short_k(self): + check_output(randn(4, 4096, 4, 128), randn(4, 64, 4, 128), randn(4, 64, 4, 128)) + + +class TestCLCRepeatability: + @pytest.mark.parametrize("trial", range(5)) + def test_repeat_mismatch(self, trial): + torch.random.manual_seed(trial) + check_output(randn(7, 192, 5, 128), randn(7, 513, 5, 128), randn(7, 513, 5, 128)) + + @pytest.mark.parametrize("trial", range(5)) + def test_repeat_2cta(self, trial): + torch.random.manual_seed(trial) + check_output(randn(9, 257, 3, 128), randn(9, 511, 3, 128), randn(9, 511, 3, 128)) + + @pytest.mark.parametrize("trial", range(5)) + def test_repeat_gqa_mismatch(self, trial): + torch.random.manual_seed(trial) + check_output(randn(5, 128, 8, 128), randn(5, 1024, 2, 128), randn(5, 1024, 2, 128)) + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/flash-attn4/tests/cute/test_flash_attn.py b/flash-attn4/tests/cute/test_flash_attn.py index df628d9f..76c3f8de 100644 --- a/flash-attn4/tests/cute/test_flash_attn.py +++ b/flash-attn4/tests/cute/test_flash_attn.py @@ -4,6 +4,7 @@ import itertools import os import random +import re import pytest import torch @@ -27,15 +28,15 @@ from flash_attn4.interface import ( flash_attn_func, flash_attn_varlen_func, - flash_attn_combine, ) # torch FakeTensorMode would enable fast cutedsl kernel compilation without allocating the actual GPU memory or running the kernel # When operating fake tensors, we cannot perform data-dependent operations (e.g., `tensor.max()`). USE_FAKE_TENSOR = int(os.getenv("FLASH_ATTENTION_FAKE_TENSOR", 0)) == 1 DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" -# SplitKV and paged KV are not supported on SM90 +# SplitKV is not supported on SM90 IS_SM90 = torch.cuda.get_device_capability()[0] == 9 +IS_SM100 = torch.cuda.get_device_capability()[0] == 10 TEST_BWD_ONLY = False VERBOSE = True @@ -47,8 +48,8 @@ # @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) -# @pytest.mark.parametrize("deterministic", [False, True]) -@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("deterministic", [False, True]) +# @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) @@ -63,7 +64,7 @@ # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) # @pytest.mark.parametrize("d", [128, 192]) -@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("d", [64, 96, 128, 192, 256]) # @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", @@ -252,6 +253,8 @@ def test_flash_attn_output( # SplitKV not supported on SM90 - skip this iteration if IS_SM90 and num_splits > 1: continue + if IS_SM100 and (d >= 192 and dv >= 192): # hdim 192 and 256 not support on SM100 + continue out, lse = flash_attn_func( q, k, @@ -294,16 +297,10 @@ def test_flash_attn_output( # and False and not ((causal or local) and seqlen_k < seqlen_q) ): - # TODO: SM90 backward pass has invalid MMA tile config for d=64 + non-causal - # The m_block_size=80 (non-causal) with head_dim=64 creates an invalid tile. - # Fix requires adjusting m_block_size or MMA config in flash_bwd_sm90.py - if IS_SM90 and d == 64 and not causal: - pytest.xfail("SM90 backward: d=64 + non-causal has invalid MMA tile config (m_block=80)") - # TODO: SM90 backward pass does not support local attention yet - if IS_SM90 and local: - pytest.xfail("SM90 backward: local attention not supported yet") - if d == 192 and local: - pytest.xfail("hdim 192 backward: local attention not supported yet") + if d > 192 and IS_SM90: + pytest.xfail("hdim > 192 backward: SM90 not supported yet") + if d != dv and mha_type != "mha" and IS_SM90: + pytest.xfail("SM90 GQA bwd currently requires headdim == headdim_v") g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) @@ -389,8 +386,8 @@ def test_flash_attn_output( @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) -# @pytest.mark.parametrize("deterministic", [False, True]) -@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("deterministic", [False, True]) +# @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) @@ -406,7 +403,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) # @pytest.mark.parametrize("d", [128, 192]) -@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("d", [64, 128, 192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -484,7 +481,7 @@ def test_flash_attn_varlen_output( seed = seqlen_q + seqlen_k + d + int(causal) * 2 + int(local) random.seed(seed) torch.random.manual_seed(seed) - batch_size = 49 if seqlen_q <= 1024 else 7 + batch_size = 49 if seqlen_q <= 512 else 7 nheads = 6 # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) @@ -717,16 +714,26 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): continue if query_unused_mask is not None: out.masked_fill_(q_zero_masking, 0.0) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # When unpad_q=False with seqused_q, the kernel doesn't write positions + # beyond seqused_q, so those contain uninitialized values. Mask them out + # before comparing. + out_cmp, out_ref_cmp, out_pt_cmp = out, out_ref, out_pt + if not unpad_q and seqused_q is not None: + seqused_mask = torch.arange(seqlen_q, device=device)[None, :] < seqused_q[:, None] + seqused_mask = rearrange(seqused_mask, "b s -> b s 1 1") + out_cmp = out.clone().masked_fill_(~seqused_mask, 0.0) + out_ref_cmp = out_ref.clone().masked_fill_(~seqused_mask, 0.0) + out_pt_cmp = out_pt.clone().masked_fill_(~seqused_mask, 0.0) + print(f"Output max diff: {(out_cmp - out_ref_cmp).abs().max().item()}") + print(f"Output mean diff: {(out_cmp - out_ref_cmp).abs().mean().item()}") # if not causal: # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() # Check that FlashAttention's numerical error is at most 3x the numerical error # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * ( - out_pt - out_ref + assert (out_cmp - out_ref_cmp).abs().max().item() <= rtol * ( + out_pt_cmp - out_ref_cmp ).abs().max().item() + fwd_atol if ( @@ -736,11 +743,12 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): and not attention_chunk != 0 and ((dv == d and d <= 128) or (d == 192 and dv == 128)) and not has_learnable_sink - and not IS_SM90 # and False ): - if d == 192 and local: - pytest.xfail("hdim 192 backward: local attention not supported yet") + if d > 192 and IS_SM90: + pytest.xfail("hdim > 192 backward: SM90 not supported yet") + if d != dv and mha_type != "mha" and IS_SM90: + pytest.xfail("SM90 GQA bwd currently requires headdim == headdim_v") g_unpad = torch.randn_like(out_unpad) # do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) # import flash_attn_3_cuda @@ -941,8 +949,6 @@ def test_flash_attn_kvcache( ): if page_size is not None and seqlen_k % page_size != 0: pytest.skip() - if page_size is not None and IS_SM90: - pytest.xfail("paged KV not supported on SM90") if seqlen_q > seqlen_k and new_kv: pytest.skip() if not new_kv and rotary_fraction > 0.0: @@ -1432,9 +1438,6 @@ def test_flash_attn_kvcache( @pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128), (256, 256)]) @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_bwd_preallocated_outputs(seqlen_q, seqlen_k, d, causal, dtype): - if IS_SM90 and d == 64 and not causal: - pytest.xfail("SM90 backward: d=64 + non-causal has invalid MMA tile config (m_block=80)") - from flash_attn4.interface import _flash_attn_fwd, _flash_attn_bwd device = "cuda" @@ -1468,6 +1471,141 @@ def test_flash_attn_bwd_preallocated_outputs(seqlen_q, seqlen_k, d, causal, dtyp assert torch.allclose(dv, dv_ref, atol=1e-5, rtol=1e-5) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128), (256, 256)]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_lse_grad(seqlen_q, seqlen_k, d, causal, dtype): + """Test that gradient flows through the returned LSE tensor.""" + device = "cuda" + torch.random.manual_seed(42) + batch_size = 2 + nheads = 4 + + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + + out, lse = flash_attn_func(q, k, v, causal=causal, return_lse=True) + + if is_fake_mode(): + return + + assert lse is not None + assert lse.requires_grad + + # Compute loss = sum(out * g) + sum(lse * dlse_weight) to test gradient flows through both + g = torch.randn_like(out) + dlse_weight = torch.randn_like(lse) + loss = (out * g).sum() + (lse * dlse_weight).sum() + dq, dk, dv = torch.autograd.grad(loss, (q, k, v)) + + # Compare against reference: manually compute what the gradients should be + # Reference: standard attention in float + q_ref = q.detach().float().requires_grad_() + k_ref = k.detach().float().requires_grad_() + v_ref = v.detach().float().requires_grad_() + # (batch, seqlen_q, nheads, d) -> (batch, nheads, seqlen_q, d) + qk = torch.einsum("bshd,bthd->bhst", q_ref, k_ref) / (d ** 0.5) + if causal: + mask = torch.triu(torch.ones(seqlen_q, seqlen_k, device=device, dtype=torch.bool), diagonal=seqlen_k - seqlen_q + 1) + qk = qk.masked_fill(mask, float("-inf")) + lse_ref = torch.logsumexp(qk, dim=-1) # (batch, nheads, seqlen_q) + p = torch.softmax(qk, dim=-1) + # v_ref: (batch, seqlen_k, nheads, d) + out_ref = torch.einsum("bhst,bthd->bshd", p, v_ref) + loss_ref = (out_ref * g.float()).sum() + (lse_ref * dlse_weight.float()).sum() + dq_ref, dk_ref, dv_ref = torch.autograd.grad(loss_ref, (q_ref, k_ref, v_ref)) + + # Use relaxed tolerances since flash_attn operates in bf16 while reference is float32. + # The reference is also not a perfect bf16 simulation (it doesn't reorder ops), so + # we use a generous tolerance. + print(f"dQ max diff: {(dq.float() - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk.float() - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv.float() - dv_ref).abs().max().item()}") + # Absolute tolerance: bf16 has ~0.004-0.02 error for these sizes + atol = 0.02 + assert (dq.float() - dq_ref).abs().max().item() <= atol, f"dQ error too large" + assert (dk.float() - dk_ref).abs().max().item() <= atol, f"dK error too large" + assert (dv.float() - dv_ref).abs().max().item() <= atol, f"dV error too large" + + # Also test: gradient with only dLSE (no dO) + out2, lse2 = flash_attn_func(q, k, v, causal=causal, return_lse=True) + loss_lse_only = (lse2 * dlse_weight).sum() + dq2, dk2, dv2 = torch.autograd.grad(loss_lse_only, (q, k, v)) + + q_ref2 = q.detach().float().requires_grad_() + k_ref2 = k.detach().float().requires_grad_() + qk2 = torch.einsum("bshd,bthd->bhst", q_ref2, k_ref2) / (d ** 0.5) + if causal: + qk2 = qk2.masked_fill(mask, float("-inf")) + lse_ref2 = torch.logsumexp(qk2, dim=-1) + loss_ref2 = (lse_ref2 * dlse_weight.float()).sum() + dq_ref2, dk_ref2 = torch.autograd.grad(loss_ref2, (q_ref2, k_ref2)) + + print(f"LSE-only dQ max diff: {(dq2.float() - dq_ref2).abs().max().item()}") + print(f"LSE-only dK max diff: {(dk2.float() - dk_ref2).abs().max().item()}") + # dV should be zero when only LSE gradient flows (LSE doesn't depend on V) + print(f"LSE-only dV max: {dv2.abs().max().item()}") + assert dv2.abs().max().item() == 0.0, "dV should be zero when loss depends only on LSE" + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128)]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_lse_grad_unused(seqlen_q, seqlen_k, d, causal, dtype): + """Test return_lse=True when LSE is returned but not used in the loss. + + With set_materialize_grads(False), dlse should be None (not a zero tensor), + so no extra zeroing kernel is launched. Gradients should match the standard + backward (without return_lse). + """ + device = "cuda" + torch.random.manual_seed(42) + batch_size = 2 + nheads = 4 + + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + g = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + + # Case 1: return_lse=False (standard path, lse marked non-differentiable) + out1, lse1 = flash_attn_func(q, k, v, causal=causal, return_lse=False) + if is_fake_mode(): + return + dq1, dk1, dv1 = torch.autograd.grad(out1, (q, k, v), g) + + # Case 2: return_lse=True but lse NOT used in loss (dlse should be None) + out2, lse2 = flash_attn_func(q, k, v, causal=causal, return_lse=True) + dq2, dk2, dv2 = torch.autograd.grad(out2, (q, k, v), g) + + # Case 3: return_lse=True and lse IS used in loss + out3, lse3 = flash_attn_func(q, k, v, causal=causal, return_lse=True) + dlse_weight = torch.randn_like(lse3) + loss3 = (out3 * g).sum() + (lse3 * dlse_weight).sum() + dq3, dk3, dv3 = torch.autograd.grad(loss3, (q, k, v)) + + # Cases 1 and 2 should produce identical gradients + assert torch.equal(dq1, dq2), "dQ should be identical when LSE is unused" + assert torch.equal(dk1, dk2), "dK should be identical when LSE is unused" + assert torch.equal(dv1, dv2), "dV should be identical when LSE is unused" + + # Case 3 should differ from case 1 (LSE gradient adds extra contribution to dQ, dK) + assert not torch.equal(dq1, dq3), "dQ should differ when LSE gradient is included" + assert not torch.equal(dk1, dk3), "dK should differ when LSE gradient is included" + # dV should be the same since LSE doesn't depend on V + assert torch.equal(dv1, dv3), "dV should be identical since LSE doesn't depend on V" + + print("Case 1 vs 2 (unused LSE): dQ diff =", (dq1 - dq2).abs().max().item()) + print("Case 1 vs 3 (used LSE): dQ diff =", (dq1 - dq3).abs().max().item()) + print("Case 1 vs 3 (used LSE): dK diff =", (dk1 - dk3).abs().max().item()) + print("Case 1 vs 3 (used LSE): dV diff =", (dv1 - dv3).abs().max().item()) + + def _generate_block_kvcache( seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref ): @@ -1500,85 +1638,67 @@ def _generate_block_kvcache( return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks -def attention_combine_ref(out_partial, lse_partial): - """ - out_partial: (num_splits, batch_size, seqlen, nheads, d) - lse_partial: (num_splits, batch_size, seqlen, nheads) +@pytest.mark.parametrize("page_size", [16, 64, 256]) +@pytest.mark.parametrize("seqlen_q", [64, 128, 256]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_paged_deepseek(seqlen_q, page_size): + """Regression test: paged non-TMA with DeepSeek MLA shape (d=192, dv=128). + seqlen_q<=128 triggers q_stage=1, seqlen_q>128 triggers q_stage=2. """ - lse = torch.logsumexp(lse_partial, dim=0) - scale = torch.exp(lse_partial - lse) - scale = torch.where( - torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale + if IS_SM90: + pytest.skip("paged KV not supported on SM90") + device = "cuda" + dtype = torch.bfloat16 + d, dv = 192, 128 + nheads = 16 + nheads_kv = 16 + + torch.random.manual_seed(0) + q = torch.randn(seqlen_q, nheads, d, device=device, dtype=dtype) + k = torch.randn(seqlen_q, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(seqlen_q, nheads_kv, dv, device=device, dtype=dtype) + cu_seqlens = torch.tensor([0, seqlen_q], dtype=torch.int32, device=device) + + # Non-paged reference + out_ref, _ = flash_attn_varlen_func( + q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, + max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, causal=True, ) - out = (scale.unsqueeze(-1) * out_partial).sum(0) - return out, lse + # Paged + num_pages = (seqlen_q + page_size - 1) // page_size + k_cache_paged = torch.zeros(num_pages, page_size, nheads_kv, d, device=device, dtype=dtype) + v_cache_paged = torch.zeros(num_pages, page_size, nheads_kv, dv, device=device, dtype=dtype) + for i in range(seqlen_q): + k_cache_paged[i // page_size, i % page_size] = k[i] + v_cache_paged[i // page_size, i % page_size] = v[i] + page_table = torch.arange(num_pages, dtype=torch.int32, device=device).unsqueeze(0) + cache_seqlens = torch.tensor([seqlen_q], dtype=torch.int32, device=device) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize("dtype", [torch.float32]) -# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) -@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) -# @pytest.mark.parametrize("d", [128]) -@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) -# @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) -# @pytest.mark.parametrize("seqlen", [15]) -@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) -# @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) -# @pytest.mark.parametrize("num_splits", [11]) -@maybe_fake_tensor_mode(USE_FAKE_TENSOR) -def test_flash_attn_combine(num_splits, seqlen, d, dtype): - device = "cuda" - # set seed - torch.random.manual_seed(1) - batch_size = 5 - nheads = 16 - # batch_size = 1 - # nheads = 1 - # Create tensors in the expected format: (num_splits, batch_size, seqlen, nheads, d) and (num_splits, batch_size, seqlen, nheads) - out_partial = torch.randn( - num_splits * 2, - batch_size, - nheads, - seqlen, - d, - device=device, - dtype=torch.float32, - ).transpose(2, 3)[:num_splits] # To test non-contiguous tensor - lse_partial = torch.randn( - num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32 - ).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor - # To test short-circuiting based on num_splits - lse_partial[num_splits // 2 :, : batch_size // 3] = -float("inf") - - # Test with LSE returned (default behavior) - out, lse = flash_attn_combine( - out_partial, lse_partial, out_dtype=dtype, return_lse=True + out, _ = flash_attn_varlen_func( + q, k_cache_paged, v_cache_paged, + cu_seqlens_q=cu_seqlens, cu_seqlens_k=None, + max_seqlen_q=seqlen_q, max_seqlen_k=None, + seqused_k=cache_seqlens, page_table=page_table, causal=True, ) + if is_fake_mode(): return - out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) - out_pt = out_ref.to(dtype) - print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") - print(f"LSE mean diff: {(lse - lse_ref).abs().mean().item()}") print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - # breakpoint() - - assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) - multiple = 2 - assert ( - (out - out_ref).abs().max().item() - <= multiple * (out_pt - out_ref).abs().max().item() - ) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) - - # Test with LSE not returned - out_no_lse, lse_no_lse = flash_attn_combine( - out_partial, lse_partial, out_dtype=dtype, return_lse=False - ) - assert lse_no_lse is None, "LSE should be None when return_lse=False" - assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), ( - "Output should be the same regardless of return_lse" - ) + assert torch.equal(out, out_ref) + + +@pytest.mark.parametrize("head_dim", [4, 148, 288]) +def test_flash_attn_invalid_head_dim(head_dim): + device = "cuda" + dtype = torch.bfloat16 + batch_size, seqlen, nheads = 1, 64, 4 + + q = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype) + k = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype) + + with pytest.raises(AssertionError, match=re.escape(f"(head_dim, head_dim_v)=({head_dim}, {head_dim}) is not supported on SM")): + flash_attn_func(q, k, v) diff --git a/flash-attn4/tests/cute/test_flash_attn_combine.py b/flash-attn4/tests/cute/test_flash_attn_combine.py new file mode 100644 index 00000000..b6a5672c --- /dev/null +++ b/flash-attn4/tests/cute/test_flash_attn_combine.py @@ -0,0 +1,286 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + +import os + +import pytest +import torch + +from flash_attn4.testing import ( + maybe_fake_tensor_mode, + is_fake_mode, +) +from flash_attn4.interface import ( + flash_attn_combine, +) + +USE_FAKE_TENSOR = int(os.getenv("FLASH_ATTENTION_FAKE_TENSOR", 0)) == 1 + + +def attention_combine_ref(out_partial, lse_partial): + """ + out_partial: (num_splits, batch_size, seqlen, nheads, d) + lse_partial: (num_splits, batch_size, seqlen, nheads) + """ + lse = torch.logsumexp(lse_partial, dim=0) + scale = torch.exp(lse_partial - lse) + scale = torch.where( + torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale + ) + out = (scale.unsqueeze(-1) * out_partial).sum(0) + return out, lse + + +def check_combine_results(out, lse, out_ref, lse_ref, dtype): + """Check combine kernel output against reference for a single (seqlen, nheads, d) chunk.""" + out_pt = out_ref.to(dtype) + print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}, " + f"Output max diff: {(out - out_ref).abs().max().item()}, " + f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) + assert ( + (out - out_ref).abs().max().item() + <= 2 * (out_pt - out_ref).abs().max().item() + ) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float32]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) +# @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) +# @pytest.mark.parametrize("seqlen", [15]) +@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) +# @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) +# @pytest.mark.parametrize("num_splits", [11]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_combine(num_splits, seqlen, d, dtype): + device = "cuda" + # set seed + torch.random.manual_seed(1) + batch_size = 5 + nheads = 16 + # batch_size = 1 + # nheads = 1 + # Create tensors in the expected format: (num_splits, batch_size, seqlen, nheads, d) and (num_splits, batch_size, seqlen, nheads) + out_partial = torch.randn( + num_splits * 2, + batch_size, + nheads, + seqlen, + d, + device=device, + dtype=torch.float32, + ).transpose(2, 3)[:num_splits] # To test non-contiguous tensor + lse_partial = torch.randn( + num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32 + ).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor + # To test short-circuiting based on num_splits + lse_partial[num_splits // 2 :, : batch_size // 3] = -float("inf") + + # Test with LSE returned (default behavior) + out, lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, return_lse=True + ) + if is_fake_mode(): + return + out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) + check_combine_results(out, lse, out_ref, lse_ref, dtype) + + # Test with LSE not returned + out_no_lse, lse_no_lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, return_lse=False + ) + assert lse_no_lse is None, "LSE should be None when return_lse=False" + assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), ( + "Output should be the same regardless of return_lse" + ) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("d", [64, 96, 128, 256]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("seqlen", [1, 32, 113, 256, 1024]) +# @pytest.mark.parametrize("seqlen", [113]) +@pytest.mark.parametrize("num_splits", [2, 5, 17, 55]) +# @pytest.mark.parametrize("num_splits", [5]) +@pytest.mark.parametrize( + "varlen_mode", + ["cu_seqlens", "seqused", "cu_seqlens_seqused"], +) +# @pytest.mark.parametrize("varlen_mode", ["cu_seqlens"]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_combine_varlen(varlen_mode, num_splits, seqlen, d, dtype): + device = "cuda" + torch.random.manual_seed(1) + batch_size = 3 + nheads = 8 + use_cu_seqlens = "cu_seqlens" in varlen_mode + use_seqused = "seqused" in varlen_mode + + # Generate variable-length sequences + seqlens = torch.randint(1, seqlen + 1, (batch_size,), device=device, dtype=torch.int32) + # For cu_seqlens+seqused mode, seqused < seqlen (kernel processes fewer tokens) + seqused_vals = ( + torch.clamp( + seqlens - torch.randint(0, max(1, seqlen // 4), (batch_size,), device=device, dtype=torch.int32), + min=1, + ) + if use_cu_seqlens and use_seqused + else seqlens + ) + + if use_cu_seqlens: + # Packed varlen layout: (num_splits, total_q, nheads, d) + total_q = seqlens.sum().item() + cu_seqlens_q = torch.zeros(batch_size + 1, device=device, dtype=torch.int32) + cu_seqlens_q[1:] = torch.cumsum(seqlens, dim=0) + + out_partial = torch.randn( + num_splits * 2, total_q, nheads, d, device=device, dtype=torch.float32, + )[:num_splits] # Non-contiguous in splits dim + # lse_partial needs stride(-2)==1 (seqlen dim contiguous) + lse_partial = torch.randn( + num_splits, nheads, total_q, device=device, dtype=torch.float32 + ).transpose(-1, -2) + lse_partial[num_splits // 2:, :total_q // 3] = -float("inf") + + out, lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, + cu_seqlens=cu_seqlens_q, + seqused=seqused_vals if use_seqused else None, + return_lse=True, + ) + if is_fake_mode(): + return + + # Reference on full packed tensor + out_ref, lse_ref = attention_combine_ref( + out_partial.unsqueeze(1), lse_partial.unsqueeze(1) + ) + out_ref = out_ref.squeeze(0) + lse_ref = lse_ref.squeeze(0) + + # Validate per-batch (only seqused_vals tokens are guaranteed correct) + for i in range(batch_size): + start = cu_seqlens_q[i].item() + sl = seqused_vals[i].item() + check_combine_results( + out[start:start + sl], lse[start:start + sl], + out_ref[start:start + sl], lse_ref[start:start + sl], dtype, + ) + + # Also test return_lse=False + out_no_lse, lse_no_lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, + cu_seqlens=cu_seqlens_q, + seqused=seqused_vals if use_seqused else None, + return_lse=False, + ) + assert lse_no_lse is None + # Only compare valid positions (beyond seqused, output is undefined) + for i in range(batch_size): + start = cu_seqlens_q[i].item() + sl = seqused_vals[i].item() + assert torch.allclose(out_no_lse[start:start + sl], out[start:start + sl], atol=1e-5, rtol=1e-5) + + else: + # seqused only โ€” batched layout: (num_splits, batch, max_seqlen, nheads, d) + max_seqlen = seqlens.max().item() + out_partial = torch.randn( + num_splits, batch_size, max_seqlen, nheads, d, device=device, dtype=torch.float32, + ) + # lse_partial needs stride(-2)==1 (seqlen dim contiguous) + lse_partial = torch.randn( + num_splits, batch_size, nheads, max_seqlen, device=device, dtype=torch.float32, + ).transpose(-1, -2) + lse_partial[num_splits // 2:, :batch_size // 2] = -float("inf") + # Zero out / -inf beyond seqused so reference matches kernel + for i in range(batch_size): + out_partial[:, i, seqlens[i]:] = 0 + lse_partial[:, i, seqlens[i]:] = -float("inf") + + out, lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, seqused=seqlens, return_lse=True, + ) + if is_fake_mode(): + return + + out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) + + # Validate per-batch (only seqused tokens) + for i in range(batch_size): + sl = seqlens[i].item() + check_combine_results( + out[i, :sl], lse[i, :sl], + out_ref[i, :sl], lse_ref[i, :sl], dtype, + ) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("d", [64, 128, 256]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("seqlen", [32, 113, 256]) +# @pytest.mark.parametrize("seqlen", [113]) +@pytest.mark.parametrize("num_splits", [2, 5, 17]) +# @pytest.mark.parametrize("num_splits", [5]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_combine_varlen_batch_idx(num_splits, seqlen, d, dtype): + """Test that varlen_batch_idx correctly remaps virtual batch indices to real batch indices. + + varlen_batch_idx maps blockIdx.z (virtual batch) -> real batch index. The kernel + reads AND writes using the remapped batch_idx, so with a permutation the output + should match running without varlen_batch_idx (each real batch is processed once). + + We also test with seqused to verify interaction with variable-length sequences. + """ + device = "cuda" + torch.random.manual_seed(42) + batch_size = 4 + nheads = 8 + + # Create batched input data + out_partial = torch.randn( + num_splits, batch_size, seqlen, nheads, d, device=device, dtype=torch.float32, + ) + lse_partial = torch.randn( + num_splits, batch_size, nheads, seqlen, device=device, dtype=torch.float32, + ).transpose(-1, -2) # stride(-2)==1 + lse_partial[num_splits // 2:, :batch_size // 2] = -float("inf") + + # Create a permuted batch index mapping: virtual batch -> real batch + perm = torch.tensor([2, 0, 3, 1], device=device, dtype=torch.int32) + assert perm.shape[0] == batch_size + + # Also test with seqused to verify interaction with varlen_batch_idx + seqused = torch.randint(1, seqlen + 1, (batch_size,), device=device, dtype=torch.int32) + # Zero out / -inf beyond seqused so reference matches kernel + for i in range(batch_size): + out_partial[:, i, seqused[i]:] = 0 + lse_partial[:, i, seqused[i]:] = -float("inf") + + # Run with varlen_batch_idx and seqused via public API + out, lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, + seqused=seqused, + varlen_batch_idx=perm, + return_lse=True, + ) + if is_fake_mode(): + return + + # Reference: standard combine (no remapping needed since perm is a bijection + # and both reads and writes use the remapped batch_idx) + out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) + + # The kernel reads from input[perm[v]] and writes to output[perm[v]], + # so the net result is output[b] = combine(input[b]) for all b. + for b in range(batch_size): + sl = seqused[b].item() + check_combine_results( + out[b, :sl], lse[b, :sl], + out_ref[b, :sl], lse_ref[b, :sl], dtype, + ) diff --git a/flash-attn4/tests/cute/test_flash_attn_fast.py b/flash-attn4/tests/cute/test_flash_attn_fast.py new file mode 100644 index 00000000..3425cbe1 --- /dev/null +++ b/flash-attn4/tests/cute/test_flash_attn_fast.py @@ -0,0 +1,331 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# Fast subset of test_flash_attn.py for quick iteration. +# Covers: causal/noncausal, varlen/not varlen, MHA/GQA, split/not split, fwd+bwd. + +import os +import random + +import pytest +import torch + +from einops import rearrange + +from flash_attn4.testing import ( + attention_ref, + generate_random_padding_mask, + generate_qkv, + maybe_fake_tensor_mode, + is_fake_mode, +) +from flash_attn4.interface import ( + flash_attn_func, + flash_attn_varlen_func, + flash_attn_combine, +) + +USE_FAKE_TENSOR = int(os.getenv("FLASH_ATTENTION_FAKE_TENSOR", 0)) == 1 +IS_SM90 = torch.cuda.get_device_capability()[0] == 9 + + +# --------------------------------------------------------------------------- +# Forward + backward (non-varlen) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "gqa", "mqa"]) +@pytest.mark.parametrize("num_splits", [1, 3]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (128, 128), + (256, 256), + (113, 203), + (1024, 1024), + ], +) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_output(seqlen_q, seqlen_k, d, causal, num_splits, mha_type, dtype): + if IS_SM90 and num_splits > 1: + pytest.skip("SM90 fwd doens't support num_splits > 1") + device = "cuda" + torch.random.manual_seed(0) + random.seed(0) + torch.cuda.empty_cache() + batch_size = 4 + nheads = 6 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype).to(dtype).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_() + + q = q_ref.detach().to(dtype).requires_grad_() + k = k_ref.detach().to(dtype).requires_grad_() + v = v_ref.detach().to(dtype).requires_grad_() + + out_ref, _ = attention_ref(q_ref, k_ref, v_ref, None, None, causal=causal) + out_pt, _ = attention_ref( + q_ref, k_ref, v_ref, None, None, causal=causal, upcast=False, reorder_ops=True, + ) + + out, lse = flash_attn_func(q, k, v, causal=causal, num_splits=num_splits) + + if is_fake_mode(): + return + + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + fwd_atol + + # Backward (only for non-split, matching d) + can_bwd = ( + num_splits == 1 + and d <= 128 + and not (causal and seqlen_k < seqlen_q) + ) + if IS_SM90 and d == 64 and not causal: + can_bwd = False # SM90 d=64 non-causal xfail + if not can_bwd: + return + + g = torch.randn_like(out) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + dq_atol + assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + dk_atol + assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + dv_atol + + +# --------------------------------------------------------------------------- +# Forward + backward (varlen with cu_seqlens) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "gqa", "mqa"]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("seqlen", [128, 256, 1024]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_varlen_output(seqlen, d, causal, mha_type, dtype): + """Varlen test with cu_seqlens (packed): equal seqlens so we can compare with non-varlen ref.""" + device = "cuda" + seed = seqlen + d + int(causal) * 2 + torch.random.manual_seed(seed) + random.seed(seed) + batch_size = 9 + nheads = 6 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + + q_ref = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype).to(dtype).requires_grad_() + k_ref = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_() + v_ref = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_() + + out_ref, _ = attention_ref(q_ref, k_ref, v_ref, None, None, causal=causal) + out_pt, _ = attention_ref( + q_ref, k_ref, v_ref, None, None, causal=causal, upcast=False, reorder_ops=True, + ) + + cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, seqlen, device=device, dtype=torch.int32) + q_varlen = rearrange(q_ref.detach(), "b s h d -> (b s) h d").requires_grad_() + k_varlen = rearrange(k_ref.detach(), "b s h d -> (b s) h d").requires_grad_() + v_varlen = rearrange(v_ref.detach(), "b s h d -> (b s) h d").requires_grad_() + + out_varlen, lse = flash_attn_varlen_func( + q_varlen, k_varlen, v_varlen, + cu_seqlens, cu_seqlens, + seqlen, seqlen, + causal=causal, + ) + + if is_fake_mode(): + return + + out_reshaped = rearrange(out_varlen, "(b s) h d -> b s h d", b=batch_size) + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + assert (out_reshaped - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + fwd_atol + + # Backward + can_bwd = d <= 128 + if not can_bwd: + return + + g = torch.randn_like(out_varlen) + dq_varlen, dk_varlen, dv_varlen = torch.autograd.grad(out_varlen, (q_varlen, k_varlen, v_varlen), g) + + assert dq_varlen.isfinite().all(), "dq contains non-finite values" + assert dk_varlen.isfinite().all(), "dk contains non-finite values" + assert dv_varlen.isfinite().all(), "dv contains non-finite values" + assert dq_varlen.abs().max().item() > 0, "dq is all zeros" + assert dk_varlen.abs().max().item() > 0, "dk is all zeros" + assert dv_varlen.abs().max().item() > 0, "dv is all zeros" + + +# --------------------------------------------------------------------------- +# Forward + backward (varlen with padding masks โ€” all unpad combinations) +# Covers 4 compile-key-distinct paths: +# (unpad_q, unpad_kv) = (T,T): cu_seqlens for both Q and K +# (unpad_q, unpad_kv) = (F,F): seqused for both Q and K +# (unpad_q, unpad_kv) = (T,F): cu_seqlens_q + seqused_k +# (unpad_q, unpad_kv) = (F,T): seqused_q + cu_seqlens_k +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "gqa", "mqa"]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("seqlen", [128, 256]) +@pytest.mark.parametrize( + "unpad_q,unpad_kv", + [(True, True), (False, False), (True, False), (False, True)], +) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_varlen_unpad_output(seqlen, d, causal, mha_type, unpad_q, unpad_kv, dtype): + """Varlen test with all 4 (unpad_q, unpad_kv) combos: cu_seqlens vs seqused.""" + device = "cuda" + seed = seqlen + d + int(causal) * 2 + int(unpad_q) * 7 + int(unpad_kv) * 13 + torch.random.manual_seed(seed) + random.seed(seed) + batch_size = 9 + nheads = 6 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + + q = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) + k = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype) + q_ref = q.detach().to(dtype).requires_grad_() + k_ref = k.detach().to(dtype).requires_grad_() + v_ref = v.detach().to(dtype).requires_grad_() + + query_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random") + key_padding_mask = query_padding_mask if causal else generate_random_padding_mask( + seqlen, batch_size, device, mode="random" + ) + + ( + q_unpad_t, k_unpad_t, v_unpad_t, _qv_unpad, + cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, + max_seqlen_q, max_seqlen_k, + q_padded, k_padded, v_padded, _qv_padded, + output_pad_fn, dq_pad_fn, dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask) + + out_ref, _ = attention_ref( + q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal, + ) + out_pt, _ = attention_ref( + q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal, + upcast=False, reorder_ops=True, + ) + + # Select Q input: packed (unpad) or padded (seqused) + if unpad_q: + q_in = q_unpad_t.detach().to(dtype).requires_grad_() + else: + q_in = q.detach().to(dtype).requires_grad_() + # Select KV input: packed (unpad) or padded (seqused) + if unpad_kv: + k_in = k_unpad_t.detach().to(dtype).requires_grad_() + v_in = v_unpad_t.detach().to(dtype).requires_grad_() + else: + k_in = k.detach().to(dtype).requires_grad_() + v_in = v.detach().to(dtype).requires_grad_() + + out_unpad, lse = flash_attn_varlen_func( + q_in, k_in, v_in, + cu_seqlens_q=cu_seqlens_q if unpad_q else None, + cu_seqlens_k=cu_seqlens_k if unpad_kv else None, + max_seqlen_q=seqlen, + max_seqlen_k=seqlen, + seqused_q=seqused_q if not unpad_q else None, + seqused_k=seqused_k if not unpad_kv else None, + causal=causal, + ) + + if is_fake_mode(): + return + + # Reshape output to (batch, seqlen, nheads, d) for comparison + out = output_pad_fn(out_unpad) if unpad_q else out_unpad + + # Mask out padding positions โ€” kernel output at padding positions is undefined + q_mask = rearrange(query_padding_mask, "b s -> b s 1 1") + out_masked = out.clone().masked_fill_(~q_mask, 0.0) + out_ref_masked = out_ref.clone().masked_fill_(~q_mask, 0.0) + out_pt_masked = out_pt.clone().masked_fill_(~q_mask, 0.0) + + fwd_atol = 2 * (out_ref_masked + 0.3 - 0.3 - out_ref_masked).abs().max().item() + assert (out_masked - out_ref_masked).abs().max().item() <= 2 * (out_pt_masked - out_ref_masked).abs().max().item() + fwd_atol + + # Backward (original test skips all SM90 varlen backward) + can_bwd = d <= 128 and not IS_SM90 + if not can_bwd: + return + + g = torch.randn_like(out_unpad) + dq_in, dk_in, dv_in = torch.autograd.grad(out_unpad, (q_in, k_in, v_in), g) + + # Mask out padding positions again + k_mask = rearrange(key_padding_mask, "b s -> b s 1 1") + if not unpad_q: + dq_in = dq_in.clone().masked_fill_(~q_mask, 0.0) + if not unpad_kv: + dk_in = dk_in.clone().masked_fill_(~k_mask, 0.0) + dv_in = dv_in.clone().masked_fill_(~k_mask, 0.0) + + assert dq_in.isfinite().all(), "dq contains non-finite values" + assert dk_in.isfinite().all(), "dk contains non-finite values" + assert dv_in.isfinite().all(), "dv contains non-finite values" + assert dq_in.abs().max().item() > 0, "dq is all zeros" + assert dk_in.abs().max().item() > 0, "dk is all zeros" + assert dv_in.abs().max().item() > 0, "dv is all zeros" + + +# --------------------------------------------------------------------------- +# Combine kernel +# --------------------------------------------------------------------------- + +def attention_combine_ref(out_partial, lse_partial): + lse = torch.logsumexp(lse_partial, dim=0) + scale = torch.exp(lse_partial - lse) + scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale) + out = (scale.unsqueeze(-1) * out_partial).sum(0) + return out, lse + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("seqlen", [32, 256]) +@pytest.mark.parametrize("num_splits", [2, 5, 17]) +@maybe_fake_tensor_mode(USE_FAKE_TENSOR) +def test_flash_attn_combine(num_splits, seqlen, d, dtype): + device = "cuda" + torch.random.manual_seed(1) + batch_size = 3 + nheads = 8 + + # out_partial: (num_splits, batch, seqlen, nheads, d) with stride(-1)==1 + # lse_partial: (num_splits, batch, seqlen, nheads) with stride(-2)==1 (seqlen contiguous) + out_partial = torch.randn( + num_splits, batch_size, seqlen, nheads, d, device=device, dtype=torch.float32, + ) + lse_partial = torch.randn( + num_splits, batch_size, nheads, seqlen, device=device, dtype=torch.float32, + ).transpose(-1, -2) + lse_partial[num_splits // 2 :, : batch_size // 3] = -float("inf") + + out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype, return_lse=True) + if is_fake_mode(): + return + out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) + out_pt = out_ref.to(dtype) + + assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) diff --git a/flash-attn4/tests/cute/test_flash_attn_race_condition.py b/flash-attn4/tests/cute/test_flash_attn_race_condition.py index 752d9f2c..cc489db4 100644 --- a/flash-attn4/tests/cute/test_flash_attn_race_condition.py +++ b/flash-attn4/tests/cute/test_flash_attn_race_condition.py @@ -76,6 +76,9 @@ def test_flash_attn_output( local = local_enum > 0 if local and causal: pytest.skip() + is_sm90 = torch.cuda.get_device_capability()[0] == 9 + if is_sm90 and d == 192: + pytest.xfail("headdim 192 not supported on sm90") device = "cuda" # set seed torch.random.manual_seed(0) @@ -252,8 +255,6 @@ def test_flash_attn_output( pytest.xfail("SM90 backward: GQA/MQA has tensor layout issue (qhead_per_kvhead > 1)") if IS_SM90 and local: pytest.xfail("SM90 backward: local attention not supported yet") - if d == 192 and local: - pytest.xfail("hdim 192 backward: local attention not supported yet") g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) @@ -412,8 +413,8 @@ def test_flash_attn_varlen_output( is_sm90 = torch.cuda.get_device_capability()[0] == 9 if is_sm90 and local: pytest.xfail("bwd local attention not supported on sm90") - if is_sm90 and deterministic: - pytest.xfail("bwd deterministic not supported on sm90") + if is_sm90 and d == 192: + pytest.xfail("headdim 192 not supported on sm90") if ( causal or local ): # Right now reference only supports causal attention with seqlen_k == seqlen_q @@ -655,8 +656,6 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): and not is_sm90 # and False ): - if d == 192 and local: - pytest.xfail("hdim 192 backward: local attention not supported yet") g_unpad = torch.randn_like(out_unpad) # do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) # import flash_attn_3_cuda diff --git a/flash-attn4/tests/cute/test_flash_attn_varlen.py b/flash-attn4/tests/cute/test_flash_attn_varlen.py index e2178187..0822c461 100644 --- a/flash-attn4/tests/cute/test_flash_attn_varlen.py +++ b/flash-attn4/tests/cute/test_flash_attn_varlen.py @@ -1,15 +1,10 @@ -import itertools from typing import Optional -from einops import rearrange import pytest import torch import torch.nn.functional as F from flash_attn4 import flash_attn_varlen_func -IS_SM90 = torch.cuda.get_device_capability()[0] == 9 - - @pytest.mark.parametrize("B", [1, 7, 20]) @pytest.mark.parametrize("H", [1, 4, 6]) @pytest.mark.parametrize("D", [64, 128]) @@ -43,9 +38,6 @@ def test_varlen( dtype=dtype ) - # SM90 backward pass doesn't support varlen yet - skip_backward = IS_SM90 - ok = check_varlen_vs_torch_flash( q, k, v, cu_seqlens_q, cu_seqlens_k, @@ -53,7 +45,6 @@ def test_varlen( softmax_scale=softmax_scale, causal=causal, mha_type=mha_type, - skip_backward=skip_backward, ) assert ok @@ -71,7 +62,6 @@ def check_varlen_vs_torch_flash( softcap=0.0, atol=3e-2, rtol=3e-2, - skip_backward=False, ): assert q.requires_grad and k.requires_grad and v.requires_grad, "Set requires_grad=True on inputs" @@ -128,10 +118,6 @@ def clone_like(t): if not ok_fwd: return False - # Skip backward if not supported (e.g., SM90 varlen) - if skip_backward: - return True - # Use the same upstream gradient to compare backward paths grad_out = torch.randn_like(out_fa) @@ -312,4 +298,4 @@ def _stats(name, a, b, atol, rtol): mean_abs = diff.abs().mean().item() mean_rel = (diff.abs().mean() / b.abs().clamp_min(1e-6).mean().item()) print(f"{name}: mean_abs={mean_abs:.4e}, mean_rel={mean_rel:.4e}, sum_fa={a.sum()}, sum_ref={b.sum()}") - return mean_abs < atol and mean_rel < rtol \ No newline at end of file + return mean_abs < atol and mean_rel < rtol diff --git a/flash-attn4/tests/cute/test_mask_mod.py b/flash-attn4/tests/cute/test_mask_mod.py index 643e9957..2da84b3a 100644 --- a/flash-attn4/tests/cute/test_mask_mod.py +++ b/flash-attn4/tests/cute/test_mask_mod.py @@ -108,6 +108,76 @@ def compute_reference_flex_attn(tensors, mask_mod_flex, block_size: tuple[int, i return out_ref.transpose(1, 2).contiguous() +def assert_fwd_matches_reference(out_cute, out_ref_fp32, out_pt, test_desc: str | None = None): + assert out_cute.shape == out_ref_fp32.shape == out_pt.shape + assert not torch.isnan(out_cute).any() + assert not torch.isnan(out_ref_fp32).any() + assert torch.isfinite(out_cute).all() + assert torch.isfinite(out_ref_fp32).all() + + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + if test_desc is not None: + print(f"\n{test_desc}") + print(" Reference implementation: FlexAttention") + print(f" PyTorch vs FP32: {pt_error:.2e}") + print(f" Kernel vs FP32: {cute_error:.2e}") + print(f" Tolerance: rtol={rtol} * {pt_error:.2e} + {fwd_atol:.2e}") + + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" + ) + + +def assert_bwd_matches_reference( + dq_cute, + dk_cute, + dv_cute, + dq_ref_fp32, + dk_ref_fp32, + dv_ref_fp32, + dq_pt, + dk_pt, + dv_pt, + dtype, + min_seqlen: int, +): + assert not torch.isnan(dq_cute).any(), "dQ contains NaN" + assert not torch.isnan(dk_cute).any(), "dK contains NaN" + assert not torch.isnan(dv_cute).any(), "dV contains NaN" + + bwd_rtol = 2 + bwd_atol_floor = 1e-5 if min_seqlen >= 64 else 3e-5 + dq_atol = max(bwd_atol_floor, 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item()) + dk_atol = max(bwd_atol_floor, 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item()) + dv_atol = max(bwd_atol_floor, 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item()) + + dq_ref = dq_ref_fp32.to(dtype) + dk_ref = dk_ref_fp32.to(dtype) + dv_ref = dv_ref_fp32.to(dtype) + + pt_dq_err = (dq_pt - dq_ref).abs().max().item() + pt_dk_err = (dk_pt - dk_ref).abs().max().item() + pt_dv_err = (dv_pt - dv_ref).abs().max().item() + + cute_dq_err = (dq_cute - dq_ref).abs().max().item() + cute_dk_err = (dk_cute - dk_ref).abs().max().item() + cute_dv_err = (dv_cute - dv_ref).abs().max().item() + + print(" Backward comparison:") + print(f" dQ: PT err={pt_dq_err:.2e}, CuTE err={cute_dq_err:.2e}, atol={dq_atol:.2e}") + print(f" dK: PT err={pt_dk_err:.2e}, CuTE err={cute_dk_err:.2e}, atol={dk_atol:.2e}") + print(f" dV: PT err={pt_dv_err:.2e}, CuTE err={cute_dv_err:.2e}, atol={dv_atol:.2e}") + + assert cute_dq_err <= bwd_rtol * pt_dq_err + dq_atol, f"dQ error too large: {cute_dq_err:.2e}" + assert cute_dk_err <= bwd_rtol * pt_dk_err + dk_atol, f"dK error too large: {cute_dk_err:.2e}" + assert cute_dv_err <= bwd_rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}" + + def get_coarse_block_mask_pair(sparse_tile_m: int, tile_n: int, last_block: int): @fast_sampling @cute.jit @@ -349,8 +419,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): window_size_left=window_left, window_size_right=window_right, learnable_sink=None, - m_block_size=tile_m, - n_block_size=tile_n, + tile_mn=(tile_m, tile_n), pack_gqa=pack_gqa, _arch=None, score_mod=None, @@ -371,18 +440,8 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): out_ref_fp32 = compute_reference_flex_attn(tensors_fp32, mask_mod_flex, block_size) out_ref = compute_reference_flex_attn(tensors, mask_mod_flex, block_size) out_pt = out_ref.clone() - - # Check for invalid values - assert out_cute.shape == out_ref_fp32.shape == out_ref.shape - assert not torch.isnan(out_cute).any() - assert not torch.isnan(out_ref_fp32).any() - assert torch.isfinite(out_cute).all() - assert torch.isfinite(out_ref_fp32).all() - - # Compute numerical tolerance (matching flash attention tests) fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() rtol = 2 - ref_error = (out_ref - out_ref_fp32).abs().max().item() pt_error = (out_pt - out_ref_fp32).abs().max().item() cute_error = (out_cute - out_ref_fp32).abs().max().item() @@ -413,10 +472,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): print(f" DEBUG: Kernel value: {out_cute[max_diff_coords]:.6f}") print(f" DEBUG: Reference value: {out_ref_fp32[max_diff_coords]:.6f}") - # Use the same assertion logic as FlashAttention tests - assert cute_error <= rtol * pt_error + fwd_atol, ( - f"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" - ) + assert_fwd_matches_reference(out_cute, out_ref_fp32, out_pt, mask_desc) if needs_backward: q = tensors["q"] @@ -444,38 +500,19 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): q, k, v, flex_block_mask, grad_out ) - # Check for invalid values - assert not torch.isnan(dq_cute).any(), "dQ contains NaN" - assert not torch.isnan(dk_cute).any(), "dK contains NaN" - assert not torch.isnan(dv_cute).any(), "dV contains NaN" - - bwd_rtol = 2 - min_seqlen = min(seqlen_q, seqlen_k) - bwd_atol_floor = 1e-5 if min_seqlen >= 64 else 3e-5 - dq_atol = max(bwd_atol_floor, 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item()) - dk_atol = max(bwd_atol_floor, 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item()) - dv_atol = max(bwd_atol_floor, 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item()) - - dq_ref = dq_ref_fp32.to(dtype) - dk_ref = dk_ref_fp32.to(dtype) - dv_ref = dv_ref_fp32.to(dtype) - - pt_dq_err = (dq_pt - dq_ref).abs().max().item() - pt_dk_err = (dk_pt - dk_ref).abs().max().item() - pt_dv_err = (dv_pt - dv_ref).abs().max().item() - - cute_dq_err = (dq_cute - dq_ref).abs().max().item() - cute_dk_err = (dk_cute - dk_ref).abs().max().item() - cute_dv_err = (dv_cute - dv_ref).abs().max().item() - - print(" Backward comparison:") - print(f" dQ: PT err={pt_dq_err:.2e}, CuTE err={cute_dq_err:.2e}, atol={dq_atol:.2e}") - print(f" dK: PT err={pt_dk_err:.2e}, CuTE err={cute_dk_err:.2e}, atol={dk_atol:.2e}") - print(f" dV: PT err={pt_dv_err:.2e}, CuTE err={cute_dv_err:.2e}, atol={dv_atol:.2e}") - - assert cute_dq_err <= bwd_rtol * pt_dq_err + dq_atol, f"dQ error too large: {cute_dq_err:.2e}" - assert cute_dk_err <= bwd_rtol * pt_dk_err + dk_atol, f"dK error too large: {cute_dk_err:.2e}" - assert cute_dv_err <= bwd_rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}" + assert_bwd_matches_reference( + dq_cute, + dk_cute, + dv_cute, + dq_ref_fp32, + dk_ref_fp32, + dv_ref_fp32, + dq_pt, + dk_pt, + dv_pt, + dtype, + min(seqlen_q, seqlen_k), + ) def test_mask_mod_ima_partial_block(): @@ -622,7 +659,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): seqused_q=None, seqused_k=None, page_table=None, causal=False, softcap=None, window_size_left=-1, window_size_right=-1, - m_block_size=tile_m, n_block_size=tile_n, pack_gqa=False, + tile_mn=(tile_m, tile_n), pack_gqa=False, _arch=None, score_mod=None, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_mask_fwd, @@ -790,8 +827,7 @@ def test_sm100_block_sparse_sink_all_masked(): window_size_left=None, window_size_right=None, learnable_sink=learnable_sink, - m_block_size=128, - n_block_size=128, + tile_mn=(128, 128), num_threads=384, pack_gqa=False, block_sparse_tensors=sparse, @@ -908,8 +944,7 @@ def test_sm100_block_sparse_coarse_blocks(): window_size_left=None, window_size_right=None, learnable_sink=None, - m_block_size=tile_m, - n_block_size=tile_n, + tile_mn=(tile_m, tile_n), pack_gqa=False, _arch=None, score_mod=None, @@ -997,7 +1032,7 @@ def wrapped_normalize(*args, **kwargs): observed["q_subtile_factor"] = q_subtile_factor return normalized, pattern, q_subtile_factor - with mock.patch("flash_attn4.interface.normalize_block_sparse_config", wrapped_normalize): + with mock.patch("flash_attn.cute.interface.normalize_block_sparse_config", wrapped_normalize): out_cute, _ = _flash_attn_fwd( q=tensors["q"], k=tensors["k"], @@ -1015,8 +1050,7 @@ def wrapped_normalize(*args, **kwargs): window_size_left=None, window_size_right=None, learnable_sink=None, - m_block_size=tile_m, - n_block_size=tile_n, + tile_mn=(tile_m, tile_n), pack_gqa=False, _arch=None, score_mod=None, @@ -1144,6 +1178,9 @@ def test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message(): tensors = create_tensors(batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim, dtype) mask_mod_cute, mask_mod_flex = get_mask_pair("block_diagonal", seqlen_q=seqlen_q, seqlen_k=seqlen_k) + + # Use a block_size whose Q dimension doesn't divide m_block_size (100 % 80 != 0) + bad_block_size_q = 100 bm = create_block_mask( mask_mod_flex, batch_size, @@ -1151,7 +1188,7 @@ def test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message(): seqlen_q, seqlen_k, device="cuda", - BLOCK_SIZE=(tile_m, tile_n), + BLOCK_SIZE=(bad_block_size_q, tile_n), ) ( _seq_q, @@ -1172,7 +1209,7 @@ def test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message(): mask_block_idx=q_mask_idx, full_block_cnt=full_q_cnt, full_block_idx=full_q_idx, - block_size=(tile_m, tile_n), + block_size=(bad_block_size_q, tile_n), ) softmax_scale = 1.0 / math.sqrt(headdim) @@ -1182,7 +1219,7 @@ def test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message(): with pytest.raises( ValueError, - match=r"Block sparsity expects sparse_block_size_q=128 for subtile_factor=2\.", + match=r"Block sparsity expects sparse_block_size_q=", ): _flash_attn_bwd( q=tensors["q"], @@ -1200,6 +1237,209 @@ def test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message(): ) +@pytest.mark.skipif(COMPUTE_CAPABILITY != 9, reason="SM90-only test") +def test_sm90_block_sparse_infers_block_size(): + torch.manual_seed(0) + + batch_size = 1 + nheads = 4 + seqlen_q = 128 + seqlen_k = 128 + headdim = 64 + tile_m = 128 + tile_n = 128 + dtype = torch.bfloat16 + softmax_scale = 1.0 / math.sqrt(headdim) + + q = torch.randn(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype) + k = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) + + def block_causal(batch, head, q_idx, kv_idx): + return kv_idx // tile_n <= q_idx // tile_m + + bm = create_block_mask( + block_causal, + batch_size, + nheads, + seqlen_q, + seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _seq_q, + _seq_k, + kv_mask_cnt, + kv_mask_idx, + full_kv_cnt, + full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = bm.as_tuple() + + block_sparse_mask_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + block_size=None, + ) + block_sparse_mask_bwd = BlockSparseTensorsTorch( + mask_block_cnt=q_mask_cnt, + mask_block_idx=q_mask_idx, + full_block_cnt=full_q_cnt, + full_block_idx=full_q_idx, + block_size=None, + ) + + out, lse = _flash_attn_fwd( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + causal=False, + block_sparse_tensors=block_sparse_mask_fwd, + return_lse=True, + ) + grad_out = torch.randn_like(out) + dq, dk, dv = run_cute_mask_bwd( + q, + k, + v, + out, + lse, + grad_out, + None, + block_sparse_mask_bwd=block_sparse_mask_bwd, + tile_m=tile_m, + tile_n=tile_n, + ) + + out_ref, dq_ref, dk_ref, dv_ref = run_flex_reference_bwd( + q, k, v, bm, grad_out, dtype=torch.float32 + ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, bm, grad_out) + assert_fwd_matches_reference(out, out_ref, out_pt) + assert_bwd_matches_reference( + dq, + dk, + dv, + dq_ref, + dk_ref, + dv_ref, + dq_pt, + dk_pt, + dv_pt, + dtype, + min(seqlen_q, seqlen_k), + ) + + +@pytest.mark.skipif(COMPUTE_CAPABILITY != 9, reason="SM90-only test") +def test_sm90_block_sparse_explicit_192_block_size(): + torch.manual_seed(0) + + batch_size = 1 + nheads = 4 + seqlen_q = 384 + seqlen_k = 384 + headdim = 96 + block_size_q = 192 + block_size_kv = 128 + dtype = torch.bfloat16 + softmax_scale = 1.0 / math.sqrt(headdim) + + q = torch.randn(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype) + k = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) + + def block_causal(batch, head, q_idx, kv_idx): + return (q_idx >= block_size_q) & (kv_idx < block_size_kv) + + bm = create_block_mask( + block_causal, + batch_size, + nheads, + seqlen_q, + seqlen_k, + device="cuda", + BLOCK_SIZE=(block_size_q, block_size_kv), + ) + ( + _seq_q, + _seq_k, + kv_mask_cnt, + kv_mask_idx, + full_kv_cnt, + full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = bm.as_tuple() + + block_sparse_mask_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + block_size=(block_size_q, block_size_kv), + ) + block_sparse_mask_bwd = BlockSparseTensorsTorch( + mask_block_cnt=q_mask_cnt, + mask_block_idx=q_mask_idx, + full_block_cnt=full_q_cnt, + full_block_idx=full_q_idx, + block_size=(block_size_q, block_size_kv), + ) + + out, lse = _flash_attn_fwd( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + causal=True, + block_sparse_tensors=block_sparse_mask_fwd, + return_lse=True, + ) + grad_out = torch.randn_like(out) + dq, dk, dv = _flash_attn_bwd( + q=q, + k=k, + v=v, + out=out, + dout=grad_out, + lse=lse, + softmax_scale=softmax_scale, + causal=True, + block_sparse_tensors=block_sparse_mask_bwd, + ) + + out_ref, dq_ref, dk_ref, dv_ref = run_flex_reference_bwd( + q, k, v, bm, grad_out, dtype=torch.float32 + ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, bm, grad_out) + assert_fwd_matches_reference(out, out_ref, out_pt) + assert_bwd_matches_reference( + dq, + dk, + dv, + dq_ref, + dk_ref, + dv_ref, + dq_pt, + dk_pt, + dv_pt, + dtype, + min(seqlen_q, seqlen_k), + ) + + def test_gqa_block_sparse_broadcast_pattern_recompilation(): """Test that different block sparse broadcast patterns trigger recompilation. @@ -1269,7 +1509,7 @@ def run_with_block_mask_nheads(block_mask_nheads: int) -> tuple[torch.Tensor, to q=q, k=k, v=v, out=out, lse=lse, softmax_scale=softmax_scale, causal=False, window_size_left=-1, window_size_right=-1, - m_block_size=tile_m, n_block_size=tile_n, pack_gqa=False, + tile_mn=(tile_m, tile_n), pack_gqa=False, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_fwd, return_lse=True, ) @@ -1344,7 +1584,7 @@ def test_gqa_expand_stride_zero_bug(): q=q, k=k, v=v, out=out, lse=lse, softmax_scale=softmax_scale, causal=True, - m_block_size=128, n_block_size=128, + tile_mn=(128, 128), return_lse=True, ) out_fwd, lse_fwd = out_tuple[0], out_tuple[1] @@ -1460,7 +1700,7 @@ def test_persistent_blocksparse_empty_tiles(): causal=False, softcap=None, window_size_left=None, window_size_right=None, learnable_sink=None, - m_block_size=tile_m, n_block_size=tile_n, + tile_mn=(tile_m, tile_n), pack_gqa=False, _arch=None, score_mod=None, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_mask_fwd, @@ -1472,5 +1712,91 @@ def test_persistent_blocksparse_empty_tiles(): +def test_compact_block_sparse_indices(): + """Test that compact block sparse index tensors (idx.shape[3] < n_blocks) work correctly. + + FA4 only accesses indices 0..cnt-1 per query tile, so the index tensor's last + dimension does not need to be as large as ceil(seqlen_k / block_size_n). This + test verifies that truncated (compact) index tensors produce identical output + to full-sized ones. + """ + torch.manual_seed(42) + batch_size = 1 + nheads = 4 + seqlen_q = 1024 + seqlen_k = 1024 + headdim = 128 + tile_m = 128 + tile_n = 128 + dtype = torch.bfloat16 + + sparse_tile_m = 2 * tile_m if COMPUTE_CAPABILITY == 10 else tile_m + + mask_mod_cute, mask_mod_flex = get_mask_pair( + "block_diagonal", seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=None + ) + tensors = create_tensors( + batch_size, seqlen_q, seqlen_k, nheads, nheads, headdim, headdim, dtype + ) + + bm = create_block_mask( + mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k, + device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n), + ) + (_, _, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, *_) = bm.as_tuple() + + # Determine the max count across all query tiles โ€” this is the compact last dim + max_mask_k = kv_mask_cnt.max().item() if kv_mask_cnt is not None else 0 + max_full_k = full_kv_cnt.max().item() if full_kv_cnt is not None else 0 + max_k = max(max_mask_k, max_full_k, 1) + + # Truncate index tensors to compact size + kv_mask_idx_compact = kv_mask_idx[:, :, :, :max_k].contiguous() + full_kv_idx_compact = full_kv_idx[:, :, :, :max_k].contiguous() if full_kv_idx is not None else None + + block_sparse_compact = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx_compact, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx_compact, + block_size=(sparse_tile_m, tile_n), + ) + + out_compact, _ = _flash_attn_fwd( + q=tensors["q"], k=tensors["k"], v=tensors["v"], + out=tensors["out"].clone(), lse=tensors["lse"].clone(), + softmax_scale=1.0 / math.sqrt(headdim), + causal=False, mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_compact, + return_lse=True, + ) + + # Reference: use full-sized index tensors + block_sparse_full = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + block_size=(sparse_tile_m, tile_n), + ) + + out_full, _ = _flash_attn_fwd( + q=tensors["q"], k=tensors["k"], v=tensors["v"], + out=tensors["out"].clone(), lse=tensors["lse"].clone(), + softmax_scale=1.0 / math.sqrt(headdim), + causal=False, mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_full, + return_lse=True, + ) + + assert not torch.isnan(out_compact).any(), "Compact output has NaN" + assert torch.isfinite(out_compact).all(), "Compact output has Inf" + # Compact and full should produce bit-identical results + assert torch.equal(out_compact, out_full), ( + f"Compact and full block sparse outputs differ: " + f"max diff = {(out_compact - out_full).abs().max().item():.2e}" + ) + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) diff --git a/flash-attn4/tests/cute/test_score_mod.py b/flash-attn4/tests/cute/test_score_mod.py index 105ff590..f03b8996 100644 --- a/flash-attn4/tests/cute/test_score_mod.py +++ b/flash-attn4/tests/cute/test_score_mod.py @@ -4,8 +4,9 @@ import cutlass.cute as cute from cutlass._mlir.dialects import math as mlir_math import operator -from torch.nn.attention.flex_attention import flex_attention -from flash_attn4.interface import _flash_attn_fwd, _flash_attn_bwd +from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from flash_attn4.interface import _flash_attn_fwd, _flash_attn_bwd, _tile_size_bwd_sm90 +from flash_attn4.block_sparsity import BlockSparseTensorsTorch COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] @@ -107,7 +108,7 @@ (4224, 4224), ] -VEC_SIZES_TO_CHECK_EQUALITY = [1, 4] +VEC_SIZES_TO_CHECK_EQUALITY = [1, 2, 4] if COMPUTE_CAPABILITY == 10 else [1, 2] def create_tensors( @@ -225,7 +226,6 @@ def test_cute_score_mod_vectorized( for vec_size in VEC_SIZES_TO_CHECK_EQUALITY: cute_vectorized_score_mod.__vec_size__ = vec_size out = run_cute_flash(q, k, v, cute_vectorized_score_mod, pack_gqa=pack_gqa) - assert torch.equal(out, out_ref) @@ -342,9 +342,13 @@ def test_cute_score_mod_with_aux_tensors_vectorized( for vec_size in VEC_SIZES_TO_CHECK_EQUALITY: cute_vectorized_score_mod.__vec_size__ = vec_size out = run_cute_flash( - q, k, v, cute_vectorized_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa + q, + k, + v, + cute_vectorized_score_mod, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, ) - assert torch.equal(out, out_ref) @@ -809,6 +813,158 @@ def run_flex_reference_bwd(q, k, v, eager_score_mod, grad_out, dtype=None): return out, dq, dk, dv +@pytest.mark.skipif(COMPUTE_CAPABILITY != 9, reason="SM90-only test") +def test_sm90_block_sparse_score_mod_backward_with_dq_swapab(): + torch.random.manual_seed(42) + + batch_size = 1 + num_heads = 4 + seqlen_q = 640 + seqlen_kv = 640 + dim = 128 + block_size_q = 640 + block_size_kv = 128 + dtype = torch.bfloat16 + + cfg = _tile_size_bwd_sm90( + dim, + dim, + causal=False, + local=False, + sparse_block_size_q=block_size_q, + ) + assert cfg.m_block_size == 80 + assert cfg.dQ_swapAB + + q, k, v = create_tensors( + batch_size=batch_size, + num_heads=num_heads, + seqlen_q=seqlen_q, + seqlen_kv=seqlen_kv, + dim=dim, + dtype=dtype, + ) + + def prefix_visible(batch, head, q_idx, kv_idx): + return kv_idx < 3 * block_size_kv + + block_mask = create_block_mask( + prefix_visible, + B=batch_size, + H=num_heads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_kv, + device=q.device, + BLOCK_SIZE=(block_size_q, block_size_kv), + ) + ( + _seq_q, + _seq_k, + kv_mask_cnt, + kv_mask_idx, + full_kv_cnt, + full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = block_mask.as_tuple() + + block_sparse_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + block_size=(block_size_q, block_size_kv), + ) + block_sparse_bwd = BlockSparseTensorsTorch( + mask_block_cnt=q_mask_cnt, + mask_block_idx=q_mask_idx, + full_block_cnt=full_q_cnt, + full_block_idx=full_q_idx, + block_size=(block_size_q, block_size_kv), + ) + + q_t = q.transpose(1, 2) + k_t = k.transpose(1, 2) + v_t = v.transpose(1, 2) + out, lse = _flash_attn_fwd( + q_t, + k_t, + v_t, + return_lse=True, + score_mod=score_mod_squared, + block_sparse_tensors=block_sparse_fwd, + ) + grad_out = torch.randn_like(out) + dq, dk, dv = _flash_attn_bwd( + q_t, + k_t, + v_t, + out, + grad_out, + lse, + score_mod=score_mod_squared, + score_mod_bwd=score_mod_bwd_squared, + block_sparse_tensors=block_sparse_bwd, + ) + + def run_flex_block_sparse_score_mod_ref(q_ref, k_ref, v_ref, grad_out_ref, ref_dtype=None): + if ref_dtype is not None: + q_ref = q_ref.to(ref_dtype).requires_grad_(True) + k_ref = k_ref.to(ref_dtype).requires_grad_(True) + v_ref = v_ref.to(ref_dtype).requires_grad_(True) + grad_out_ref = grad_out_ref.to(ref_dtype) + else: + q_ref = q_ref.requires_grad_(True) + k_ref = k_ref.requires_grad_(True) + v_ref = v_ref.requires_grad_(True) + + compiled_flex = torch.compile(flex_attention) + out_ref = compiled_flex( + q_ref, + k_ref, + v_ref, + block_mask=block_mask, + score_mod=score_squared_eager, + enable_gqa=False, + ) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), grad_out_ref) + return out_ref, dq_ref, dk_ref, dv_ref + + out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_block_sparse_score_mod_ref( + q, k, v, grad_out.transpose(1, 2), ref_dtype=torch.float32 + ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_block_sparse_score_mod_ref( + q, k, v, grad_out.transpose(1, 2) + ) + + rtol = 2 + out_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + dq_atol = 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item() + dk_atol = 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item() + dv_atol = 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item() + + out_ref = out_ref_fp32.to(dtype) + dq_ref = dq_ref_fp32.to(dtype) + dk_ref = dk_ref_fp32.to(dtype) + dv_ref = dv_ref_fp32.to(dtype) + + assert (out.transpose(1, 2) - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + out_atol + assert (dq.transpose(1, 2) - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + assert (dk.transpose(1, 2) - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + assert (dv.transpose(1, 2) - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol + + @pytest.mark.parametrize( "seqlen_q,seqlen_kv", [ diff --git a/flash-attn4/tests/cute/test_score_mod_varlen.py b/flash-attn4/tests/cute/test_score_mod_varlen.py index 4a222212..ac7678ff 100644 --- a/flash-attn4/tests/cute/test_score_mod_varlen.py +++ b/flash-attn4/tests/cute/test_score_mod_varlen.py @@ -65,6 +65,7 @@ ) IS_SM90 = torch.cuda.get_device_capability()[0] == 9 +IS_SM100 = torch.cuda.get_device_capability()[0] == 10 # ============================================================================= # Test pairs @@ -172,7 +173,7 @@ ([1, 1, 1], [256 * 1024] * 3), ] -VEC_SIZES_TO_CHECK_EQUALITY = [1, 4] +VEC_SIZES_TO_CHECK_EQUALITY = [1, 2, 4] if IS_SM100 else [1, 2] # ============================================================================= # Helper functions @@ -590,7 +591,6 @@ def test_varlen_with_score_mod_vectorized( cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, ) - assert torch.equal(out, out_ref) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) diff --git a/flash-attn4/tests/cute/test_utils.py b/flash-attn4/tests/cute/test_utils.py index 10f50e18..59a8685b 100644 --- a/flash-attn4/tests/cute/test_utils.py +++ b/flash-attn4/tests/cute/test_utils.py @@ -1,4 +1,4 @@ -"""Unit tests for flash_attn4.utils module.""" +"""Unit tests for flash_attn.cute.utils module.""" import functools diff --git a/flash-attn4/torch-ext/flash_attn4/AUTHORS b/flash-attn4/torch-ext/flash_attn4/AUTHORS index bc3991c6..055e75b6 100644 --- a/flash-attn4/torch-ext/flash_attn4/AUTHORS +++ b/flash-attn4/torch-ext/flash_attn4/AUTHORS @@ -1,5 +1,8 @@ -Tri Dao, tri@tridao.me +Tri Dao Jay Shah Ted Zadouri Markus Hoehnerbach -Vijay Thakkar \ No newline at end of file +Vijay Thakkar +Timmy Liu +Driss Guessous +Reuben Stern \ No newline at end of file diff --git a/flash-attn4/torch-ext/flash_attn4/README.md b/flash-attn4/torch-ext/flash_attn4/README.md index 61aa412c..c7f1b32e 100644 --- a/flash-attn4/torch-ext/flash_attn4/README.md +++ b/flash-attn4/torch-ext/flash_attn4/README.md @@ -8,6 +8,12 @@ FlashAttention-4 is a CuTeDSL-based implementation of FlashAttention for Hopper pip install flash-attn-4 ``` +If you're on CUDA 13, install with the `cu13` extra for best performance: + +```sh +pip install "flash-attn-4[cu13]" +``` + ## Usage ```python @@ -21,6 +27,7 @@ out = flash_attn_func(q, k, v, causal=True) ```sh git clone https://github.com/Dao-AILab/flash-attention.git cd flash-attention -pip install -e "flash_attn/cute[dev]" +pip install -e "flash_attn/cute[dev]" # CUDA 12.x +pip install -e "flash_attn/cute[dev,cu13]" # CUDA 13.x (e.g. B200) pytest tests/cute/ ``` diff --git a/flash-attn4/torch-ext/flash_attn4/__init__.py b/flash-attn4/torch-ext/flash_attn4/__init__.py index 563d5a77..149e45bf 100644 --- a/flash-attn4/torch-ext/flash_attn4/__init__.py +++ b/flash-attn4/torch-ext/flash_attn4/__init__.py @@ -1,19 +1,15 @@ """Flash Attention CUTE (CUDA Template Engine) implementation.""" -from importlib.metadata import PackageNotFoundError, version - -# Update when syncing again. -__version__ = "4.0.0.beta4" +__version__ = "4.0.0.beta8" import cutlass.cute as cute +from .cute_dsl_utils import cute_compile_patched from .interface import ( flash_attn_func, flash_attn_varlen_func, ) -from .cute_dsl_utils import cute_compile_patched - # Patch cute.compile to optionally dump SASS cute.compile = cute_compile_patched diff --git a/flash-attn4/torch-ext/flash_attn4/bench_utils.py b/flash-attn4/torch-ext/flash_attn4/bench_utils.py new file mode 100644 index 00000000..45cbcf1a --- /dev/null +++ b/flash-attn4/torch-ext/flash_attn4/bench_utils.py @@ -0,0 +1,196 @@ +"""Shared benchmark utilities: attention_ref, cuDNN helpers, flops calculation.""" + +import math +import torch + +try: + import cudnn +except ImportError: + cudnn = None + + +# โ”€โ”€ FLOPS calculation โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +def flops( + batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(None, None) +): + if causal: + avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 + else: + if window_size == (None, None): + avg_seqlen = seqlen_k + else: + row_idx = torch.arange(seqlen_q, device="cuda") + col_left = ( + torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) + if window_size[0] is not None + else torch.zeros_like(row_idx) + ) + col_right = ( + torch.minimum( + row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1) + ) + if window_size[1] is not None + else torch.full_like(row_idx, seqlen_k - 1) + ) + avg_seqlen = (col_right - col_left + 1).float().mean().item() + return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) + + +# โ”€โ”€ Reference attention โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +_attention_ref_mask_cache = {} + + +def attention_ref(q, k, v, causal=False): + """Standard attention reference implementation. + + Args: + q, k, v: (batch, seqlen, nheads, headdim) tensors. + causal: whether to apply causal mask. + """ + softmax_scale = 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) + if causal: + if scores.shape[-2] not in _attention_ref_mask_cache: + mask = torch.tril( + torch.ones(scores.shape[-2:], device=scores.device, dtype=torch.bool), diagonal=0 + ) + _attention_ref_mask_cache[scores.shape[-2]] = mask + else: + mask = _attention_ref_mask_cache[scores.shape[-2]] + scores = scores.masked_fill(mask, float("-inf")) + attn = torch.softmax(scores, dim=-1) + return torch.einsum("bhts,bshd->bthd", attn, v) + + +# โ”€โ”€ cuDNN graph helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +_TORCH_TO_CUDNN_DTYPE = { + torch.float16: "HALF", + torch.bfloat16: "BFLOAT16", + torch.float32: "FLOAT", + torch.int32: "INT32", + torch.int64: "INT64", +} + + +def _build_cudnn_graph(io_dtype, tensors, build_fn): + """Build a cuDNN graph. Returns (graph, variant_pack, workspace).""" + assert cudnn is not None, "cuDNN is not available" + cudnn_dtype = getattr(cudnn.data_type, _TORCH_TO_CUDNN_DTYPE[io_dtype]) + graph = cudnn.pygraph( + io_data_type=cudnn_dtype, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + graph_tensors = {name: graph.tensor_like(t.detach()) for name, t in tensors.items()} + variant_pack = build_fn(graph, graph_tensors) + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + return graph, variant_pack, workspace + + +def cudnn_fwd_setup(q, k, v, causal=False, window_size_left=None): + """Build a cuDNN forward SDPA graph. + + Args: + q, k, v: (batch, nheads, seqlen, headdim) tensors (cuDNN layout). + causal: whether to apply causal mask. + window_size_left: sliding window size (None for no window). + + Returns: + (fwd_fn, o_gpu, stats_gpu) where fwd_fn is a zero-arg callable. + """ + b, nheads, seqlen_q, headdim = q.shape + headdim_v = v.shape[-1] + o_gpu = torch.empty(b, nheads, seqlen_q, headdim_v, dtype=q.dtype, device=q.device) + stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device) + + def build(graph, gt): + o, stats = graph.sdpa( + name="sdpa", + q=gt["q"], + k=gt["k"], + v=gt["v"], + is_inference=False, + attn_scale=1.0 / math.sqrt(headdim), + use_causal_mask=causal or window_size_left is not None, + sliding_window_length=window_size_left + if window_size_left is not None and not causal + else None, + ) + o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride()) + stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) + return {gt["q"]: q, gt["k"]: k, gt["v"]: v, o: o_gpu, stats: stats_gpu} + + graph, variant_pack, workspace = _build_cudnn_graph(q.dtype, {"q": q, "k": k, "v": v}, build) + + def fwd_fn(): + graph.execute(variant_pack, workspace) + return o_gpu + + return fwd_fn, o_gpu, stats_gpu + + +def cudnn_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=None): + """Build a cuDNN backward SDPA graph. + + Args: + q, k, v, o, g, lse: (batch, nheads, seqlen, dim) tensors (cuDNN layout). + causal: whether to apply causal mask. + window_size_left: sliding window size (None for no window). + + Returns: + bwd_fn: zero-arg callable that returns (dq, dk, dv). + """ + headdim = q.shape[-1] + dq_gpu, dk_gpu, dv_gpu = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + + def build(graph, gt): + dq, dk, dv = graph.sdpa_backward( + name="sdpa_backward", + q=gt["q"], + k=gt["k"], + v=gt["v"], + o=gt["o"], + dO=gt["g"], + stats=gt["lse"], + attn_scale=1.0 / math.sqrt(headdim), + use_causal_mask=causal or window_size_left is not None, + sliding_window_length=window_size_left + if window_size_left is not None and not causal + else None, + use_deterministic_algorithm=False, + ) + dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride()) + dk.set_output(True).set_dim(dk_gpu.shape).set_stride(dk_gpu.stride()) + dv.set_output(True).set_dim(dv_gpu.shape).set_stride(dv_gpu.stride()) + return { + gt["q"]: q, + gt["k"]: k, + gt["v"]: v, + gt["o"]: o, + gt["g"]: g, + gt["lse"]: lse, + dq: dq_gpu, + dk: dk_gpu, + dv: dv_gpu, + } + + graph, variant_pack, workspace = _build_cudnn_graph( + q.dtype, + {"q": q, "k": k, "v": v, "o": o, "g": g, "lse": lse}, + build, + ) + + def bwd_fn(): + graph.execute(variant_pack, workspace) + return dq_gpu, dk_gpu, dv_gpu + + return bwd_fn diff --git a/flash-attn4/torch-ext/flash_attn4/block_info.py b/flash-attn4/torch-ext/flash_attn4/block_info.py index a5a2544a..b6eaeadd 100644 --- a/flash-attn4/torch-ext/flash_attn4/block_info.py +++ b/flash-attn4/torch-ext/flash_attn4/block_info.py @@ -6,7 +6,7 @@ import cutlass.cute as cute from cutlass import Int32, const_expr -from .seqlen_info import SeqlenInfoQK +from .seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK @dataclass(frozen=True) @@ -25,8 +25,8 @@ def get_n_block_min_max( self, seqlen_info: SeqlenInfoQK, m_block: Int32, - split_idx: cutlass.Int32 = 0, - num_splits: cutlass.Int32 = 1, + split_idx: Int32 = 0, + num_splits: Int32 = 1, ) -> Tuple[Int32, Int32]: n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n) if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)): @@ -46,7 +46,7 @@ def get_n_block_min_max( n_block_min = cutlass.max(n_idx_left // self.tile_n, 0) if cutlass.const_expr(self.is_split_kv): num_n_blocks_per_split = ( - cutlass.Int32(0) + Int32(0) if n_block_max <= n_block_min else (n_block_max - n_block_min + num_splits - 1) // num_splits ) @@ -70,6 +70,37 @@ def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tupl m_block_max = min(m_block_max, cute.ceil_div(m_idx_left, self.tile_m)) return m_block_min, m_block_max + @cute.jit + def get_n_block_k_new_min_max( + self, + seqlen_info: SeqlenInfoQKNewK, + m_block: Int32, + split_idx: Int32 = 0, + num_splits: Int32 = 1, + ) -> Tuple[Int32, Int32]: + """Get the block range for new K tokens (append KV). + + First computes the full n_block range via get_n_block_min_max, then maps + those blocks into the new-K index space by subtracting seqlen_k_og. + """ + n_block_min, n_block_max = self.get_n_block_min_max( + seqlen_info, + m_block, + split_idx, + num_splits, + ) + idx_k_new_min = cutlass.max(n_block_min * self.tile_n - seqlen_info.seqlen_k_og, 0) + idx_k_new_max = cutlass.min( + n_block_max * self.tile_n - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new + ) + n_block_new_min = idx_k_new_min // self.tile_n + n_block_new_max = ( + cute.ceil_div(idx_k_new_max, self.tile_n) + if idx_k_new_max > idx_k_new_min + else n_block_new_min + ) + return n_block_new_min, n_block_new_max + @cute.jit def get_n_block_min_causal_local_mask( self, diff --git a/flash-attn4/torch-ext/flash_attn4/block_sparse_utils.py b/flash-attn4/torch-ext/flash_attn4/block_sparse_utils.py index 10213d57..c4785124 100644 --- a/flash-attn4/torch-ext/flash_attn4/block_sparse_utils.py +++ b/flash-attn4/torch-ext/flash_attn4/block_sparse_utils.py @@ -72,24 +72,22 @@ def load_block_list( block_indices: cute.Tensor, block_count, - load_q_with_first: cutlass.Constexpr, first_block_preloaded: cutlass.Constexpr, kv_producer_state, - load_Q, load_K, load_V, pipeline_k, pipeline_v, - use_tma_q: cutlass.Constexpr, - tma_q_bytes: cutlass.Constexpr, intra_wg_overlap: cutlass.Constexpr, ): - """Iterate over the sparse blocks and load K, V (and Q) into the pipeline. - for the intra_wg_overlap case, we overlap the loads of K and V. And this + """Iterate over the sparse blocks and load K, V into the pipeline. + For the intra_wg_overlap case, we overlap the loads of K and V. And this means we need to pipeline the last V load from the partial block case, with the loads for the full blocks. Set first_block_preloaded when the caller has already issued the first K load for the list. + Q is loaded separately on its own mbarrier before this function is called. + Note: we iterate along the block_n indices in reverse. @@ -99,21 +97,7 @@ def load_block_list( """ if block_count > 0: if const_expr(not intra_wg_overlap): - # Peel first iteration: the first block may need to load Q alongside K, - # Parameters are already Constexpr, so no need to wrap in const_expr() - n_block_first = block_indices[block_count - 1] - extra_tx = tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0 - pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx) - - if const_expr(load_q_with_first and use_tma_q): - load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) - - load_K(src_idx=n_block_first, producer_state=kv_producer_state) - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block_first, producer_state=kv_producer_state) - kv_producer_state.advance() - - for offset in cutlass.range(1, block_count): + for offset in cutlass.range(block_count): n_block = block_indices[block_count - 1 - offset] pipeline_k.producer_acquire(kv_producer_state) load_K(src_idx=n_block, producer_state=kv_producer_state) @@ -123,14 +107,7 @@ def load_block_list( else: n_block_first = block_indices[block_count - 1] if const_expr(not first_block_preloaded): - extra_tx = ( - tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0 - ) - pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx) - - if const_expr(load_q_with_first and use_tma_q): - load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) - + pipeline_k.producer_acquire(kv_producer_state) load_K(src_idx=n_block_first, producer_state=kv_producer_state) for idx in cutlass.range(block_count - 1, unroll=1): @@ -186,19 +163,18 @@ def produce_block_sparse_loads( head_idx, m_block, kv_producer_state, - load_Q, load_K, load_V, pipeline_k, pipeline_v, - use_tma_q: cutlass.Constexpr, - tma_q_bytes: cutlass.Constexpr, intra_wg_overlap: cutlass.Constexpr, qhead_per_kvhead: cutlass.Constexpr[int] = 1, q_subtile_factor: cutlass.Constexpr[int] = 1, ): """Iterate over the mask and full block lists for a single tile. + Q is loaded separately on its own mbarrier before this function is called. + The masked (partial) list may leave the last V load pending when intra-warp-group overlap is enabled. The first full block must consume that pending V while issuing its own K load on the next pipeline stage. @@ -230,20 +206,16 @@ def produce_block_sparse_loads( full_empty = curr_full_block_cnt == 0 if mask_empty: - # No masked blocks: the full list owns the initial Q+K load. + # No masked blocks: the full list owns the initial K load. kv_producer_state = load_block_list( curr_full_block_idx, curr_full_block_cnt, - load_q_with_first=True, first_block_preloaded=False, kv_producer_state=kv_producer_state, - load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_k=pipeline_k, pipeline_v=pipeline_v, - use_tma_q=use_tma_q, - tma_q_bytes=tma_q_bytes, intra_wg_overlap=intra_wg_overlap, ) @@ -256,21 +228,16 @@ def produce_block_sparse_loads( kv_producer_state, ) else: - # Masked blocks present: load Q together with the first masked K so consumers can - # start immediately. When overlap is disabled this fully drains the list. + # Masked blocks present. When overlap is disabled this fully drains the list. kv_producer_state = load_block_list( curr_mask_block_idx, curr_mask_block_cnt, - load_q_with_first=True, first_block_preloaded=False, kv_producer_state=kv_producer_state, - load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_k=pipeline_k, pipeline_v=pipeline_v, - use_tma_q=use_tma_q, - tma_q_bytes=tma_q_bytes, intra_wg_overlap=intra_wg_overlap, ) @@ -299,16 +266,12 @@ def produce_block_sparse_loads( kv_producer_state = load_block_list( curr_full_block_idx, curr_full_block_cnt, - load_q_with_first=False, first_block_preloaded=True, kv_producer_state=kv_producer_state, - load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_k=pipeline_k, pipeline_v=pipeline_v, - use_tma_q=use_tma_q, - tma_q_bytes=tma_q_bytes, intra_wg_overlap=intra_wg_overlap, ) @@ -320,21 +283,16 @@ def produce_block_sparse_loads( kv_producer_state, ) else: - # Non-overlap path with both lists: run the full list normally (skipping the Q - # reload because the masked list already issued it). + # Non-overlap path with both lists: run the full list normally. kv_producer_state = load_block_list( curr_full_block_idx, curr_full_block_cnt, - load_q_with_first=False, first_block_preloaded=False, kv_producer_state=kv_producer_state, - load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_k=pipeline_k, pipeline_v=pipeline_v, - use_tma_q=use_tma_q, - tma_q_bytes=tma_q_bytes, intra_wg_overlap=intra_wg_overlap, ) @@ -1390,18 +1348,18 @@ def _store_one_dQaccum_sm90( m_block, sdQaccum: cute.Tensor, gdQaccum: cute.Tensor, - num_mma_warp_groups: cutlass.Constexpr, + num_dQ_warp_groups: cutlass.Constexpr, num_threads_per_warp_group: cutlass.Constexpr, tma_copy_bytes_dQ, ): """Store dQaccum for a single m_block.""" - for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups): - cute.arch.cp_async_bulk_wait_group(num_mma_warp_groups - 1 - warp_group_idx, read=True) + for warp_group_idx in cutlass.range_constexpr(num_dQ_warp_groups): + cute.arch.cp_async_bulk_wait_group(num_dQ_warp_groups - 1 - warp_group_idx, read=True) cute.arch.barrier_arrive( barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, ) - for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups): + for warp_group_idx in cutlass.range_constexpr(num_dQ_warp_groups): cute.arch.barrier( barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, @@ -1409,7 +1367,7 @@ def _store_one_dQaccum_sm90( with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( sdQaccum[None, warp_group_idx].iterator, - gdQaccum[None, warp_group_idx, m_block].iterator, + gdQaccum[(None, warp_group_idx), m_block].iterator, tma_copy_bytes_dQ, ) cute.arch.cp_async_bulk_commit_group() @@ -1425,7 +1383,7 @@ def dQaccum_store_block_sparse_bwd_sm90( gdQaccum: cute.Tensor, subtile_factor: cutlass.Constexpr, m_block_max: int, - num_mma_warp_groups: cutlass.Constexpr, + num_dQ_warp_groups: cutlass.Constexpr, num_threads_per_warp_group: cutlass.Constexpr, tma_copy_bytes_dQ, ): @@ -1454,7 +1412,7 @@ def dQaccum_store_block_sparse_bwd_sm90( m_block, sdQaccum, gdQaccum, - num_mma_warp_groups, + num_dQ_warp_groups, num_threads_per_warp_group, tma_copy_bytes_dQ, ) @@ -1470,7 +1428,7 @@ def dQaccum_store_block_sparse_bwd_sm90( m_block, sdQaccum, gdQaccum, - num_mma_warp_groups, + num_dQ_warp_groups, num_threads_per_warp_group, tma_copy_bytes_dQ, ) diff --git a/flash-attn4/torch-ext/flash_attn4/block_sparsity.py b/flash-attn4/torch-ext/flash_attn4/block_sparsity.py index 8ba66e95..55389957 100644 --- a/flash-attn4/torch-ext/flash_attn4/block_sparsity.py +++ b/flash-attn4/torch-ext/flash_attn4/block_sparsity.py @@ -34,6 +34,23 @@ class BlockSparseTensorsTorch(NamedTuple): block_size: tuple[int, int] | None = None +def get_sparse_q_block_size( + tensors: BlockSparseTensorsTorch | None, + seqlen_q: int, +) -> int | None: + """Return the Q sparse block size, or None when sparsity is unset or ambiguous.""" + if tensors is None: + return None + if tensors.block_size is not None: + return tensors.block_size[0] + num_m_blocks = tensors.mask_block_idx.shape[2] + min_block_size = ceildiv(seqlen_q, num_m_blocks) + max_block_size = seqlen_q if num_m_blocks == 1 else (seqlen_q - 1) // (num_m_blocks - 1) + if min_block_size != max_block_size: + return None + return min_block_size + + def _expand_sparsity_tensor( tensor: torch.Tensor, expected_shape: Tuple[int, ...], @@ -81,6 +98,12 @@ def _check_and_expand_block( expanded_cnt = _expand_sparsity_tensor( cnt, expected_count_shape, f"{name}_block_cnt", context, hint ) + # [Note] Allow Compact block sparse indices + # Allow the last dimension (n_blocks) of idx to be <= expected, since + # FA4 only accesses indices 0..cnt-1 per query tile. This enables compact + # index tensors that avoid O(N^2) memory at long sequence lengths. + if idx.ndim == 4 and idx.shape[3] <= expected_index_shape[3]: + expected_index_shape = (*expected_index_shape[:3], idx.shape[3]) expanded_idx = _expand_sparsity_tensor( idx, expected_index_shape, f"{name}_block_idx", context, hint ) @@ -140,17 +163,14 @@ def infer_block_sparse_expected_shapes( num_m_blocks = tensors.mask_block_idx.shape[2] if sparse_block_size_q is None: - min_block_size = ceildiv(seqlen_q, num_m_blocks) - if num_m_blocks == 1: - max_block_size = seqlen_q - else: - max_block_size = (seqlen_q - 1) // (num_m_blocks - 1) - if max_block_size != min_block_size and base_m_block != 1: + sparse_block_size_q = get_sparse_q_block_size(tensors, seqlen_q) + if sparse_block_size_q is None and base_m_block != 1: raise ValueError( f"Block sparse tensors{context} require explicit sparse_block_size[0] " f"to disambiguate block size for seqlen_q={seqlen_q} and num_m_blocks={num_m_blocks}." ) - sparse_block_size_q = min_block_size + if sparse_block_size_q is None: + sparse_block_size_q = ceildiv(seqlen_q, num_m_blocks) if sparse_block_size_q % base_m_block != 0: raise ValueError( @@ -186,9 +206,11 @@ def infer_block_sparse_expected_shapes( raise ValueError(f"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.") if mask_block_cnt.shape[2] != mask_block_idx.shape[2]: raise ValueError(f"Block sparse tensors{context} must share the same m-block dimension.") - if mask_block_idx.shape[3] != expected_n_blocks: + # [Note] Allow Compact block sparse indices: FA4 only accesses indices 0..cnt-1 + # per query tile, so idx.shape[3] can be <= expected_n_blocks. + if mask_block_idx.shape[3] > expected_n_blocks: raise ValueError( - f"Block sparse tensors{context} n-block dimension must be {expected_n_blocks}." + f"Block sparse tensors{context} n-block dimension must be <= {expected_n_blocks}." ) if expected_m_blocks != num_m_blocks: raise ValueError( @@ -314,7 +336,7 @@ def normalize_block_sparse_config( ) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None, int]: m_block_size, n_block_size = block_size if tensors.block_size is None: - sparse_block_size_q, sparse_block_size_kv = q_stage * m_block_size, n_block_size + sparse_block_size_q, sparse_block_size_kv = None, n_block_size else: sparse_block_size_q, sparse_block_size_kv = tensors.block_size if sparse_block_size_kv != n_block_size: @@ -401,6 +423,7 @@ def to_cute_block_sparse_tensors( """Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi""" if not is_block_sparsity_enabled(tensors): return None + ( mask_block_cnt, mask_block_idx, diff --git a/flash-attn4/torch-ext/flash_attn4/cache_utils.py b/flash-attn4/torch-ext/flash_attn4/cache_utils.py index 8606f04b..6f5b23f2 100644 --- a/flash-attn4/torch-ext/flash_attn4/cache_utils.py +++ b/flash-attn4/torch-ext/flash_attn4/cache_utils.py @@ -1,7 +1,6 @@ # Manage Ahead-of-Time (AOT) compiled kernels import fcntl import hashlib -import logging import os import pickle import sys @@ -18,6 +17,7 @@ import cutlass.cute as cute import tvm_ffi from cutlass.cutlass_dsl import JitCompiledFunction +from .fa_logging import fa_log # Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols # (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen. @@ -30,12 +30,6 @@ CompileKeyType: TypeAlias = tuple[Hashable, ...] CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function -logger = logging.getLogger(__name__) -_handler = logging.StreamHandler() -_handler.setFormatter(logging.Formatter("%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S")) -logger.addHandler(_handler) -logger.setLevel(logging.DEBUG) - # Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1` CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1" @@ -222,13 +216,13 @@ def _try_load_from_storage(self, key: CompileKeyType) -> bool: label=sha256_hex, ): if obj_path.exists(): - logger.debug("Loading compiled function from disk: %s", obj_path) + fa_log(1, f"Loading compiled function from disk: {obj_path}") m = cute.runtime.load_module(str(obj_path), enable_tvm_ffi=True) fn = getattr(m, self.EXPORT_FUNCTION_PREFIX) JITCache.__setitem__(self, key, fn) return True else: - logger.debug("Cache miss on disk for key hash %s", sha256_hex) + fa_log(1, f"Cache miss on disk for key hash {sha256_hex}") return False def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) -> None: @@ -243,14 +237,14 @@ def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) - obj_path = self.cache_path / f"{sha256_hex}.o" if obj_path.exists(): # Another process already exported. - logger.debug("Skipping export, already on disk: %s", obj_path) + fa_log(1, f"Skipping export, already on disk: {obj_path}") return - logger.debug("Exporting compiled function to disk: %s", obj_path) + fa_log(1, f"Exporting compiled function to disk: {obj_path}") fn.export_to_c( object_file_path=str(obj_path), function_name=self.EXPORT_FUNCTION_PREFIX, ) - logger.debug("Successfully exported compiled function to disk: %s", obj_path) + fa_log(1, f"Successfully exported compiled function to disk: {obj_path}") def _key_to_hash(self, key: CompileKeyType) -> str: return hashlib.sha256(pickle.dumps(key)).hexdigest() @@ -262,7 +256,7 @@ def clear(self) -> None: """ Not only clear the in-memory cache. Also purge persistent compilation cache. """ - logger.debug("Clearing persistent cache at %s", self.cache_path) + fa_log(1, f"Clearing persistent cache at {self.cache_path}") super().clear() for child in self.cache_path.iterdir(): child.unlink() @@ -281,8 +275,8 @@ def get_jit_cache(name: str | None = None) -> JITCache: path = get_cache_path() / _compute_source_fingerprint() if name: path = path / name - logger.debug("Creating persistent JIT cache at %s", path) + fa_log(1, f"Creating persistent JIT cache at {path}") return JITPersistentCache(path) else: - logger.debug("Persistent cache disabled, using in-memory JIT cache") + fa_log(1, "Persistent cache disabled, using in-memory JIT cache") return JITCache() diff --git a/flash-attn4/torch-ext/flash_attn4/cute_dsl_utils.py b/flash-attn4/torch-ext/flash_attn4/cute_dsl_utils.py index 0cf0b605..79ebd9df 100644 --- a/flash-attn4/torch-ext/flash_attn4/cute_dsl_utils.py +++ b/flash-attn4/torch-ext/flash_attn4/cute_dsl_utils.py @@ -4,7 +4,6 @@ import pathlib from typing import Tuple from functools import partial, lru_cache -from dataclasses import dataclass, fields import torch @@ -15,7 +14,6 @@ import cutlass import cutlass.cute as cute -from cutlass.base_dsl.typing import JitArgument from cutlass.cutlass_dsl import NumericMeta from cutlass.cute.runtime import from_dlpack @@ -43,42 +41,6 @@ def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: return torch.cuda.get_device_capability(device) -@dataclass -class ArgumentsBase(JitArgument): - def __c_pointers__(self): - all_fields = [getattr(self, field.name) for field in fields(self)] - non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] - c_ptrs = [] - for obj in non_constexpr_fields: - if hasattr(obj, "__c_pointers__"): - c_ptrs.extend(obj.__c_pointers__()) - return c_ptrs - - def __get_mlir_types__(self): - all_fields = [getattr(self, field.name) for field in fields(self)] - non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] - types, self._values_pos = [], [] - for obj in non_constexpr_fields: - if hasattr(obj, "__get_mlir_types__"): - obj_types = obj.__get_mlir_types__() - types.extend(obj_types) - self._values_pos.append(len(obj_types)) - else: - self._values_pos.append(0) - return types - - def __new_from_mlir_values__(self, values): - all_fields = {field.name: getattr(self, field.name) for field in fields(self)} - constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} - non_constexpr_fields = { - n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) - } - for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): - non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) - values = values[n_items:] - return self.__class__(**non_constexpr_fields, **constexpr_fields) - - def load_cubin_module_data_patched(cubin_data, filepath): pathlib.Path(filepath).write_bytes(cubin_data) return load_cubin_module_data_og(cubin_data) diff --git a/flash-attn4/torch-ext/flash_attn4/fa_logging.py b/flash-attn4/torch-ext/flash_attn4/fa_logging.py new file mode 100644 index 00000000..63189cd5 --- /dev/null +++ b/flash-attn4/torch-ext/flash_attn4/fa_logging.py @@ -0,0 +1,97 @@ +# Copyright (c) 2025, Tri Dao. + +"""Unified FlashAttention logging controlled by a single ``FA_LOG_LEVEL`` env var. + +Host-side messages go through Python ``logging`` (logger name ``flash_attn``). +A default ``StreamHandler`` is attached automatically when ``FA_LOG_LEVEL >= 1`` +so that standalone scripts get output without extra setup; applications that +configure their own logging can remove or replace it via the standard API. + +FA_LOG_LEVEL mapping:: + + 0 off nothing logged + 1 host host-side summaries only (no kernel printf) + 2 kernel host + curated kernel traces + 3 max host + all kernel traces (noisy, perf hit) + +Set via environment variable:: + + FA_LOG_LEVEL=1 python train.py + +Device-side ``cute.printf`` calls are compile-time eliminated via +``cutlass.const_expr`` when the log level is below the callsite threshold, +so there is zero performance cost when device logging is off. +Changing the log level after kernel compilation requires a recompile +(the level participates in the forward compile key). +""" + +import logging +import os +import sys + +import cutlass.cute as cute +from cutlass import const_expr + +_LOG_LEVEL_NAMES = {"off": 0, "host": 1, "kernel": 2, "max": 3} + + +def _parse_log_level(raw: str) -> int: + if raw in _LOG_LEVEL_NAMES: + return _LOG_LEVEL_NAMES[raw] + try: + level = int(raw) + except ValueError: + return 0 + return max(0, min(level, 3)) + + +_fa_log_level: int = _parse_log_level(os.environ.get("FA_LOG_LEVEL", "0")) + +_logger = logging.getLogger("flash_attn") +_logger.addHandler(logging.NullHandler()) +_default_handler: logging.Handler | None = None + + +def _configure_default_handler() -> None: + global _default_handler + if _fa_log_level >= 1: + if _default_handler is None: + _default_handler = logging.StreamHandler(sys.stdout) + _default_handler.setFormatter(logging.Formatter("[FA] %(message)s")) + _logger.addHandler(_default_handler) + _logger.setLevel(logging.DEBUG) + else: + if _default_handler is not None: + _logger.removeHandler(_default_handler) + _default_handler = None + _logger.setLevel(logging.WARNING) + + +_configure_default_handler() + + +def get_fa_log_level() -> int: + return _fa_log_level + + +def set_fa_log_level(level: int | str) -> None: + """Set the FA log level programmatically. + + Host logging takes effect immediately. Device logging changes only + affect kernels compiled after this call (new compile-key selection). + """ + global _fa_log_level + if isinstance(level, str): + level = _parse_log_level(level) + _fa_log_level = max(0, min(int(level), 3)) + _configure_default_handler() + + +def fa_log(level: int, msg: str): + if _fa_log_level >= level: + _logger.info(msg) + + +def fa_printf(level: int, fmt, *args): + if const_expr(_fa_log_level >= level): + cute.printf(fmt, *args) diff --git a/flash-attn4/torch-ext/flash_attn4/flash_bwd.py b/flash-attn4/torch-ext/flash_attn4/flash_bwd.py index 82408121..7e2be4ed 100644 --- a/flash-attn4/torch-ext/flash_attn4/flash_bwd.py +++ b/flash-attn4/torch-ext/flash_attn4/flash_bwd.py @@ -22,6 +22,7 @@ from .seqlen_info import SeqlenInfoQK from .quack.cute_dsl_utils import ParamsBase from .tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from .block_sparsity import BlockSparseTensors class FlashAttentionBackwardSm80: @@ -372,7 +373,6 @@ def __call__( mdK: cute.Tensor, mdV: cute.Tensor, softmax_scale: cutlass.Float32, - stream: cuda.CUstream, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, @@ -381,8 +381,16 @@ def __call__( window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, + mdK_semaphore: Optional[cute.Tensor] = None, + mdV_semaphore: Optional[cute.Tensor] = None, + aux_tensors: Optional[list] = None, + blocksparse_tensors: Optional[BlockSparseTensors] = None, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, ): - assert mdQ_semaphore is None, "semaphore not supported yet" + assert mdQ_semaphore is None and mdK_semaphore is None and mdV_semaphore is None, ( + "determinism not supported yet for Sm80" + ) # Get the data type and check if it is fp16 or bf16 self._check_type(*(t.element_type if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK))) @@ -512,7 +520,17 @@ def kernel( n_block, head_idx, batch_idx, _ = work_tile.tile_idx if work_tile.is_valid_tile: - seqlen = SeqlenInfoQK.create(batch_idx, mQ.shape[1], mK.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK) + seqlen = SeqlenInfoQK.create( + batch_idx, + mQ.shape[1], + mK.shape[1], + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + tile_m=self.m_block_size, + tile_n=self.n_block_size, + ) m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size) m_block_min = 0 @@ -538,7 +556,7 @@ def kernel( mdPsum_cur = mdPsum[batch_idx, head_idx, None] mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] else: - padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size + padded_offset_q = seqlen.padded_offset_q mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, head_idx, None]) mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[head_idx, None]) mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None]) @@ -794,9 +812,10 @@ def kernel( # Mainloop # /////////////////////////////////////////////////////////////////////////////// # Start processing of the first n-block. - mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k) + mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen) mask_fn = partial( mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp, + batch_idx=batch_idx, head_idx=head_idx, mask_seqlen=True, mask_causal=self.is_causal ) smem_pipe_read_q = cutlass.Int32(0) @@ -968,7 +987,7 @@ def dQ_mma(hook_fn): # MMA dK if cutlass.const_expr(self.Mma_dKV_is_RS): - tdVrP = layout_utils.reshape_acc_to_frgA(rdS) + tdKrdS = layout_utils.reshape_acc_to_frgA(rdS) else: tdKrdS = mma_params.tdKrdS sm80_utils.gemm( diff --git a/flash-attn4/torch-ext/flash_attn4/flash_bwd_postprocess.py b/flash-attn4/torch-ext/flash_attn4/flash_bwd_postprocess.py index 68404565..5ef8feb6 100644 --- a/flash-attn4/torch-ext/flash_attn4/flash_bwd_postprocess.py +++ b/flash-attn4/torch-ext/flash_attn4/flash_bwd_postprocess.py @@ -2,7 +2,7 @@ # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_postprocess_kernel.h # from Cutlass C++ to Cute-DSL. import math -from typing import Callable, Optional, Type, Literal +from typing import Callable, Optional, Type import cuda.bindings.driver as cuda @@ -36,7 +36,7 @@ def __init__( self, dtype: Type[cutlass.Numeric], head_dim: int, - arch: Literal[80, 90, 100], + arch: int, tile_m: int = 128, num_threads: int = 256, AtomLayoutMdQ: int = 1, @@ -52,8 +52,8 @@ def __init__( """ self.dtype = dtype self.tile_m = tile_m - assert arch // 10 in [8, 9, 10, 11], ( - "Only Ampere (8.x), Hopper (9.x), and Blackwell (10.x, 11.x) are supported" + assert arch // 10 in [8, 9, 10, 11, 12], ( + "Only Ampere (8.x), Hopper (9.x), and Blackwell (10.x, 11.x, 12.x) are supported" ) self.arch = arch # padding head_dim to a multiple of 32 as k_block_size @@ -63,7 +63,7 @@ def __init__( self.num_threads = num_threads self.AtomLayoutMdQ = AtomLayoutMdQ self.dQ_swapAB = dQ_swapAB - self.use_2cta_instrs = use_2cta_instrs and arch == 100 and head_dim != 64 + self.use_2cta_instrs = use_2cta_instrs and arch // 10 == 10 and head_dim != 64 self.cluster_size = cluster_size @staticmethod @@ -89,7 +89,7 @@ def can_implement(dtype, head_dim, tile_m, num_threads) -> bool: return True def _get_tiled_mma(self): - if const_expr(self.arch == 80): + if const_expr(self.arch // 10 in [8, 12]): num_mma_warps = self.num_threads // 32 atom_layout_dQ = ( (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) @@ -101,9 +101,9 @@ def _get_tiled_mma(self): atom_layout_dQ, permutation_mnk=(atom_layout_dQ[0] * 16, atom_layout_dQ[1] * 16, 16), ) - elif const_expr(self.arch == 90): - num_mma_warp_groups = self.num_threads // 128 - atom_layout_dQ = (self.AtomLayoutMdQ, num_mma_warp_groups // self.AtomLayoutMdQ) + elif const_expr(self.arch // 10 == 9): + num_wg_mma = self.num_threads // 128 + atom_layout_dQ = (self.AtomLayoutMdQ, num_wg_mma // self.AtomLayoutMdQ) tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1]) tiled_mma = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, @@ -125,7 +125,7 @@ def _get_tiled_mma(self): cta_group, (self.tile_m, self.tile_hdim), ) - if const_expr(self.arch in [80, 90]): + if const_expr(self.arch // 10 in [8, 9, 12]): assert self.num_threads == tiled_mma.size return tiled_mma @@ -148,22 +148,22 @@ def _setup_attributes(self): cute.make_layout(self.num_threads), cute.make_layout(async_copy_elems_accum), ) - num_s2r_copy_elems = 1 if const_expr(self.arch == 80) else 4 - if const_expr(self.arch == 80): + num_s2r_copy_elems = 1 if const_expr(self.arch // 10 in [8, 12]) else 4 + if const_expr(self.arch // 10 in [8, 12]): self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( Float32, self.num_threads, num_s2r_copy_elems ) self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim) - elif const_expr(self.arch == 90): + elif const_expr(self.arch // 10 == 9): num_threads_per_warp_group = 128 - num_mma_warp_groups = self.num_threads // 128 + num_wg_mma = self.num_threads // 128 self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv( cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128), - cute.make_layout((num_threads_per_warp_group, num_mma_warp_groups)), # thr_layout + cute.make_layout((num_threads_per_warp_group, num_wg_mma)), # thr_layout cute.make_layout(128 // Float32.width), # val_layout ) self.sdQaccum_layout = cute.make_layout( - (self.tile_m * self.tile_hdim // num_mma_warp_groups, num_mma_warp_groups) + (self.tile_m * self.tile_hdim // num_wg_mma, num_wg_mma) ) else: self.dQ_reduce_ncol = 32 @@ -188,14 +188,18 @@ def _setup_attributes(self): # then setting kBlockKSmem to 32 will cause "Static shape_div failure". # We want to treat it as 64 x 48, so kBlockKSmem should be 16. mma_shape_n = self.tiled_mma.get_tile_size(1) - if const_expr(self.arch == 80): + if const_expr(self.arch // 10 in [8, 12]): sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n) self.sdQ_layout = cute.tile_to_shape( sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1) ) - elif const_expr(self.arch == 90): + elif const_expr(self.arch // 10 == 9): + wg_d_dQ = num_wg_mma // self.AtomLayoutMdQ self.sdQ_layout = sm90_utils.make_smem_layout( - self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim) + self.dtype, + LayoutEnum.ROW_MAJOR, + (self.tile_m, self.tile_hdim), + major_mode_size=self.tile_hdim // wg_d_dQ, ) else: # TODO: this is hard-coded for hdim 128 @@ -211,7 +215,8 @@ def __call__( scale: cutlass.Float32, mCuSeqlensQ: Optional[cute.Tensor], mSeqUsedQ: Optional[cute.Tensor], - stream: cuda.CUstream, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, ): # Get the data type and check if it is fp16 or bf16 if const_expr(mdQ.element_type not in [cutlass.Float16, cutlass.BFloat16]): @@ -305,7 +310,7 @@ def kernel( smem = cutlass.utils.SmemAllocator() sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024) sdQaccum_flat = cute.make_tensor(sdQaccum.iterator, cute.make_layout(cute.size(sdQaccum))) - if const_expr(self.arch in [80, 90]): + if const_expr(self.arch // 10 in [8, 9, 12]): sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout) else: # extra stage dimension @@ -343,10 +348,7 @@ def kernel( mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] head_dim = mdQ.shape[3] else: - if cutlass.const_expr(self.arch >= 90): - padded_offset_q = seqlen.padded_offset_q - else: - padded_offset_q = seqlen.offset_q + batch_idx * self.tile_m + padded_offset_q = seqlen.padded_offset_q mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None]) mdQaccum_cur = cute.domain_offset( (padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None] @@ -371,7 +373,7 @@ def kernel( seqlen_q = seqlen.seqlen_q seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) - if const_expr(self.arch == 100 and self.use_2cta_instrs): + if const_expr(self.arch // 10 == 10 and self.use_2cta_instrs): # 2-CTA: remap dQaccum layout into TMEM view before writing sdQ num_reduce_threads = self.num_threads thr_mma_dsk = tiled_mma.get_slice(tidx) @@ -502,7 +504,7 @@ def kernel( tile_shape = (self.tile_m, self.tile_hdim) acc = None tiled_copy_t2r = None - if const_expr(self.arch in [80, 90]): + if const_expr(self.arch // 10 in [8, 9, 12]): acc_shape = tiled_mma.partition_shape_C( tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1] ) @@ -531,7 +533,7 @@ def kernel( # Step 3: Copy dQ from register to smem cute.arch.barrier() # make sure all threads have finished loading dQaccum - if const_expr(self.arch in [80, 90]): + if const_expr(self.arch // 10 in [8, 9, 12]): copy_atom_r2s_dQ = utils.get_smem_store_atom( self.arch, self.dtype, transpose=self.dQ_swapAB ) @@ -553,7 +555,7 @@ def kernel( ) thr_copy_r2s_dQ = tiled_copy_r2s_dQ.get_slice(tidx) cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim)) - if const_expr(self.arch in [80, 90]): + if const_expr(self.arch // 10 in [8, 9, 12]): taccdQrdQ = thr_copy_r2s_dQ.retile(rdQ) else: taccdQcdQ_shape = thr_copy_r2s_dQ.partition_S(cdQ).shape diff --git a/flash-attn4/torch-ext/flash_attn4/flash_bwd_preprocess.py b/flash-attn4/torch-ext/flash_attn4/flash_bwd_preprocess.py index 62988898..0d9628ae 100644 --- a/flash-attn4/torch-ext/flash_attn4/flash_bwd_preprocess.py +++ b/flash-attn4/torch-ext/flash_attn4/flash_bwd_preprocess.py @@ -1,21 +1,32 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_preprocess_kernel.h # from Cutlass C++ to Cute-DSL. +# +# Computes D_i = (dO_i * O_i).sum(dim=-1), optionally adjusted for LSE gradient: +# D'_i = D_i - dLSE_i +# This works because in the backward pass: +# dS_ij = P_ij * (dP_ij - D_i) [standard] +# When LSE is differentiable, d(loss)/d(S_ij) gets an extra term dLSE_i * P_ij +# (since d(LSE_i)/d(S_ij) = P_ij), giving: +# dS_ij = P_ij * (dP_ij - D_i) + dLSE_i * P_ij +# = P_ij * (dP_ij - (D_i - dLSE_i)) +# So the main backward kernel is unchanged; we just replace D with D' = D - dLSE here. import math import operator -from typing import Callable, Type, Optional, Literal +from functools import partial +from typing import Callable, Type, Optional import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute -from cutlass import Float32 +from cutlass import Float32, const_expr +from cutlass.cutlass_dsl import Arch, BaseDSL -from .quack import copy_utils +from .quack import copy_utils, layout_utils from . import utils -from .cute_dsl_utils import assume_tensor_aligned -from .seqlen_info import SeqlenInfoQK +from .seqlen_info import SeqlenInfo from .quack.cute_dsl_utils import ParamsBase from .tile_scheduler import ( SingleTileScheduler, @@ -30,9 +41,8 @@ def __init__( dtype: Type[cutlass.Numeric], head_dim: int, head_dim_v: int, - arch: Literal[80, 90, 100], - m_block_size: int = 128, - num_threads: int = 128, + tile_m: int = 128, + num_threads: int = 256, ): """ All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension @@ -40,14 +50,14 @@ def __init__( :param head_dim: head dimension :type head_dim: int - :param m_block_size: m block size - :type m_block_size: int + :param tile_m: m block size + :type tile_m: int :param num_threads: number of threads :type num_threads: int """ + self.use_pdl = BaseDSL._get_dsl().get_arch_enum() >= Arch.sm_90a self.dtype = dtype - self.m_block_size = m_block_size - self.arch = arch + self.tile_m = tile_m # padding head_dim to a multiple of 32 as k_block_size hdim_multiple_of = 32 self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) @@ -56,15 +66,15 @@ def __init__( self.num_threads = num_threads @staticmethod - def can_implement(dtype, head_dim, m_block_size, num_threads) -> bool: + def can_implement(dtype, head_dim, tile_m, num_threads) -> bool: """Check if the kernel can be implemented with the given parameters. :param dtype: data type :type dtype: cutlass.Numeric :param head_dim: head dimension :type head_dim: int - :param m_block_size: m block size - :type m_block_size: int + :param tile_m: m block size + :type tile_m: int :param num_threads: number of threads :type num_threads: int @@ -77,7 +87,7 @@ def can_implement(dtype, head_dim, m_block_size, num_threads) -> bool: return False if num_threads % 32 != 0: return False - if num_threads < m_block_size: # For multiplying lse with log2 + if num_threads < tile_m: # For multiplying lse with log2 return False return True @@ -105,7 +115,7 @@ def _setup_attributes(self): universal_copy_bits = 128 num_copy_elems_dQaccum = universal_copy_bits // Float32.width assert ( - self.m_block_size * self.head_dim_padded // num_copy_elems_dQaccum + self.tile_m * self.head_dim_padded // num_copy_elems_dQaccum ) % self.num_threads == 0 self.gmem_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( Float32, self.num_threads, num_copy_elems_dQaccum @@ -114,38 +124,53 @@ def _setup_attributes(self): @cute.jit def __call__( self, - mO: cute.Tensor, - mdO: cute.Tensor, - mdPsum: cute.Tensor, - mLSE: Optional[cute.Tensor], - mLSElog2: Optional[cute.Tensor], + mO: cute.Tensor, # (batch, seqlen, nheads, head_dim_v) or (total_q, nheads, head_dim_v) + mdO: cute.Tensor, # same shape as mO + mPdPsum: cute.Tensor, # (batch, nheads, seqlen_padded) or (nheads, total_q_padded) + mLSE: Optional[cute.Tensor], # (batch, nheads, seqlen) or (nheads, total_q) + mLSElog2: Optional[cute.Tensor], # same shape as mPdPsum + # (batch, nheads, seqlen_padded * head_dim_v) or (nheads, total_q_padded * head_dim_v) mdQaccum: Optional[cute.Tensor], - mCuSeqlensQ: Optional[cute.Tensor], - mSeqUsedQ: Optional[cute.Tensor], - stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor], # (batch + 1,) + mSeqUsedQ: Optional[cute.Tensor], # (batch,) + mdLSE: Optional[cute.Tensor], # (batch, nheads, seqlen) or (nheads, total_q) + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, ): # Get the data type and check if it is fp16 or bf16 - if cutlass.const_expr(not (mO.element_type == mdO.element_type)): + if const_expr(not (mO.element_type == mdO.element_type)): raise TypeError("All tensors must have the same data type") - if cutlass.const_expr(mO.element_type not in [cutlass.Float16, cutlass.BFloat16]): + if const_expr(mO.element_type not in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") - if cutlass.const_expr(mdPsum.element_type not in [Float32]): - raise TypeError("dPsum tensor must be Float32") - if cutlass.const_expr(mdQaccum is not None): - if cutlass.const_expr(mdQaccum.element_type not in [Float32]): + if const_expr(mPdPsum.element_type not in [Float32]): + raise TypeError("PdPsum tensor must be Float32") + if const_expr(mdQaccum is not None): + if const_expr(mdQaccum.element_type not in [Float32]): raise TypeError("dQaccum tensor must be Float32") - if cutlass.const_expr(mLSE is not None): + if const_expr(mLSE is not None): assert mLSElog2 is not None, "If mLSE is provided, mLSElog2 must also be provided" - if cutlass.const_expr(mLSE.element_type not in [Float32]): + if const_expr(mLSE.element_type not in [Float32]): raise TypeError("LSE tensor must be Float32") - if cutlass.const_expr(mLSElog2.element_type not in [Float32]): + if const_expr(mLSElog2.element_type not in [Float32]): raise TypeError("LSElog2 tensor must be Float32") - - mO, mdO, mdQaccum = [assume_tensor_aligned(t) for t in (mO, mdO, mdQaccum)] + if const_expr(mdLSE is not None): + if const_expr(mdLSE.element_type not in [Float32]): + raise TypeError("dLSE tensor must be Float32") self._setup_attributes() - if cutlass.const_expr(mCuSeqlensQ is not None): + # (batch, nheads, seqlen) -> (seqlen, nheads, batch) or (total_q, nheads) -> (nheads, total_q) + transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mPdPsum = layout_utils.select(mPdPsum, transpose) + if const_expr(mLSE is not None): + mLSE = layout_utils.select(mLSE, transpose) + mLSElog2 = layout_utils.select(mLSElog2, transpose) + if const_expr(mdLSE is not None): + mdLSE = layout_utils.select(mdLSE, transpose) + if const_expr(mdQaccum is not None): + mdQaccum = layout_utils.select(mdQaccum, transpose) + + if const_expr(mCuSeqlensQ is not None): TileScheduler = SingleTileVarlenScheduler num_head = mO.shape[1] num_batch = mCuSeqlensQ.shape[0] - 1 @@ -155,7 +180,7 @@ def __call__( num_batch = mO.shape[0] tile_sched_args = TileSchedulerArguments( - num_block=cute.ceil_div(mO.shape[1], self.m_block_size), + num_block=cute.ceil_div(mO.shape[1], self.tile_m), num_head=num_head, num_batch=num_batch, num_splits=1, @@ -163,7 +188,7 @@ def __call__( headdim=0, headdim_v=mO.shape[2], total_q=mO.shape[0], - tile_shape_mn=(self.m_block_size, 1), + tile_shape_mn=(self.tile_m, 1), mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, ) @@ -174,12 +199,13 @@ def __call__( self.kernel( mO, mdO, - mdPsum, + mPdPsum, mLSE, mLSElog2, mdQaccum, mCuSeqlensQ, mSeqUsedQ, + mdLSE, self.gmem_tiled_copy_O, self.gmem_tiled_copy_dQaccum, tile_sched_params, @@ -188,6 +214,7 @@ def __call__( grid=grid_dim, block=[self.num_threads, 1, 1], stream=stream, + use_pdl=self.use_pdl, ) @cute.kernel @@ -195,12 +222,13 @@ def kernel( self, mO: cute.Tensor, mdO: cute.Tensor, - mdPsum: cute.Tensor, + mPdPsum: cute.Tensor, mLSE: Optional[cute.Tensor], mLSElog2: Optional[cute.Tensor], mdQaccum: Optional[cute.Tensor], mCuSeqlensQ: Optional[cute.Tensor], mSeqUsedQ: Optional[cute.Tensor], + mdLSE: Optional[cute.Tensor], gmem_tiled_copy_O: cute.TiledCopy, gmem_tiled_copy_dQaccum: cute.TiledCopy, tile_sched_params: ParamsBase, @@ -217,145 +245,106 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. # /////////////////////////////////////////////////////////////////////////////// - seqlen = SeqlenInfoQK.create( - batch_idx, - mO.shape[1], - 0, - mCuSeqlensQ=mCuSeqlensQ, - mCuSeqlensK=None, - mSeqUsedQ=mSeqUsedQ, - mSeqUsedK=None, + seqlen = SeqlenInfo.create( + batch_idx, mO.shape[1], mCuSeqlensQ, mSeqUsedQ, tile=self.tile_m ) + mO_cur = seqlen.offset_batch(mO, batch_idx, dim=0)[None, head_idx, None] + mdO_cur = seqlen.offset_batch(mdO, batch_idx, dim=0)[None, head_idx, None] + mPdPsum_cur = seqlen.offset_batch(mPdPsum, batch_idx, dim=2, padded=True)[ + None, head_idx + ] + headdim_v = mO_cur.shape[cute.rank(mO_cur) - 1] + seqlen_q = seqlen.seqlen + seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) + seqlen_limit = seqlen_q - m_block * self.tile_m + + lse = None + if const_expr(mLSE is not None): + mLSE_cur = seqlen.offset_batch(mLSE, batch_idx, dim=2)[None, head_idx] + gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,)) + lse = Float32.inf + if tidx < seqlen_limit: + lse = gLSE[tidx] - if cutlass.const_expr(not seqlen.has_cu_seqlens_q): - mO_cur = mO[batch_idx, None, head_idx, None] - mdO_cur = mdO[batch_idx, None, head_idx, None] - mdPsum_cur = mdPsum[batch_idx, head_idx, None] - headdim_v = mO.shape[3] - else: - mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, head_idx, None]) - mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None]) - - padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size - if cutlass.const_expr(self.arch >= 90): - padded_offset_q = padded_offset_q // self.m_block_size * self.m_block_size - mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None]) - headdim_v = mO.shape[2] - - blkOdO_shape = (self.m_block_size, self.head_dim_v_padded) - # (m_block_size, head_dim_v) - gO = cute.local_tile(mO_cur, blkOdO_shape, (m_block, 0)) - gdO = cute.local_tile(mdO_cur, blkOdO_shape, (m_block, 0)) - + blk_shape = (self.tile_m, self.head_dim_v_padded) + gO = cute.local_tile(mO_cur, blk_shape, (m_block, 0)) + gdO = cute.local_tile(mdO_cur, blk_shape, (m_block, 0)) gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) # (CPY_Atom, CPY_M, CPY_K) tOgO = gmem_thr_copy_O.partition_S(gO) tOgdO = gmem_thr_copy_O.partition_S(gdO) - - # /////////////////////////////////////////////////////////////////////////////// - # Predicate: Mark indices that need to copy when problem_shape isn't a multiple - # of tile_shape - # /////////////////////////////////////////////////////////////////////////////// - # Construct identity layout for KV - cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) + cO = cute.make_identity_tensor(blk_shape) tOcO = gmem_thr_copy_O.partition_S(cO) t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cO) - tOpO = utils.predicate_k(tOcO, limit=headdim_v) - tOpdO = utils.predicate_k(tOcO, limit=headdim_v) - - seqlen_q = seqlen.seqlen_q - seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) - - if cutlass.const_expr(mLSE is not None): - if cutlass.const_expr(not seqlen.has_cu_seqlens_q): - mLSE_cur = mLSE[batch_idx, head_idx, None] - else: - mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[head_idx, None]) - - gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,)) - lse = Float32.inf - if tidx < seqlen_q - m_block * self.m_block_size: - lse = gLSE[tidx] - - tOrO = cute.make_fragment_like(tOgO) - tOrdO = cute.make_fragment_like(tOgdO) - assert cute.size(tOgO, mode=[0]) == cute.size(tOgdO, mode=[0]) - assert cute.size(tOgO, mode=[1]) == cute.size(tOgdO, mode=[1]) - assert cute.size(tOgO, mode=[2]) == cute.size(tOgdO, mode=[2]) + tOpO = None + if const_expr(self.check_hdim_v_oob): + tOpO = copy_utils.predicate_k(tOcO, limit=headdim_v) + # Each copy will use the same predicate + copy = partial(copy_utils.copy, pred=tOpO) + + tOrO = cute.make_rmem_tensor_like(tOgO) + tOrdO = cute.make_rmem_tensor_like(tOgdO) + if const_expr(self.check_hdim_v_oob): + tOrO.fill(0.0) + tOrdO.fill(0.0) + assert tOgO.shape == tOgdO.shape for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True): - # Instead of using tOcO, we using t0OcO and subtract the offset from the limit - # (seqlen_q - m_block * kBlockM). This is because the entries of t0OcO are known at compile time. - if t0OcO[0, m, 0][0] < seqlen_q - m_block * self.m_block_size - tOcO[0][0]: - cute.copy( - gmem_thr_copy_O, - tOgO[None, m, None], - tOrO[None, m, None], - pred=tOpO[None, m, None] - if cutlass.const_expr(self.check_hdim_v_oob) - else None, - ) - cute.copy( - gmem_thr_copy_O, - tOgdO[None, m, None], - tOrdO[None, m, None], - pred=tOpdO[None, m, None] - if cutlass.const_expr(self.check_hdim_v_oob) - else None, - ) + # Instead of using tOcO, we using t0OcO and subtract the offset from the limit. + # This is bc the entries of t0OcO are known at compile time. + if t0OcO[0, m, 0][0] < seqlen_limit - tOcO[0][0]: + copy(tOgO[None, m, None], tOrO[None, m, None]) + copy(tOgdO[None, m, None], tOrdO[None, m, None]) + # O and dO loads are done; signal that the next kernel can start. + # Correctness is ensured by griddepcontrol_wait() in bwd_sm90 before it reads our outputs. + if const_expr(self.use_pdl): + cute.arch.griddepcontrol_launch_dependents() # Sum across the "k" dimension - dpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce( + pdpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce( cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1) ) threads_per_row = gmem_tiled_copy_O.layout_src_tv_tiled[0].shape[0] assert cute.arch.WARP_SIZE % threads_per_row == 0 - dpsum = utils.warp_reduce(dpsum, operator.add, width=threads_per_row) - dP_sum = cute.make_fragment(cute.size(tOrO, mode=[1]), Float32) - dP_sum.store(dpsum) - - # Write dPsum from rmem -> gmem - gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (m_block,)) - # Only the thread corresponding to column 0 writes out the dPsum to gmem + pdpsum = utils.warp_reduce(pdpsum, operator.add, width=threads_per_row) + PdP_sum = cute.make_rmem_tensor(cute.size(tOrO, mode=[1]), Float32) + PdP_sum.store(pdpsum) + + # If dLSE is provided, compute D' = D - dLSE (see module docstring for derivation). + gdLSE = None + if const_expr(mdLSE is not None): + mdLSE_cur = seqlen.offset_batch(mdLSE, batch_idx, dim=2)[None, head_idx] + gdLSE = cute.local_tile(mdLSE_cur, (self.tile_m,), (m_block,)) + + # Write PdPsum from rmem -> gmem + gPdPsum = cute.local_tile(mPdPsum_cur, (self.tile_m,), (m_block,)) + # Only the thread corresponding to column 0 writes out the PdPsum to gmem if tOcO[0, 0, 0][1] == 0: - for m in cutlass.range(cute.size(dP_sum), unroll_full=True): + for m in cutlass.range(cute.size(PdP_sum), unroll_full=True): row = tOcO[0, m, 0][0] - gdPsum[row] = dP_sum[m] if row < seqlen_q - m_block * self.m_block_size else 0.0 + PdPsum_val = 0.0 + if row < seqlen_limit: + PdPsum_val = PdP_sum[m] + if const_expr(mdLSE is not None): + PdPsum_val -= gdLSE[row] + gPdPsum[row] = PdPsum_val # Clear dQaccum - if cutlass.const_expr(mdQaccum is not None): - if cutlass.const_expr(not seqlen.has_cu_seqlens_q): - mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] - else: - mdQaccum_cur = cute.domain_offset( - (padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None] - ) - - # HACK: Compiler doesn't seem to recognize that padding - # by padded_offset_q * self.head_dim_padded keeps alignment - # since statically divisible by 4 - - mdQaccum_cur_ptr = cute.make_ptr( - dtype=mdQaccum_cur.element_type, - value=mdQaccum_cur.iterator.toint(), - mem_space=mdQaccum_cur.iterator.memspace, - assumed_align=mdQaccum.iterator.alignment, - ) - mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout) - - blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) + if const_expr(mdQaccum is not None): + mdQaccum_cur = seqlen.offset_batch( + mdQaccum, batch_idx, dim=2, padded=True, multiple=self.head_dim_padded + )[None, head_idx] + blkdQaccum_shape = (self.tile_m * self.head_dim_padded,) gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,)) gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) - zero = cute.make_fragment_like(tdQgdQaccum) + zero = cute.make_rmem_tensor_like(tdQgdQaccum) zero.fill(0.0) cute.copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum) - if cutlass.const_expr(mLSE is not None): - if cutlass.const_expr(not seqlen.has_cu_seqlens_q): - mLSElog2_cur = mLSElog2[batch_idx, head_idx, None] - else: - mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[head_idx, None]) - - gLSElog2 = cute.local_tile(mLSElog2_cur, (self.m_block_size,), (m_block,)) + if const_expr(mLSE is not None): + mLSElog2_cur = seqlen.offset_batch(mLSElog2, batch_idx, dim=2, padded=True)[ + None, head_idx + ] + gLSElog2 = cute.local_tile(mLSElog2_cur, (self.tile_m,), (m_block,)) LOG2_E = math.log2(math.e) - if tidx < seqlen_q_rounded - m_block * self.m_block_size: + if tidx < seqlen_q_rounded - m_block * self.tile_m: gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0 diff --git a/flash-attn4/torch-ext/flash_attn4/flash_bwd_sm100.py b/flash-attn4/torch-ext/flash_attn4/flash_bwd_sm100.py index 1e99dbab..96edc004 100644 --- a/flash-attn4/torch-ext/flash_attn4/flash_bwd_sm100.py +++ b/flash-attn4/torch-ext/flash_attn4/flash_bwd_sm100.py @@ -84,7 +84,6 @@ def __init__( self.use_2cta_instrs = bool( use_2cta_instrs and cluster_size == 2 - and not is_local and score_mod is None and score_mod_bwd is None and mask_mod is None @@ -453,7 +452,6 @@ def __call__( mdK: cute.Tensor, mdV: cute.Tensor, softmax_scale: Float32, - stream: cuda.CUstream, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, @@ -467,6 +465,8 @@ def __call__( aux_tensors: Optional[list] = None, # Block-sparse tensors (Q direction - for iterating m_blocks per n_block): blocksparse_tensors: Optional[BlockSparseTensors] = None, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, ): self.q_dtype = mQ.element_type self.k_dtype = mK.element_type @@ -927,10 +927,6 @@ class SharedStorage: "2-CTA mode does not support block sparsity. " "Please create kernel with use_2cta_instrs=False for block sparse attention." ) - assert window_size_left is None and window_size_right is None, ( - "2-CTA mode does not support window attention. " - "Please create kernel with use_2cta_instrs=False for window attention." - ) # 2-CTA: 231424 and 1-CTA: 232448 # print("SMEM: ", self.shared_storage.size_in_bytes()) if const_expr(self.use_block_sparsity or aux_tensors is not None): @@ -3143,6 +3139,8 @@ def compute_loop( with cute.arch.elect_one(): pipeline_S_P.consumer_release(consumer_state_S_P_dP) # pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask) + # Normally we'd need syncwarp here since only 1 thread will signal in + # consumer_release, but we already have the self.compute_sync_barrier before this pipeline_LSE.consumer_release(consumer_state_LSE) consumer_state_LSE.advance() # --------------------------------------------- @@ -3253,6 +3251,8 @@ def compute_loop( cute.arch.fence_view_async_shared() self.compute_sync_barrier.arrive_and_wait() + # Normally we'd need syncwarp here since only 1 thread will signal in + # consumer_release, but we already have the self.compute_sync_barrier before this pipeline_dPsum.consumer_release(consumer_state_dPsum) consumer_state_dPsum.advance() # when 2cta hdim 128, pipeline_dS also signals S tmem load completion so is deferred @@ -3650,6 +3650,9 @@ def dQacc_reduce( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() + if const_expr(not self.deterministic): + cute.arch.cp_async_bulk_wait_group(0, read=True) + @cute.jit def epilogue_dKV( self, diff --git a/flash-attn4/torch-ext/flash_attn4/flash_bwd_sm120.py b/flash-attn4/torch-ext/flash_attn4/flash_bwd_sm120.py new file mode 100644 index 00000000..297504a2 --- /dev/null +++ b/flash-attn4/torch-ext/flash_attn4/flash_bwd_sm120.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# SM120 (Blackwell GeForce / DGX Spark) backward pass. +# +# SM120 uses the same SM80-era MMA instructions (mma.sync.aligned.m16n8k16) but has +# a smaller shared memory capacity (99 KB vs 163 KB on SM80). This module subclasses +# FlashAttentionBackwardSm80 and overrides the SMEM capacity check accordingly. + +import cutlass +import cutlass.utils as utils_basic + +from .flash_bwd import FlashAttentionBackwardSm80 + + +class FlashAttentionBackwardSm120(FlashAttentionBackwardSm80): + @staticmethod + def can_implement( + dtype, + head_dim, + head_dim_v, + m_block_size, + n_block_size, + num_stages_Q, + num_stages_dO, + num_threads, + is_causal, + V_in_regs=False, + ) -> bool: + """Check if the kernel can be implemented on SM120. + + Same logic as SM80 but uses SM120's shared memory capacity (99 KB). + """ + if dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if head_dim % 8 != 0: + return False + if head_dim_v % 8 != 0: + return False + if n_block_size % 16 != 0: + return False + if num_threads % 32 != 0: + return False + # Shared memory usage: Q tile + dO tile + K tile + V tile + smem_usage_Q = m_block_size * head_dim * num_stages_Q * 2 + smem_usage_dO = m_block_size * head_dim_v * num_stages_dO * 2 + smem_usage_K = n_block_size * head_dim * 2 + smem_usage_V = n_block_size * head_dim_v * 2 + smem_usage_QV = ( + (smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V) + ) + smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K + # SM120 has 99 KB shared memory (vs 163 KB on SM80) + smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_120") + if smem_usage > smem_capacity: + return False + return True diff --git a/flash-attn4/torch-ext/flash_attn4/flash_bwd_sm90.py b/flash-attn4/torch-ext/flash_attn4/flash_bwd_sm90.py index be4affab..07ddb12c 100644 --- a/flash-attn4/torch-ext/flash_attn4/flash_bwd_sm90.py +++ b/flash-attn4/torch-ext/flash_attn4/flash_bwd_sm90.py @@ -24,7 +24,13 @@ from .block_info import BlockInfo from . import pipeline from .quack.cute_dsl_utils import ParamsBase -from .tile_scheduler import TileSchedulerArguments, SingleTileScheduler +from .tile_scheduler import ( + TileSchedulerArguments, + SingleTileScheduler, + SingleTileLPTBwdScheduler, + SingleTileVarlenScheduler, +) +from . import barrier from .named_barrier import NamedBarrierBwd from .softmax import apply_score_mod_inner, apply_score_mod_bwd_inner from .block_sparsity import BlockSparseTensors @@ -46,6 +52,8 @@ def __init__( head_dim_v: Optional[int] = None, qhead_per_kvhead: int = 1, is_causal: bool = False, + is_local: bool = False, + deterministic: bool = False, tile_m: int = 64, tile_n: int = 128, Q_stage: int = 2, @@ -64,6 +72,7 @@ def __init__( mask_mod: cutlass.Constexpr | None = None, has_aux_tensors: cutlass.Constexpr = False, subtile_factor: cutlass.Constexpr[int] = 1, + dQ_single_wg: bool = False, ): self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size @@ -77,7 +86,8 @@ def __init__( self.check_hdim_v_oob = head_dim_v != self.tile_hdimv self.qhead_per_kvhead = qhead_per_kvhead self.is_causal = is_causal - self.is_local = False + self.is_local = is_local + self.deterministic = deterministic self.tile_m = tile_m self.tile_n = tile_n self.num_threads = num_threads @@ -92,23 +102,23 @@ def __init__( self.AtomLayoutMSdP = AtomLayoutMSdP self.AtomLayoutNdKV = AtomLayoutNdKV self.AtomLayoutMdQ = AtomLayoutMdQ - self.num_mma_warp_groups = (self.num_threads // 128) - 1 + self.num_wg_mma = (self.num_threads // 128) - 1 self.mma_dkv_is_rs = ( AtomLayoutMSdP == 1 - and AtomLayoutNdKV == self.num_mma_warp_groups + and AtomLayoutNdKV == self.num_wg_mma and SdP_swapAB and not dKV_swapAB ) self.V_in_regs = V_in_regs + # May be overridden in __call__ for varlen inputs. if qhead_per_kvhead > 1: assert self.same_hdim_kv, "GQA backward requires head_dim == head_dim_v" - assert self.num_mma_warp_groups == 2, "GQA backward assumes 2 warp groups" + assert self.num_wg_mma == 2, "GQA backward assumes 2 warp groups" # These are tuned for speed # Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share # them and then shuffle to get the value whenever we need? This can reduce register # pressure when SdP_swapAB, where each thread needs to keep statistics for (kBlockM / 4) # rows. If !SdP_swapAB, each thread only needs to keep statistics for 2 rows. - # TODO: impl these for hdim 64 self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64 self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64 @@ -124,6 +134,12 @@ def __init__( else: self.vec_size: cutlass.Constexpr = 4 self.qk_acc_dtype = Float32 + # dQ_single_wg: WG0 computes the full dQ GEMM, WG1 skips it. + # Only valid for 2 MMA warp groups. + # Credit: Ben Spector + if dQ_single_wg: + assert self.num_wg_mma == 2, "dQ_single_wg only supports 2 warp groups" + self.num_wg_dQ = 1 if dQ_single_wg else self.num_wg_mma @staticmethod def can_implement( @@ -182,32 +198,58 @@ def _check_type( assert mQ_type == self.dtype def _setup_attributes(self): - self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout, self.sPdS_layout = [ - sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, shape, stage) - for shape, stage in [ - ((self.tile_m, self.tile_hdim), self.Q_stage), - ((self.tile_n, self.tile_hdim), None), - ((self.tile_n, self.tile_hdimv), None), - ((self.tile_m, self.tile_hdimv), self.dO_stage), - ((self.tile_m, self.tile_n), self.PdS_stage), + # We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory. + # Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma. + # The M dimension (tile_m) doesn't matter for the layout, only the K dimension + wg_d_dKV = self.num_wg_mma // self.AtomLayoutNdKV + self.sQ_layout, self.sdO_layout = [ + # Need to set major_mode_size (mms) to accommodate Q and Q.T + sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, shape, stage, mms) + for shape, stage, mms in [ + ((self.tile_m, self.tile_hdim), self.Q_stage, self.tile_hdim // wg_d_dKV), + ((self.tile_m, self.tile_hdimv), self.dO_stage, self.tile_hdim // wg_d_dKV), ] ] + wg_d_dQ = self.num_wg_dQ // self.AtomLayoutMdQ + # Accomodate both K and K.T + self.sK_layout = sm90_utils.make_smem_layout( + self.dtype, + LayoutEnum.ROW_MAJOR, + (self.tile_n, self.tile_hdim), + stage=None, + major_mode_size=self.tile_hdim // wg_d_dQ, + ) + # There's only V, no V.T, so layout is normal + self.sV_layout = sm90_utils.make_smem_layout( + self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_n, self.tile_hdimv), None + ) + # Accomodate both S and S.T + wg_n_SdP = self.num_wg_mma // self.AtomLayoutMSdP + wg_n_dKV = self.AtomLayoutNdKV + self.sPdS_layout = sm90_utils.make_smem_layout( + self.dtype, + LayoutEnum.ROW_MAJOR, + (self.tile_m, self.tile_n), + stage=self.PdS_stage, + major_mode_size=math.gcd(self.tile_n // wg_n_SdP, self.tile_n // wg_n_dKV), + ) self.sdQaccum_layout = cute.make_layout( - (self.tile_m * self.tile_hdim // self.num_mma_warp_groups, self.num_mma_warp_groups) + (self.tile_m * self.tile_hdim // self.num_wg_dQ, self.num_wg_dQ) ) # dQaccum R->S self.r2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv( cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128), # thr_layout - cute.make_layout((self.num_threads_per_warp_group, self.num_mma_warp_groups)), + cute.make_layout((self.num_threads_per_warp_group, self.num_wg_dQ)), cute.make_layout(128 // Float32.width), # val_layout ) # dKVaccum for GQA epilogue - reuses sV+sK memory recast as f32 # TODO: assert that sVaccum and sKaccum don't overflow smem def _get_tiled_mma(self): + maybe_swap_mn = lambda shape, swap: (shape[1], shape[0], *shape[2:]) if swap else shape # S = Q @ K.T, dP = dO @ V.T - atom_layout_SdP = (self.AtomLayoutMSdP, self.num_mma_warp_groups // self.AtomLayoutMSdP) + atom_layout_SdP = (self.AtomLayoutMSdP, self.num_wg_mma // self.AtomLayoutMSdP, 1) tiler_mn_SdP = (self.tile_m // atom_layout_SdP[0], self.tile_n // atom_layout_SdP[1]) tiled_mma_SdP = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, @@ -215,12 +257,11 @@ def _get_tiled_mma(self): warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.K, Float32, - atom_layout_mnk=(atom_layout_SdP if not self.SdP_swapAB else atom_layout_SdP[::-1]) - + (1,), - tiler_mn=tiler_mn_SdP if not self.SdP_swapAB else tiler_mn_SdP[::-1], + atom_layout_mnk=maybe_swap_mn(atom_layout_SdP, self.SdP_swapAB), + tiler_mn=(64, tiler_mn_SdP[1] if not self.SdP_swapAB else tiler_mn_SdP[0]), ) # dV = P.T @ dO, dK = dS.T @ Q - atom_layout_dKV = (self.AtomLayoutNdKV, self.num_mma_warp_groups // self.AtomLayoutNdKV) + atom_layout_dKV = (self.AtomLayoutNdKV, self.num_wg_mma // self.AtomLayoutNdKV, 1) tiler_mn_dK = (self.tile_n // atom_layout_dKV[0], self.tile_hdim // atom_layout_dKV[1]) tiler_mn_dV = (self.tile_n // atom_layout_dKV[0], self.tile_hdimv // atom_layout_dKV[1]) tiled_mma_dK, tiled_mma_dV = [ @@ -232,9 +273,8 @@ def _get_tiled_mma(self): else warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.MN, Float32, - atom_layout_mnk=(atom_layout_dKV if not self.dKV_swapAB else atom_layout_dKV[::-1]) - + (1,), - tiler_mn=tiler_mn_d if not self.dKV_swapAB else tiler_mn_d[::-1], + atom_layout_mnk=maybe_swap_mn(atom_layout_dKV, self.dKV_swapAB), + tiler_mn=(64, tiler_mn_d[1] if not self.dKV_swapAB else tiler_mn_d[0]), a_source=warpgroup.OperandSource.RMEM if self.mma_dkv_is_rs else warpgroup.OperandSource.SMEM, @@ -242,7 +282,8 @@ def _get_tiled_mma(self): for tiler_mn_d in (tiler_mn_dK, tiler_mn_dV) ] # dQ = dS @ K - atom_layout_dQ = (self.AtomLayoutMdQ, self.num_mma_warp_groups // self.AtomLayoutMdQ) + assert self.num_wg_dQ % self.AtomLayoutMdQ == 0 + atom_layout_dQ = (self.AtomLayoutMdQ, self.num_wg_dQ // self.AtomLayoutMdQ, 1) tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1]) tiled_mma_dQ = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, @@ -250,8 +291,8 @@ def _get_tiled_mma(self): warpgroup.OperandMajorMode.K if not self.dQ_swapAB else warpgroup.OperandMajorMode.MN, warpgroup.OperandMajorMode.MN if not self.dQ_swapAB else warpgroup.OperandMajorMode.K, Float32, - atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1]) + (1,), - tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1], + atom_layout_mnk=maybe_swap_mn(atom_layout_dQ, self.dQ_swapAB), + tiler_mn=(64, tiler_mn_dQ[1] if not self.dQ_swapAB else tiler_mn_dQ[0]), ) return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ @@ -305,7 +346,6 @@ def __call__( mdK: cute.Tensor, mdV: cute.Tensor, softmax_scale: Float32, - stream: cuda.CUstream, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, @@ -318,10 +358,13 @@ def __call__( mdV_semaphore: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, ): - assert mdQ_semaphore is None and mdK_semaphore is None and mdV_semaphore is None, ( - "determinism not supported yet for Sm90" - ) + # For GQA (qhead_per_kvhead > 1), multiple Q heads accumulate into the same dK/dV, + # so we need the float32 accum path + postprocess. + # For varlen_k with qhead_per_kvhead == 1, we use ragged TMA tensors. + self.varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None self._check_type( *( @@ -330,23 +373,36 @@ def __call__( ) ) + self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None + mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [ assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV) ] - layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) - mQ, mK, mV, mdO = [layout_utils.select(t, layout_transpose) for t in (mQ, mK, mV, mdO)] + # Non-varlen inputs are (b, s, n, h), varlen inputs are (s, n, h). + # We convert both to a seqlen-major view with head-dim second. + # Each tensor may have different rank when Q is padded (seqused_q) but K/V are unpadded (cu_seqlens_k). + def _qkv_transpose(t): + return layout_utils.select(t, [1, 3, 2, 0] if cute.rank(t.shape) == 4 else [0, 2, 1]) + + mQ, mK, mV, mdO = [_qkv_transpose(t) for t in (mQ, mK, mV, mdO)] if const_expr(self.qhead_per_kvhead == 1): - mdK, mdV = [layout_utils.select(t, layout_transpose) for t in (mdK, mdV)] + mdK, mdV = [_qkv_transpose(t) for t in (mdK, mdV)] else: - accum_transpose = [2, 1, 0] # (b, n, s*h) -> (s*h, n, b) + # Accum tensors are (b, n, s*h) for non-varlen and (n, s*h) for varlen. + accum_transpose = [2, 1, 0] if cute.rank(mdK.shape) == 3 else [1, 0] mdK, mdV = [layout_utils.select(t, accum_transpose) for t in (mdK, mdV)] - LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) -> (s, n, b) + # Non-varlen stats are (b, n, s), varlen stats are (n, s). + LSE_dPsum_dQaccum_transpose = [2, 1, 0] if cute.rank(mLSE.shape) == 3 else [1, 0] mLSE, mdPsum, mdQaccum = [ layout_utils.select(t, LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum) ] tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ = self._get_tiled_mma() + # (batch, num_head, num_m_blocks, cluster_size) -> (num_m_blocks, cluster_size, num_head, batch) + if const_expr(self.deterministic): + assert mdQ_semaphore is not None + mdQ_semaphore = layout_utils.select(mdQ_semaphore, mode=[2, 3, 1, 0]) self.num_mma_threads = tiled_mma_SdP.size assert self.num_mma_threads + 128 == self.num_threads @@ -354,10 +410,25 @@ def __call__( self.num_threads_per_warp_group = 128 self.num_producer_threads = 32 - self.num_mma_regs = 240 - self.num_producer_regs = 24 - # self.num_mma_regs = 232 - # self.num_producer_regs = 40 + REG_LIMIT = 504 if self.num_wg_mma == 2 else 512 + if const_expr(self.num_wg_mma == 2): + if const_expr(self.num_wg_dQ == 1): + self.num_mma_regs_wg0 = 256 + self.num_mma_regs_wg1 = 224 + else: + self.num_mma_regs_wg0 = 240 + self.num_mma_regs_wg1 = 240 + self.num_mma_regs = self.num_mma_regs_wg0 # for backward compat + self.num_producer_regs = 24 + assert ( + self.num_mma_regs_wg0 + self.num_mma_regs_wg1 + self.num_producer_regs <= REG_LIMIT + ) + else: # 3 warp groups + self.num_mma_regs_wg0 = 160 + self.num_mma_regs_wg1 = 160 + self.num_mma_regs = 160 + self.num_producer_regs = 32 + assert self.num_mma_regs_wg0 * self.num_wg_mma + self.num_producer_regs <= REG_LIMIT self._setup_attributes() SharedStorage = self._get_shared_storage_cls() @@ -374,7 +445,7 @@ def __call__( self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8 self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8 self.tma_copy_bytes["dQ"] = ( - self.tile_m * self.tile_hdim * Float32.width // 8 // self.num_mma_warp_groups + self.tile_m * self.tile_hdim * Float32.width // 8 // self.num_wg_dQ ) self.tma_copy_bytes["dKacc"] = self.tile_n * self.tile_hdim * Float32.width // 8 self.tma_copy_bytes["dVacc"] = self.tile_n * self.tile_hdimv * Float32.width // 8 @@ -404,38 +475,59 @@ def __call__( (self.tile_m, self.tile_hdimv), ) if const_expr(self.qhead_per_kvhead == 1): + mdK_tma = ( + copy_utils.create_ragged_tensor_for_tma(mdK, ragged_dim=0, ptr_shift=True) + if self.varlen_k + else mdK + ) + mdV_tma = ( + copy_utils.create_ragged_tensor_for_tma(mdV, ragged_dim=0, ptr_shift=True) + if self.varlen_k + else mdV + ) tma_atom_dK, tma_tensor_dK = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileS2GOp(), - mdK, + mdK_tma, cute.select(self.sK_layout, mode=[0, 1]), (self.tile_n, self.tile_hdim), ) tma_atom_dV, tma_tensor_dV = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileS2GOp(), - mdV, + mdV_tma, cute.select(self.sV_layout, mode=[0, 1]), (self.tile_n, self.tile_hdimv), ) else: tma_atom_dK = tma_atom_dV = tma_tensor_dK = tma_tensor_dV = None - TileScheduler = SingleTileScheduler + if const_expr(mCuSeqlensK is not None or mSeqUsedK is not None): + TileScheduler = SingleTileVarlenScheduler + elif const_expr(self.deterministic): + TileScheduler = SingleTileLPTBwdScheduler + else: + TileScheduler = SingleTileScheduler + self.spt = (self.is_causal or self.is_local) and self.deterministic tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.tile_n), cute.size(mQ.shape[2]), - cute.size(mQ.shape[3]), + cute.size(mK.shape[3]) + if const_expr(mCuSeqlensK is None) + else cute.size(mCuSeqlensK.shape[0] - 1), # num_batch 1, # num_splits - cute.size(mK.shape[0]), - mQ.shape[1], - mV.shape[1], - total_q=cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), - tile_shape_mn=(self.tile_m, self.tile_n), - mCuSeqlensQ=None, - mSeqUsedQ=None, + cute.size(mQ.shape[0]), # pass seqlen_q or total_q for seqlen_k + mQ.shape[1], # headdim + mV.shape[1], # headdim_v + total_q=cute.size(mK.shape[0]) + if const_expr(mCuSeqlensK is not None) + else cute.size(mK.shape[0]) * cute.size(mK.shape[3]), + tile_shape_mn=(self.tile_n, self.tile_m), # Swapping the role of Q & K + mCuSeqlensQ=mCuSeqlensK, + mSeqUsedQ=mSeqUsedK, qhead_per_kvhead_packgqa=1, element_size=self.dtype.width // 8, is_persistent=False, - lpt=False, + lpt=self.spt, + head_swizzle=self.deterministic, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) @@ -461,6 +553,11 @@ def __call__( self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + if const_expr(window_size_left is not None): + window_size_left = Int32(window_size_left) + if const_expr(window_size_right is not None): + window_size_right = Int32(window_size_right) + self.kernel( tma_tensor_Q, tma_tensor_K, @@ -477,6 +574,10 @@ def __call__( mLSE, mdPsum, mdQaccum, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, self.sQ_layout, self.sK_layout, self.sV_layout, @@ -497,11 +598,15 @@ def __call__( fastdiv_mods, blocksparse_tensors, qhead_per_kvhead_divmod, + mdQ_semaphore, + window_size_left, + window_size_right, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], stream=stream, min_blocks_per_mp=1, + use_pdl=True, ) @cute.kernel @@ -522,6 +627,10 @@ def kernel( mLSE: cute.Tensor, mdPsum: cute.Tensor, mdQaccum: cute.Tensor, + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, @@ -542,15 +651,17 @@ def kernel( fastdiv_mods=(None, None), blocksparse_tensors: Optional[BlockSparseTensors] = None, qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None, + mdQ_semaphore: Optional[cute.Tensor] = None, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # prefetch TMA descriptors if warp_idx == 0: - cpasync.prefetch_descriptor(tma_atom_Q) - cpasync.prefetch_descriptor(tma_atom_K) - cpasync.prefetch_descriptor(tma_atom_V) - cpasync.prefetch_descriptor(tma_atom_dO) + for atom in [tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_dO, tma_atom_dK, tma_atom_dV]: + if const_expr(atom is not None): + cpasync.prefetch_descriptor(atom) smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) @@ -604,25 +715,27 @@ def kernel( self.is_causal, self.is_local, False, # is_split_kv - None, - None, + window_size_left, + window_size_right, qhead_per_kvhead_packgqa=1, ) SeqlenInfoCls = partial( SeqlenInfoQK.create, seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0], - mCuSeqlensQ=None, - mCuSeqlensK=None, - mSeqUsedQ=None, - mSeqUsedK=None, + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + tile_m=self.tile_m, + tile_n=self.tile_n, ) AttentionMaskCls = partial( AttentionMask, self.tile_m, self.tile_n, - window_size_left=None, - window_size_right=None, + window_size_left=window_size_left, + window_size_right=window_size_right, swap_AB=self.SdP_swapAB, ) TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) @@ -663,12 +776,12 @@ def kernel( TileSchedulerCls, SeqlenInfoCls, blocksparse_tensors, + mdQ_semaphore, ) else: - cute.arch.setmaxregister_increase(self.num_mma_regs) tidx, _, _ = cute.arch.thread_idx() tidx = tidx - 128 - self.mma( + mma_args = ( tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, @@ -702,6 +815,19 @@ def kernel( blocksparse_tensors, qhead_per_kvhead_divmod, ) + if const_expr(self.num_wg_dQ == self.num_wg_mma): + # Both WGs compute dQ + cute.arch.setmaxregister_increase(self.num_mma_regs_wg0) + self.mma(*mma_args, is_dQ_wg=True) + else: + # WG0 computes dQ, WG1 skips it + warp_idx_in_mma = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - 4 + if warp_idx_in_mma < 4: + cute.arch.setmaxregister_increase(self.num_mma_regs_wg0) + self.mma(*mma_args, is_dQ_wg=True) + else: + cute.arch.setmaxregister_increase(self.num_mma_regs_wg1) + self.mma(*mma_args, is_dQ_wg=False) @cute.jit def load( @@ -749,18 +875,22 @@ def load( if const_expr(self.qhead_per_kvhead == 1) else head_idx // qhead_per_kvhead_divmod ) - mK_cur = mK[None, None, head_idx_kv, batch_idx] + mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] + mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv] gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) - mV_cur = mV[None, None, head_idx_kv, batch_idx] gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) - mQ_cur = mQ[None, None, head_idx, batch_idx] + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2, padded=True)[ + None, head_idx + ] + mdO_cur = seqlen.offset_batch_Q(mdO, batch_idx, dim=3)[None, None, head_idx] + mdPsum_cur = seqlen.offset_batch_Q(mdPsum, batch_idx, dim=2, padded=True)[ + None, head_idx + ] gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (None, 0)) - mdO_cur = mdO[None, None, head_idx, batch_idx] gdO = cute.local_tile(mdO_cur, (self.tile_m, self.tile_hdimv), (None, 0)) - mLSE_cur = mLSE[None, head_idx, batch_idx] gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,)) - mdPsum_cur = mdPsum[None, head_idx, batch_idx] gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,)) load_K, _, _ = copy_utils.tma_get_copy_fn( @@ -786,7 +916,10 @@ def load( if const_expr(not self.use_block_sparsity): total_m_block_cnt = m_block_max - m_block_min - process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) + or m_block_min < m_block_max + ) else: total_m_block_cnt = get_total_q_block_count_bwd( blocksparse_tensors, @@ -806,6 +939,8 @@ def load( ) load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) load_Q(first_m_block, producer_state=producer_state_Q) + # Wait for bwd preprocess to finish writing LSE and dPsum + cute.arch.griddepcontrol_wait() load_LSE(first_m_block, producer_state=producer_state_Q) producer_state_dO_cur = ( producer_state_dO @@ -984,16 +1119,20 @@ def mma( fastdiv_mods=(None, None), blocksparse_tensors: Optional[BlockSparseTensors] = None, qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None, + is_dQ_wg: cutlass.Constexpr[bool] = True, ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) warp_group_thread_layout = cute.make_layout( - self.num_mma_warp_groups, stride=self.num_threads_per_warp_group + self.num_wg_mma, stride=self.num_threads_per_warp_group ) thr_mma_SdP = tiled_mma_SdP.get_slice(tidx) wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx)) wg_mma_dK = tiled_mma_dK.get_slice(warp_group_thread_layout(warp_group_idx)) wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx)) - wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_dQ = None + if const_expr(is_dQ_wg): + wg_idx_dQ = warp_group_idx if const_expr(self.num_wg_dQ > 1) else 0 + wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(wg_idx_dQ)) # S = Q @ K.T shape_mnk_S = (self.tile_m, self.tile_n, self.tile_hdim) _, tSrQ, tSrK = sm90_utils.partition_fragment_ABC( @@ -1039,23 +1178,43 @@ def mma( # dQ = dS @ K sKt = layout_utils.transpose_view(sK) shape_mnk_dQ = (self.tile_m, self.tile_hdim, self.tile_n) - _, tdQrdS, tdQrKt = sm90_utils.partition_fragment_ABC( - wg_mma_dQ, shape_mnk_dQ, sdS, sKt, swap_AB=self.dQ_swapAB - ) - mma_dsk_fn = partial( - gemm_zero_init, tiled_mma_dQ, shape_mnk_dQ[:2], tdQrdS, tdQrKt, swap_AB=self.dQ_swapAB - ) + mma_dsk_fn = None + if const_expr(is_dQ_wg): + _, tdQrdS, tdQrKt = sm90_utils.partition_fragment_ABC( + wg_mma_dQ, shape_mnk_dQ, sdS, sKt, swap_AB=self.dQ_swapAB + ) + mma_dsk_fn = partial( + gemm_zero_init, + tiled_mma_dQ, + shape_mnk_dQ[:2], + tdQrdS, + tdQrKt, + swap_AB=self.dQ_swapAB, + ) - # Smem copy atom tiling + # Smem copy atom tiling for P/dS R2S copy_P_r2s = None + mms_PdS = self.tile_n // (self.num_wg_mma // self.AtomLayoutMSdP) if const_expr(sP is not None): sP_cpy = sP if const_expr(not self.SdP_swapAB) else sPt copy_P_r2s, _, _ = copy_utils.get_smem_store_C( - tiled_mma_SdP, sP_cpy, tidx, self.arch, transpose=self.SdP_swapAB + tiled_mma_SdP, + sP_cpy, + tidx, + self.arch, + transpose=self.SdP_swapAB, + position_independent=True, + major_mode_size=mms_PdS, ) sdS_cpy = sdS if const_expr(not self.SdP_swapAB) else sdSt copy_dS_r2s, _, _ = copy_utils.get_smem_store_C( - tiled_mma_SdP, sdS_cpy, tidx, self.arch, transpose=self.SdP_swapAB + tiled_mma_SdP, + sdS_cpy, + tidx, + self.arch, + transpose=self.SdP_swapAB, + position_independent=True, + major_mode_size=mms_PdS, ) tLSEsLSE = layout_utils.mma_partition_C_vec( @@ -1064,9 +1223,21 @@ def mma( tLSEsdPsum = layout_utils.mma_partition_C_vec( sdPsum, thr_mma_SdP, expand_shape=self.tile_n, is_colvec=not self.SdP_swapAB ) - - smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx) - tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum) + # When shuffle=True, rows are distributed across 8 quads (4 threads each) within a warp. + # Each thread loads only ceil(num_rows/8) values; + shfl_copy = copy_utils.tiled_copy_1d(sLSE.element_type, num_threads=8, num_copy_elems=2) + if const_expr(self.shuffle_LSE): + tLSEsLSE = shfl_copy.get_slice(cute.arch.lane_idx() // 4).partition_S(tLSEsLSE) + # ((2, 1), 1, 2) -> (((2, 1), 1), 2) + tLSEsLSE = cute.group_modes(tLSEsLSE, 0, 2) + if const_expr(self.shuffle_dPsum): + tLSEsdPsum = shfl_copy.get_slice(cute.arch.lane_idx() // 4).partition_S(tLSEsdPsum) + tLSEsdPsum = cute.group_modes(tLSEsdPsum, 0, 2) + + tdQsdQaccum = None + if const_expr(is_dQ_wg): + smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx) + tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum) PdS_barrier = cutlass.pipeline.NamedBarrier( barrier_id=int(NamedBarrierBwd.PdS), num_threads=self.num_mma_threads @@ -1105,6 +1276,7 @@ def mma( PdS_barrier=PdS_barrier, # acc_dV=acc_dV, # acc_dK=acc_dK, + is_dQ_wg=is_dQ_wg, ) consumer_state_Q = cutlass.pipeline.make_pipeline_state( @@ -1136,7 +1308,10 @@ def mma( m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) if const_expr(not self.use_block_sparsity): - process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) + or m_block_min < m_block_max + ) else: total_m_block_cnt = get_total_q_block_count_bwd( blocksparse_tensors, @@ -1218,8 +1393,8 @@ def mma( qhead_per_kvhead_divmod, ) else: - # Block sparsity: KV tile with zero Q blocks produces no dK/dV; write zeros. - if const_expr(self.use_block_sparsity): + # KV tile with zero Q blocks produces no dK/dV; write zeros. + if const_expr(self.use_block_sparsity or self.is_local or self.is_varlen_q): acc_dK.fill(0.0) acc_dV.fill(0.0) self.epilogue_dKV( @@ -1248,6 +1423,22 @@ def mma( if warp_idx == 4: cute.arch.cp_async_bulk_wait_group(0, read=True) + @staticmethod + @cute.jit + def _get_stat(tSrS: cute.Tensor, row: Int32, lane: Int32, shuffle: bool) -> Float32: + """Retrieve the statistic for a given accumulator row. + + When shuffle=False, direct register indexing. + When shuffle=True, warp shuffle from the thread group that holds the value. + """ + if const_expr(not shuffle): + return tSrS[row] + # tSrS: (((2, 1), 1), 1)), distributed across 8 threads in the warp + vecsize = cute.size(tSrS, mode=[0, 0]) # 2 + idx0, off, idx1 = cute.idx2crd(row, (vecsize, 8, cute.shape(tSrS, mode=[0, 1]))) + # register index: 0, 1, 0, 1, ..., 2, 3, 2, 3, ... + return utils.shuffle_sync(tSrS[idx0 + idx1 * vecsize], offset=off * 4 + (lane % 4)) + @cute.jit def mma_one_m_block( self, @@ -1266,16 +1457,17 @@ def mma_one_m_block( pipeline_dO: cutlass.pipeline.PipelineAsync, tLSEsLSE: cute.Tensor, tLSEsdPsum: cute.Tensor, - tdQsdQaccum: cute.Tensor, + tdQsdQaccum: Optional[cute.Tensor], softmax_scale_log2: Float32, PdS_barrier: cutlass.pipeline.NamedBarrier, + is_dQ_wg: cutlass.Constexpr[bool] = True, mask_fn: Optional[Callable] = None, score_mod_fn: Optional[Callable] = None, score_mod_bwd_fn: Optional[Callable] = None, dKV_accumulate: Boolean = True, ): consumer_state_dO_cur = ( - consumer_state_dO if const_expr(self.Q_stage == self.dO_stage) else consumer_state_Q + consumer_state_Q if const_expr(self.Q_stage == self.dO_stage) else consumer_state_dO ) smem_idx_Q = consumer_state_Q.index smem_idx_dO = consumer_state_dO_cur.index if const_expr(self.dO_stage > 1) else 0 @@ -1283,6 +1475,7 @@ def mma_one_m_block( # (1) [GEMM 1] S = Q @ K^T pipeline_Q.consumer_wait(consumer_state_Q, pipeline_Q.consumer_try_wait(consumer_state_Q)) acc_S = mma_qk_fn(A_idx=smem_idx_Q, wg_wait=-1) + # If shuffle_LSE, OOB reads are OK since sLSE is already padded tLSErLSE = copy_utils.load_s2r(tLSEsLSE[None, smem_idx_Q]) # (2) [GEMM 2] dP = dO @ V.T pipeline_dO.consumer_wait( @@ -1301,10 +1494,12 @@ def mma_one_m_block( if cutlass.const_expr(mask_fn is not None): mask_fn(acc_S, m_block=m_block) acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.SdP_swapAB) + lane_idx = cute.arch.lane_idx() for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])): + lse_val = self._get_stat(tLSErLSE, r, lane_idx, shuffle=self.shuffle_LSE) for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True): acc_S_mn[r, c] = cute.math.exp2( - acc_S_mn[r, c] * softmax_scale_log2 - tLSErLSE[r], fastmath=True + acc_S_mn[r, c] * softmax_scale_log2 - lse_val, fastmath=True ) tLSErdPsum = copy_utils.load_s2r(tLSEsdPsum[None, smem_idx_dO]) @@ -1321,8 +1516,9 @@ def mma_one_m_block( warpgroup.wait_group(0) acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP, transpose=self.SdP_swapAB) for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])): + dpsum_val = self._get_stat(tLSErdPsum, r, lane_idx, shuffle=self.shuffle_dPsum) for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True): - acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r]) + acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - dpsum_val) if const_expr(self.score_mod_bwd is not None): score_mod_bwd_fn(acc_dP, acc_S_pre, m_block=m_block) @@ -1354,36 +1550,50 @@ def mma_one_m_block( # smem fence to make sure sdS is written before it's read by WGMMA cute.arch.fence_view_async_shared() PdS_barrier.arrive_and_wait() - # (6) [GEMM 4] dQ = dS @ K - acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dV) - pipeline_dO.consumer_release(consumer_state_dO_cur) # release dO as dV mma is done - # (7) [GEMM 5] dK += dS.T @ Q - if const_expr(not self.mma_dkv_is_rs): - mma_dsq_fn( - A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1 - ) - else: - mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dQ) + if const_expr(is_dQ_wg): + # (6) [GEMM 4] dQ = dS @ K + acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1) + pipeline_dO.consumer_release(consumer_state_dO_cur) # release dO as dV mma is done - cute.arch.barrier( - barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, - number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, - ) - tdQrdQaccum_flat = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape)) - cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum) - cute.arch.fence_view_async_shared() - cute.arch.barrier_arrive( - barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, - number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, - ) + # (7) [GEMM 5] dK += dS.T @ Q + if const_expr(not self.mma_dkv_is_rs): + mma_dsq_fn( + A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1 + ) + else: + mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1) - warpgroup.wait_group(0) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dK) - pipeline_Q.consumer_release(consumer_state_Q) - # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block = {}, after pipeline_Q consumer release", cute.arch.thread_idx()[0], m_block) + # dQ R2S: wait for dQaccum_store to free the smem buffer, then write dQ to smem + # When dQ_single_wg, only WG0 enters here so warp_group_idx == 0 + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + tdQrdQaccum_flat = cute.make_tensor( + acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape) + ) + cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum) + cute.arch.fence_view_async_shared() + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + + warpgroup.wait_group(0) + pipeline_Q.consumer_release(consumer_state_Q) + else: + # dQ_single_wg: WG1 skips dQ, only does dV wait + dK + # (7) [GEMM 5] dK += dS.T @ Q + if const_expr(not self.mma_dkv_is_rs): + mma_dsq_fn( + A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1 + ) + else: + mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1) + pipeline_dO.consumer_release(consumer_state_dO_cur) + warpgroup.wait_group(0) + pipeline_Q.consumer_release(consumer_state_Q) consumer_state_Q.advance() consumer_state_dO.advance() @@ -1415,8 +1625,12 @@ def epilogue_dKV( warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) if const_expr(self.qhead_per_kvhead == 1): - mdV_cur = mdV[None, None, head_idx, batch_idx] - mdK_cur = mdK[None, None, head_idx, batch_idx] + mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3, ragged=self.varlen_k)[ + None, None, head_idx + ] + mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3, ragged=self.varlen_k)[ + None, None, head_idx + ] gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) store_dK, _, _ = copy_utils.tma_get_copy_fn( @@ -1428,10 +1642,20 @@ def epilogue_dKV( sdV = sV if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sV) sdK = sK if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sK) copy_dV_r2s, _, _ = copy_utils.get_smem_store_C( - tiled_mma_dV, sdV, tidx, self.arch, transpose=self.dKV_swapAB + tiled_mma_dV, + sdV, + tidx, + self.arch, + transpose=self.dKV_swapAB, + position_independent=True, ) copy_dK_r2s, _, _ = copy_utils.get_smem_store_C( - tiled_mma_dK, sdK, tidx, self.arch, transpose=self.dKV_swapAB + tiled_mma_dK, + sdK, + tidx, + self.arch, + transpose=self.dKV_swapAB, + position_independent=True, ) cute.arch.cp_async_bulk_wait_group(1, read=True) epi_barrier.arrive_and_wait() @@ -1450,15 +1674,19 @@ def epilogue_dKV( store_dK() cute.arch.cp_async_bulk_commit_group() else: - sdKaccum_shape0 = self.tile_n * self.tile_hdim // self.num_mma_warp_groups - sdVaccum_shape0 = self.tile_n * self.tile_hdimv // self.num_mma_warp_groups - sdKaccum_layout = cute.make_layout((sdKaccum_shape0, self.num_mma_warp_groups)) - sdVaccum_layout = cute.make_layout((sdVaccum_shape0, self.num_mma_warp_groups)) + sdKaccum_shape0 = self.tile_n * self.tile_hdim // self.num_wg_mma + sdVaccum_shape0 = self.tile_n * self.tile_hdimv // self.num_wg_mma + sdKaccum_layout = cute.make_layout((sdKaccum_shape0, self.num_wg_mma)) + sdVaccum_layout = cute.make_layout((sdVaccum_shape0, self.num_wg_mma)) head_idx_kv = head_idx // qhead_per_kvhead_divmod - mdKaccum_cur = mdK[None, head_idx_kv, batch_idx] + mdKaccum_cur = seqlen.offset_batch_K( + mdK, batch_idx, dim=2, padded=True, multiple=self.tile_hdim + )[None, head_idx_kv] + mdVaccum_cur = seqlen.offset_batch_K( + mdV, batch_idx, dim=2, padded=True, multiple=self.tile_hdimv + )[None, head_idx_kv] gdKaccum_ = cute.local_tile(mdKaccum_cur, (self.tile_n * self.tile_hdim,), (n_block,)) gdKaccum = cute.flat_divide(gdKaccum_, (sdKaccum_shape0,)) - mdVaccum_cur = mdV[None, head_idx_kv, batch_idx] gdVaccum_ = cute.local_tile(mdVaccum_cur, (self.tile_n * self.tile_hdimv,), (n_block,)) gdVaccum = cute.flat_divide(gdVaccum_, (sdVaccum_shape0,)) # These two overlap each other @@ -1467,7 +1695,7 @@ def epilogue_dKV( sdVaccum = cute.make_tensor(sVaccum_ptr, sdVaccum_layout) tiled_copy_dKVaccum_r2s = cute.make_tiled_copy_tv( cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128), - cute.make_layout((self.num_threads_per_warp_group, self.num_mma_warp_groups)), + cute.make_layout((self.num_threads_per_warp_group, self.num_wg_mma)), cute.make_layout(128 // Float32.width), ) thr_copy_dKVaccum_r2s = tiled_copy_dKVaccum_r2s.get_slice(tidx) @@ -1482,11 +1710,11 @@ def epilogue_dKV( epi_barrier.arrive_and_wait() if warp_idx == 4: with cute.arch.elect_one(): - for wg_idx in cutlass.range_constexpr(self.num_mma_warp_groups): + for wg_idx in cutlass.range_constexpr(self.num_wg_mma): copy_utils.cpasync_reduce_bulk_add_f32( sdKaccum[None, wg_idx].iterator, gdKaccum[None, wg_idx].iterator, - self.tma_copy_bytes["dKacc"] // self.num_mma_warp_groups, + self.tma_copy_bytes["dKacc"] // self.num_wg_mma, ) cute.arch.cp_async_bulk_commit_group() @@ -1498,11 +1726,11 @@ def epilogue_dKV( epi_barrier.arrive_and_wait() if warp_idx == 4: with cute.arch.elect_one(): - for wg_idx in cutlass.range_constexpr(self.num_mma_warp_groups): + for wg_idx in cutlass.range_constexpr(self.num_wg_mma): copy_utils.cpasync_reduce_bulk_add_f32( sdVaccum[None, wg_idx].iterator, gdVaccum[None, wg_idx].iterator, - self.tma_copy_bytes["dVacc"] // self.num_mma_warp_groups, + self.tma_copy_bytes["dVacc"] // self.num_wg_mma, ) cute.arch.cp_async_bulk_commit_group() @@ -1515,21 +1743,45 @@ def dQaccum_store( TileSchedulerCls: cutlass.Constexpr[Callable], SeqlenInfoCls: cutlass.Constexpr[Callable], blocksparse_tensors: Optional[BlockSparseTensors] = None, + mdQ_semaphore: Optional[cute.Tensor] = None, ): + tidx, _, _ = cute.arch.thread_idx() + # warp-local thread index (dQaccum_store runs on warp 1, global tidx 32-63) + warp_local_tidx = tidx % cute.arch.WARP_SIZE + read_flag = const_expr(not self.deterministic) + tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: n_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] - gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,)) - # (M * K / WG, WG, _) - gdQaccum = cute.flat_divide( - gdQaccum_, (self.tile_m * self.tile_hdim // self.num_mma_warp_groups,) + if const_expr(not seqlen.has_cu_seqlens_q): + mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] + else: + mdQaccum_cur = cute.domain_offset( + (seqlen.padded_offset_q * self.tile_hdim,), mdQaccum[None, head_idx] + ) + # ((M * K / num_wg_dQ, num_wg_dQ), num_m_blocks) + gdQaccum = cute.local_tile( + mdQaccum_cur, + ( + cute.make_layout( + (self.tile_m * self.tile_hdim // self.num_wg_dQ, self.num_wg_dQ) + ), + ), + (None,), ) + + if const_expr(mdQ_semaphore is not None): + # mdQ_semaphore is (num_m_blocks, cluster_size, num_head, batch) after transpose + mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] + m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) if const_expr(not self.use_block_sparsity): - process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) + or m_block_min < m_block_max + ) loop_count = m_block_max - m_block_min else: total_block_cnt = get_total_q_block_count_bwd( @@ -1548,17 +1800,36 @@ def dQaccum_store( m_block = m_block_min + iter_idx m_block_safe = m_block - for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups): - cute.arch.cp_async_bulk_wait_group( - self.num_mma_warp_groups - 1 - warp_group_idx, read=True - ) + num_dQ_chunks = self.num_wg_dQ + for warp_group_idx in cutlass.range_constexpr(num_dQ_chunks): + if const_expr(not self.deterministic): + # If deterministic, we already waited at the end of the prev iter + cute.arch.cp_async_bulk_wait_group( + num_dQ_chunks - 1 - warp_group_idx, read=read_flag + ) cute.arch.barrier_arrive( barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, ) - for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups): + # Semaphore acquire: wait for prior n_blocks to finish writing this m_block + if const_expr(self.deterministic): + if const_expr(self.spt): + _, n_block_max_for_m_block = block_info.get_n_block_min_max( + seqlen, m_block_safe + ) + lock_value = n_block_max_for_m_block - 1 - n_block + else: + lock_value = n_block + barrier.wait_eq( + mdQ_semaphore_cur[(m_block_safe, None)].iterator, + warp_local_tidx, + 0, # flag_offset + lock_value, + ) + + for warp_group_idx in cutlass.range_constexpr(num_dQ_chunks): cute.arch.barrier( barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, number_of_threads=self.num_threads_per_warp_group @@ -1567,11 +1838,24 @@ def dQaccum_store( with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( sdQaccum[None, warp_group_idx].iterator, - gdQaccum[None, warp_group_idx, m_block_safe].iterator, + gdQaccum[(None, warp_group_idx), m_block_safe].iterator, self.tma_copy_bytes["dQ"], ) cute.arch.cp_async_bulk_commit_group() + + # Semaphore release: signal that this n_block is done with this m_block + if const_expr(self.deterministic): + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + barrier.arrive_inc( + mdQ_semaphore_cur[(m_block_safe, None)].iterator, + warp_local_tidx, + 0, # flag_offset + 1, + ) else: + assert not self.deterministic, ( + "Deterministic not implemented for block-sparse backward" + ) dQaccum_store_block_sparse_bwd_sm90( blocksparse_tensors, batch_idx, @@ -1581,11 +1865,27 @@ def dQaccum_store( gdQaccum, subtile_factor=self.subtile_factor, m_block_max=m_block_max, - num_mma_warp_groups=self.num_mma_warp_groups, + num_dQ_warp_groups=self.num_wg_dQ, num_threads_per_warp_group=self.num_threads_per_warp_group, tma_copy_bytes_dQ=self.tma_copy_bytes["dQ"], ) + + # For local masking + deterministic (non-spt): signal remaining m_blocks + # that this n_block won't visit, so they don't deadlock waiting. + if const_expr( + self.deterministic and not self.spt and block_info.window_size_left is not None + ): + m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m) + for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1): + barrier.arrive_inc( + mdQ_semaphore_cur[(m_block, None)].iterator, + warp_local_tidx, + 0, # flag_offset + 1, + ) + tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() - cute.arch.cp_async_bulk_wait_group(0, read=True) + if const_expr(not self.deterministic): + cute.arch.cp_async_bulk_wait_group(0, read=True) diff --git a/flash-attn4/torch-ext/flash_attn4/flash_fwd.py b/flash-attn4/torch-ext/flash_attn4/flash_fwd.py index 1824cf84..18b532e7 100644 --- a/flash-attn4/torch-ext/flash_attn4/flash_fwd.py +++ b/flash-attn4/torch-ext/flash_attn4/flash_fwd.py @@ -15,42 +15,28 @@ import cutlass import cutlass.cute as cute from cutlass import Constexpr, Float32, Int32, const_expr, Boolean -from cutlass.cute.nvgpu import cpasync, warp, warpgroup +from cutlass.cute.nvgpu import cpasync, warp import cutlass.utils as utils_basic -from cutlass.utils import LayoutEnum -import cutlass.utils.hopper_helpers as sm90_utils_basic +from cutlass.base_dsl.arch import Arch +from cutlass.cutlass_dsl import BaseDSL from .quack import copy_utils from .quack import layout_utils -from .quack import sm90_utils from . import ampere_helpers as sm80_utils from .cute_dsl_utils import assume_tensor_aligned from . import utils from .mask import AttentionMask -from .softmax import Softmax, apply_score_mod_inner +from .softmax import Softmax from .seqlen_info import SeqlenInfoQK from .block_info import BlockInfo -from .block_sparsity import BlockSparseTensors -from .block_sparse_utils import ( - produce_block_sparse_loads, - consume_block_sparse_loads, -) -from . import pipeline from .pack_gqa import PackGQA from .named_barrier import NamedBarrierFwd -from .quack.cute_dsl_utils import ParamsBase -from .tile_scheduler import ( - TileSchedulerArguments, - SingleTileScheduler, - SingleTileLPTScheduler, - SingleTileVarlenScheduler, -) -from cutlass.cute import FastDivmodDivisor +from .block_sparsity import BlockSparseTensors +from .tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments class FlashAttentionForwardBase: - arch: int = 80 def __init__( self, @@ -116,6 +102,12 @@ def __init__( self.vec_size: cutlass.Constexpr = getattr( score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2 ) + if self.vec_size > 2: + raise ValueError( + f"score_mod vec_size {self.vec_size} not supported on Sm80/90/120 " + "due to accumulator thread ownership pattern." + ) + self.arch = BaseDSL._get_dsl().get_arch_enum() @staticmethod def can_implement( @@ -318,7 +310,8 @@ def __call__( mO: cute.Tensor, mLSE: Optional[cute.Tensor], softmax_scale: Float32, - stream: cuda.CUstream, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, ): """Configures and launches the flash attention kernel. @@ -351,7 +344,7 @@ def epilogue( cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads ) - smem_copy_atom_O = utils.get_smem_store_atom(self.arch, self.dtype) + smem_copy_atom_O = utils.get_smem_store_atom(self.arch.major * 10 + self.arch.minor, self.dtype) smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx) taccOrO = smem_thr_copy_O.retile(rO) taccOsO = smem_thr_copy_O.partition_D(sO) @@ -366,11 +359,7 @@ def epilogue( # Write LSE from rmem -> gmem if const_expr(mLSE is not None): - if const_expr(not seqlen.has_cu_seqlens_q): - mLSE_cur = mLSE[None, head_idx, batch_idx] - else: - offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) - mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) + mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2)[None, head_idx] if const_expr(not self.pack_gqa): gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,)) gLSE_expanded_layout = cute.append( @@ -384,7 +373,7 @@ def epilogue( t0accOcO = layout_utils.reshape_acc_to_mn(thr_mma.get_slice(0).partition_C(cO)) # Only the thread corresponding to column 0 writes out the lse to gmem if taccOcO[0][1] == 0: - for m in cutlass.range_constexpr(cute.size(taccOgLSE.shape[1])): + for m in cutlass.range(cute.size(taccOgLSE.shape[1]), unroll_full=True): if ( t0accOcO[m, 0][0] < seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0] @@ -393,11 +382,8 @@ def epilogue( else: pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q) - if const_expr(not seqlen.has_cu_seqlens_q): - mO_cur = mO[None, None, head_idx, batch_idx] - else: - offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) - mO_cur = cute.domain_offset((offset, 0), mO[None, None, head_idx]) + ragged = self.use_tma_O and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q) + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=ragged)[None, None, head_idx] # thr_mma = tiled_mma.get_slice(tidx) # taccOgO = thr_mma.partition_C(gO) # cute.autovec_copy(rO, taccOgO) @@ -634,12 +620,19 @@ def __call__( mV: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], - stream: cuda.CUstream, - softmax_scale: Optional[Float32] = None, + softmax_scale: Float32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mPageTable: Optional[cute.Tensor] = None, window_size_left: Optional[Int32] = None, window_size_right: Optional[Int32] = None, learnable_sink: Optional[cute.Tensor] = None, + blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors=None, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, ): """Configures and launches the flash attention kernel. @@ -648,7 +641,7 @@ def __call__( """ assert learnable_sink is None, "Learnable sink is not supported in this kernel" self._check_type( - *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE)) + *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)) ) tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() self.num_mma_threads = tiled_mma_pv.size @@ -656,41 +649,54 @@ def __call__( self.num_Q_load_threads = self.num_threads self.num_epilogue_threads = self.num_threads # self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None - self.use_tma_O = self.arch >= 90 + self.use_tma_O = self.arch >= Arch.sm_90 self._setup_attributes() SharedStorage = self._get_shared_storage_cls() mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)] - mQ, mK, mV, mO = [ - cute.make_tensor(t.iterator, cute.select(t.layout, mode=[1, 3, 2, 0])) - for t in (mQ, mK, mV, mO) + # Layout permutation: 4D non-varlen vs 3D varlen + QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] + mQ, mO = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose)) + for t in (mQ, mO) ] - mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=[2, 1, 0])) - # grid_dim: (m_block, num_head, batch_size) - grid_dim = ( - cute.ceil_div(mQ.shape[0], self.tile_m), - cute.size(mQ.shape[2]), - cute.size(mQ.shape[3]), - ) - LOG2_E = math.log2(math.e) - if const_expr(self.score_mod is None): - softmax_scale_log2 = Float32(softmax_scale * LOG2_E) - softmax_scale = None + mK, mV = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose)) + for t in (mK, mV) + ] + if const_expr(mLSE is not None): + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) + # TileScheduler for varlen, simple grid for non-varlen + if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): + TileScheduler = SingleTileVarlenScheduler else: - # NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk - # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base - # and correctly apply the softmax_scale prior to score_mod in the softmax step - softmax_scale_log2 = Float32(LOG2_E) - softmax_scale = Float32(softmax_scale) - - fastdiv_mods = None - if const_expr(aux_tensors is not None): - seqlen_q = cute.size(mQ.shape[0]) // ( - self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 - ) - seqlen_k = cute.size(mK.shape[0]) - seqlen_q_divmod = FastDivmodDivisor(seqlen_q) - seqlen_k_divmod = FastDivmodDivisor(seqlen_k) - fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + TileScheduler = SingleTileScheduler + num_batch = ( + mCuSeqlensQ.shape[0] - 1 + if const_expr(mCuSeqlensQ is not None) + else mQ.shape[3] + ) + tile_sched_args = TileSchedulerArguments( + num_block=cute.ceil_div(mQ.shape[0], self.tile_m), + num_head=cute.size(mQ.shape[2]), + num_batch=num_batch, + num_splits=1, + seqlen_k=0, + headdim=mQ.shape[1], + headdim_v=mV.shape[1], + total_q=cute.size(mQ.shape[0]) + if const_expr(mCuSeqlensQ is not None) + else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + tile_shape_mn=(self.tile_m, self.tile_n), + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + ) + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(softmax_scale, self.score_mod) + fastdiv_mods = utils.compute_fastdiv_mods(mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors) self.kernel( mQ, @@ -698,6 +704,10 @@ def __call__( mV, mO, mLSE, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, softmax_scale_log2, softmax_scale, window_size_left, @@ -714,6 +724,8 @@ def __call__( tiled_mma_qk, tiled_mma_pv, SharedStorage, + tile_sched_params, + TileScheduler, aux_tensors, fastdiv_mods, ).launch( @@ -731,6 +743,10 @@ def kernel( mV: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], softmax_scale_log2: Float32, softmax_scale: Optional[Float32], window_size_left: Optional[Int32], @@ -747,12 +763,17 @@ def kernel( tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, SharedStorage: cutlass.Constexpr, + tile_sched_params, + TileScheduler: cutlass.Constexpr[Callable], aux_tensors=None, fastdiv_mods=None, ): # Thread index, block index tidx, _, _ = cute.arch.thread_idx() - m_block, num_head, batch_size = cute.arch.block_idx() + + tile_scheduler = TileScheduler.create(tile_sched_params) + work_tile = tile_scheduler.initial_work_tile_info() + m_block, num_head, batch_size, _ = work_tile.tile_idx block_info = BlockInfo( self.tile_m, @@ -764,13 +785,21 @@ def kernel( window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) - seqlen = SeqlenInfoQK.create(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0]) + seqlen = SeqlenInfoQK.create( + batch_idx=batch_size, + seqlen_q_static=mQ.shape[0], + seqlen_k_static=mK.shape[0], + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + ) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - # TODO: return early if n_block_max == 0 - # if self.is_causal: - # if n_block_max <= 0: - # return - n_block = n_block_max - 1 + # For varlen, wasted grid tiles (where batch_idx >= num_batch) will have + # seqlen_q=seqlen_k=0 and n_block_max=0. Clamp to 0 so we don't use a + # negative block index for K/V loads; the load/store predicates already + # guard all memory accesses when seqlen is 0. + n_block = cutlass.max(n_block_max - 1, 0) # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. @@ -778,10 +807,20 @@ def kernel( blkQ_shape = (self.tile_m, self.tile_hdim) blkK_shape = (self.tile_n, self.tile_hdim) blkV_shape = (self.tile_n, self.tile_hdimv) - gQ = cute.local_tile(mQ[None, None, num_head, batch_size], blkQ_shape, (m_block, 0)) num_head_kv = num_head // self.qhead_per_kvhead - gK = cute.local_tile(mK[None, None, num_head_kv, batch_size], blkK_shape, (None, 0)) - gV = cute.local_tile(mV[None, None, num_head_kv, batch_size], blkV_shape, (None, 0)) + if const_expr(not seqlen.has_cu_seqlens_q): + mQ_cur = mQ[None, None, num_head, batch_size] + else: + mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, num_head]) + if const_expr(not seqlen.has_cu_seqlens_k): + mK_cur = mK[None, None, num_head_kv, batch_size] + mV_cur = mV[None, None, num_head_kv, batch_size] + else: + mK_cur = cute.domain_offset((seqlen.offset_k, 0), mK[None, None, num_head_kv]) + mV_cur = cute.domain_offset((seqlen.offset_k, 0), mV[None, None, num_head_kv]) + gQ = cute.local_tile(mQ_cur, blkQ_shape, (m_block, 0)) + gK = cute.local_tile(mK_cur, blkK_shape, (None, 0)) + gV = cute.local_tile(mV_cur, blkV_shape, (None, 0)) # /////////////////////////////////////////////////////////////////////////////// # Get shared memory buffer @@ -953,18 +992,20 @@ def preprocess_Q(): mask = AttentionMask( self.tile_m, self.tile_n, - seqlen.seqlen_q, - seqlen.seqlen_k, + seqlen, window_size_left, window_size_right, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) mask_fn = partial( mask.apply_mask, + batch_idx=batch_size, + head_idx=num_head, m_block=m_block, thr_mma=thr_mma_qk, mask_causal=self.is_causal, mask_local=self.is_local, + aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods if const_expr(self.mask_mod is not None) else None, ) @@ -976,8 +1017,8 @@ def preprocess_Q(): smem_pipe_read, smem_pipe_write, is_first_n_block=True, - check_inf=True, - mask_fn=partial(mask_fn, mask_seqlen=True), + seqlen=seqlen, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) @@ -992,15 +1033,17 @@ def preprocess_Q(): n_block, smem_pipe_read, smem_pipe_write, - check_inf=True, - mask_fn=partial(mask_fn, mask_seqlen=False), + seqlen=seqlen, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) # The remaining iterations have no masking for n_tile in cutlass.range(n_block, unroll=1): compute_one_n_block( - n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=True + n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, + seqlen=seqlen, is_first_n_block=False, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False) ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) @@ -1144,1283 +1187,9 @@ def load_K_next(): # load_K_next() -class FlashAttentionForwardSm90(FlashAttentionForwardBase): - arch = 90 - - def __init__( - self, - *args, - intra_wg_overlap: bool = True, - mma_pv_is_rs: bool = True, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.intra_wg_overlap = intra_wg_overlap - self.mma_pv_is_rs = mma_pv_is_rs - self.buffer_align_bytes = 1024 - - def _get_smem_layout_atom(self): - sQ_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils_basic.get_smem_layout_atom(LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim), - self.dtype, - ) - sK_layout_atom = sQ_layout_atom - sV_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils_basic.get_smem_layout_atom( - LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv - ), - self.dtype, - ) - sO_layout_atom = sV_layout_atom - if not self.mma_pv_is_rs: - sP_layout_atom = warpgroup.make_smem_layout_atom( - sm90_utils_basic.get_smem_layout_atom( - LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n - ), - self.dtype, - ) - else: - sP_layout_atom = None - return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom - - def _get_tiled_mma(self): - tiled_mma_qk = sm90_utils_basic.make_trivial_tiled_mma( - self.dtype, - self.dtype, - warpgroup.OperandMajorMode.K, - warpgroup.OperandMajorMode.K, - Float32, - atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 - tiler_mn=(64, self.tile_n), - ) - tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma( - self.dtype, - self.dtype, - warpgroup.OperandMajorMode.K, - warpgroup.OperandMajorMode.MN, - Float32, - atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 - tiler_mn=(64, self.tile_hdimv), - a_source=warpgroup.OperandSource.RMEM - if self.mma_pv_is_rs - else warpgroup.OperandSource.SMEM, - ) - return tiled_mma_qk, tiled_mma_pv - - def _get_shared_storage_cls(self): - sQ_struct, sK_struct, sV_struct = [ - cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], self.buffer_align_bytes] - for layout in (self.sQ_layout, self.sK_layout, self.sV_layout) - - ] - cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) - sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] - cosize_sP = cute.cosize(self.sP_layout) if const_expr(self.sP_layout is not None) else 0 - sP_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024] - # 1 for Q, 1 for O, self.num_stages*2 for K, self.num_stages*2 for V, - mbar_ptr_QO_struct = cute.struct.MemRange[cutlass.Int64, 2] - mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] - mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] - - @cute.struct - class SharedStorageQKV: - mbar_ptr: mbar_ptr_QO_struct - mbar_ptr_K: mbar_ptr_K_struct - mbar_ptr_V: mbar_ptr_V_struct - sV: sV_struct - sQ: sQ_struct - sK: sK_struct - sP: sP_struct - - @cute.struct - class SharedStorageSharedQV: - mbar_ptr: mbar_ptr_QO_struct - mbar_ptr_K: mbar_ptr_K_struct - mbar_ptr_V: mbar_ptr_V_struct - sQ: sQV_struct - sK: sK_struct - sP: sP_struct - - return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV - - @cute.jit - def __call__( - self, - mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table - mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table - mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q - mLSE: Optional[cute.Tensor], - softmax_scale: Float32, - stream: cuda.CUstream, - mCuSeqlensQ: Optional[cute.Tensor] = None, - mCuSeqlensK: Optional[cute.Tensor] = None, - mSeqUsedQ: Optional[cute.Tensor] = None, - mSeqUsedK: Optional[cute.Tensor] = None, - mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) - window_size_left: Int32 | int | None = None, - window_size_right: Int32 | int | None = None, - learnable_sink: Optional[cute.Tensor] = None, - blocksparse_tensors: Optional[BlockSparseTensors] = None, - aux_tensors: Optional[list] = None, - ): - """Configures and launches the flash attention kernel. - - mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: - (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) - """ - - self._check_type( - *( - t.element_type if t is not None else None - for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK) - ) - ) - - mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)] - QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] - mQ, mO = [layout_utils.select(t, QO_layout_transpose) for t in (mQ, mO)] - KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] - mK, mV = [layout_utils.select(t, KV_layout_transpose) for t in (mK, mV)] - LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] - mLSE = layout_utils.select(mLSE, LSE_layout_transpose) if const_expr(mLSE is not None) else None - - tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() - self.num_mma_threads = tiled_mma_qk.size - self.num_threads_per_warp_group = 128 - self.num_mma_warp_groups = self.num_mma_threads // self.num_threads_per_warp_group - self.num_threads = self.num_threads_per_warp_group * (self.num_mma_warp_groups + 1) - self.num_producer_threads = 32 - self.num_Q_load_threads = self.num_mma_threads # If not TMA_Q, MMA threads load Q - self.num_epilogue_threads = self.num_mma_threads - self.num_mma_regs = ( - 256 - if self.num_mma_warp_groups == 1 - else (240 if self.num_mma_warp_groups == 2 else 160) - ) - self.num_producer_regs = ( - 56 if self.num_mma_warp_groups == 1 else (24 if self.num_mma_warp_groups == 2 else 32) - ) - # self.num_mma_regs = 232 - # self.num_producer_regs = 40 - self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) - - self.use_scheduler_barrier = ( - (self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128) - if const_expr(self.intra_wg_overlap) - else (self.num_mma_warp_groups == 2) - ) - self.use_tma_Q = self.arch >= 90 and not ( - self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0 - ) - self.use_tma_O = ( - self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa - ) - # TODO: rescale_O_before_gemm - self._setup_attributes() - # TODO: we prob don't need most of what's in _setup_attributes - self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout = [ - sm90_utils.make_smem_layout(mX.element_type, LayoutEnum.ROW_MAJOR, shape, stage) - for mX, shape, stage in [ - (mQ, (self.tile_m, self.tile_hdim), None), - (mK, (self.tile_n, self.tile_hdim), self.num_stages), - (mV, (self.tile_n, self.tile_hdimv), self.num_stages), - (mO, (self.tile_m, self.tile_hdimv), None), - ] - ] - self.sP_layout = None - if const_expr(not self.mma_pv_is_rs): - self.sP_layout = sm90_utils.make_smem_layout( - mV.element_type, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n) - ) - - SharedStorage = self._get_shared_storage_cls() - - if const_expr(self.pack_gqa): - shape_Q_packed = ( - (self.qhead_per_kvhead, mQ.shape[0]), - mQ.shape[1], - mK.shape[2], - *mQ.shape[3:], - ) - stride_Q_packed = ( - (mQ.stride[2], mQ.stride[0]), - mQ.stride[1], - mQ.stride[2] * self.qhead_per_kvhead, - *mQ.stride[3:], - ) - mQ = cute.make_tensor( - mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed) - ) - shape_O_packed = ( - (self.qhead_per_kvhead, mO.shape[0]), - mK.shape[1], - mK.shape[2], - *mO.shape[3:], - ) - stride_O_packed = ( - (mO.stride[2], mO.stride[0]), - mO.stride[1], - mO.stride[2] * self.qhead_per_kvhead, - *mO.stride[3:], - ) - mO = cute.make_tensor( - mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed) - ) - if const_expr(mLSE is not None): - shape_LSE_packed = ( - (self.qhead_per_kvhead, mLSE.shape[0]), - mK.shape[2], - *mLSE.shape[2:], - ) - stride_LSE_packed = ( - (mLSE.stride[1], mLSE.stride[0]), - mLSE.stride[1] * self.qhead_per_kvhead, - *mLSE.stride[2:], - ) - mLSE = cute.make_tensor( - mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed) - ) - - # TMA - gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp() - gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast - gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp() - self.tma_copy_bytes = { - name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1])) - for name, mX, layout in [ - ("Q", mQ, self.sQ_layout), - ("K", mK, self.sK_layout), - ("V", mV, self.sV_layout), - ] - } - tma_atom_Q, tma_tensor_Q = None, None - if const_expr(self.use_tma_Q): - tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( - gmem_tiled_copy_Q, - mQ, - self.sQ_layout, - (self.tile_m, self.tile_hdim), # No mcast - ) - tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( - gmem_tiled_copy_KV, - mK, - cute.select(self.sK_layout, mode=[0, 1]), - (self.tile_n, self.tile_hdim), - 1, # No mcast for now - ) - tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( - gmem_tiled_copy_KV, - mV, - cute.select(self.sV_layout, mode=[0, 1]), - (self.tile_n, self.tile_hdimv), - 1, # No mcast for now - ) - tma_atom_O, tma_tensor_O = None, None - if const_expr(self.use_tma_O): - tma_atom_O, tma_tensor_O = cpasync.make_tiled_tma_atom( - gmem_tiled_copy_O, - mO, - self.sO_layout, - (self.tile_m, self.tile_hdimv), # No mcast - ) - if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): - TileScheduler = SingleTileVarlenScheduler - else: - TileScheduler = ( - SingleTileScheduler - if const_expr(not self.is_causal or self.is_local) - else SingleTileLPTScheduler - ) - tile_sched_args = TileSchedulerArguments( - cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m), - cute.size(mQ.shape[2]), - cute.size(mQ.shape[3]) - if const_expr(mCuSeqlensQ is None) - else cute.size(mCuSeqlensQ.shape[0] - 1), - 1, # num_splits - cute.size(mK.shape[0]), - mQ.shape[1], - mV.shape[1], - total_q=cute.size(mQ.shape[0]) - if const_expr(mCuSeqlensQ is not None) - else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), - tile_shape_mn=(self.tile_m, self.tile_n), - mCuSeqlensQ=mCuSeqlensQ, - mSeqUsedQ=mSeqUsedQ, - qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, - element_size=self.dtype.width // 8, - is_persistent=False, - lpt=self.is_causal or self.is_local, - ) - tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) - grid_dim = TileScheduler.get_grid_shape(tile_sched_params) - LOG2_E = math.log2(math.e) - if const_expr(self.score_mod is None): - softmax_scale_log2 = softmax_scale * LOG2_E - softmax_scale = None - else: - # NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk - # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base - # and correctly apply the softmax_scale prior to score_mod in the softmax step - softmax_scale_log2 = LOG2_E - softmax_scale = softmax_scale - if const_expr(window_size_left is not None): - window_size_left = Int32(window_size_left) - if const_expr(window_size_right is not None): - window_size_right = Int32(window_size_right) - - fastdiv_mods = None - if const_expr(aux_tensors is not None): - seqlen_q = cute.size(mQ.shape[0]) // ( - self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 - ) - seqlen_k = ( - cute.size(mK.shape[0]) - if const_expr(mPageTable is None) - else mK.shape[0] * mPageTable.shape[1] - ) - seqlen_q_divmod = FastDivmodDivisor(seqlen_q) - seqlen_k_divmod = FastDivmodDivisor(seqlen_k) - fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) - - self.kernel( - tma_tensor_Q if const_expr(self.use_tma_Q) else mQ, - tma_tensor_K, - tma_tensor_V, - tma_tensor_O if const_expr(self.use_tma_O) else mO, - mLSE, - mCuSeqlensQ, - mCuSeqlensK, - mSeqUsedQ, - mSeqUsedK, - tma_atom_Q, - tma_atom_K, - tma_atom_V, - tma_atom_O, - softmax_scale_log2, - softmax_scale, - window_size_left, - window_size_right, - learnable_sink, - blocksparse_tensors, - self.sQ_layout, - self.sK_layout, - self.sV_layout, - self.sO_layout, - self.sP_layout, - self.gmem_tiled_copy_Q, - self.gmem_tiled_copy_K, - self.gmem_tiled_copy_V, - self.gmem_tiled_copy_O, - tiled_mma_qk, - tiled_mma_pv, - tile_sched_params, - TileScheduler, - SharedStorage, - aux_tensors, - fastdiv_mods, - ).launch( - grid=grid_dim, - block=[self.num_threads, 1, 1], - stream=stream, - min_blocks_per_mp=1, - ) - - @cute.kernel - def kernel( - self, - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - mO: cute.Tensor, - mLSE: Optional[cute.Tensor], - mCuSeqlensQ: Optional[cute.Tensor], - mCuSeqlensK: Optional[cute.Tensor], - mSeqUsedQ: Optional[cute.Tensor], - mSeqUsedK: Optional[cute.Tensor], - tma_atom_Q: Optional[cute.CopyAtom], - tma_atom_K: Optional[cute.CopyAtom], - tma_atom_V: Optional[cute.CopyAtom], - tma_atom_O: Optional[cute.CopyAtom], - softmax_scale_log2: Float32, - softmax_scale: Optional[Float32], - window_size_left: Optional[Int32], - window_size_right: Optional[Int32], - learnable_sink: Optional[cute.Tensor], - blocksparse_tensors: Optional[BlockSparseTensors], - sQ_layout: cute.ComposedLayout, - sK_layout: cute.ComposedLayout, - sV_layout: cute.ComposedLayout, - sO_layout: cute.ComposedLayout, - sP_layout: cute.ComposedLayout | None, - gmem_tiled_copy_Q: cute.TiledCopy, - gmem_tiled_copy_K: cute.TiledCopy, - gmem_tiled_copy_V: cute.TiledCopy, - gmem_tiled_copy_O: cute.TiledCopy, - tiled_mma_qk: cute.TiledMma, - tiled_mma_pv: cute.TiledMma, - tile_sched_params: ParamsBase, - TileScheduler: cutlass.Constexpr[Callable], - SharedStorage: cutlass.Constexpr[Callable], - aux_tensors=Optional[list[cute.Tensor]], - fastdiv_mods=None, - ): - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - # Prefetch tma descriptor - if warp_idx == 0: - for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O): - if const_expr(tma_atom is not None): - cpasync.prefetch_descriptor(tma_atom) - - smem = cutlass.utils.SmemAllocator() - storage = smem.allocate(SharedStorage) - - # Mbarrier init - mbar_ptr_Q = storage.mbar_ptr.data_ptr() - if warp_idx == 1: - # if tidx < 2: - # # barrierO num threads should be self.num_mma_threads - # cute.arch.mbarrier_init(mbar_ptr_Q + tidx, 1 if tidx == 0 else self.num_mma_threads) - if const_expr(not self.use_tma_Q): - cute.arch.mbarrier_init(mbar_ptr_Q, self.num_Q_load_threads) - # cute.arch.mbarrier_init(mbar_ptr_Q + 1, self.num_mma_threads) - # We rely on pipeline_k and pipeline_v to initialize the mbarrier fence and sync - pipeline_kv_producer_group = cutlass.pipeline.CooperativeGroup( - cutlass.pipeline.Agent.Thread - ) - pipeline_kv_consumer_group = cutlass.pipeline.CooperativeGroup( - cutlass.pipeline.Agent.Thread, self.num_mma_threads // cute.arch.WARP_SIZE - ) - pipeline_k = pipeline.PipelineTmaAsync.create( - barrier_storage=storage.mbar_ptr_K.data_ptr(), - num_stages=self.num_stages, - producer_group=pipeline_kv_producer_group, - consumer_group=pipeline_kv_consumer_group, - tx_count=self.tma_copy_bytes["K"], - defer_sync=True, - ) - pipeline_v = pipeline.PipelineTmaAsync.create( - barrier_storage=storage.mbar_ptr_V.data_ptr(), - num_stages=self.num_stages, - producer_group=pipeline_kv_producer_group, - consumer_group=pipeline_kv_consumer_group, - tx_count=self.tma_copy_bytes["V"], - defer_sync=False - ) - - # /////////////////////////////////////////////////////////////////////////////// - # Get shared memory buffer - # /////////////////////////////////////////////////////////////////////////////// - sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) - sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) - if const_expr(not self.Q_in_regs): - sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) - else: - sV = storage.sQ.get_tensor( - sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type - ) - # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma - sVt = layout_utils.transpose_view(sV) - sP = None - if const_expr(sP_layout is not None): - sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner) - # reuse sQ's data iterator - sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype) - - block_info = BlockInfo( - self.tile_m, - self.tile_n, - self.is_causal, - self.is_local, - False, # is_split_kv - window_size_left, - window_size_right, - qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, - ) - SeqlenInfoCls = partial( - SeqlenInfoQK.create, - seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], - seqlen_k_static=mK.shape[0], - mCuSeqlensQ=mCuSeqlensQ, - mCuSeqlensK=mCuSeqlensK, - mSeqUsedQ=mSeqUsedQ, - mSeqUsedK=mSeqUsedK, - ) - AttentionMaskCls = partial( - AttentionMask, - self.tile_m, - self.tile_n, - window_size_left=window_size_left, - window_size_right=window_size_right, - qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, - ) - TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) - - if warp_idx < 4: # Producer - cute.arch.setmaxregister_decrease(self.num_producer_regs) - self.load( - mQ, - mK, - mV, - sQ, - sK, - sV, - tma_atom_Q, - tma_atom_K, - tma_atom_V, - pipeline_k, - pipeline_v, - mbar_ptr_Q, - blocksparse_tensors, - block_info, - SeqlenInfoCls, - TileSchedulerCls, - ) - - else: # Consumer - cute.arch.setmaxregister_increase(self.num_mma_regs) - # /////////////////////////////////////////////////////////////////////////////// - # Tile MMA compute thread partitions and allocate accumulators - # /////////////////////////////////////////////////////////////////////////////// - tidx, _, _ = cute.arch.thread_idx() - tidx = tidx - 128 - self.mma( - tiled_mma_qk, - tiled_mma_pv, - mQ, - mO, - mLSE, - sQ, - sK, - sVt, - sP, - sO, - learnable_sink, - pipeline_k, - pipeline_v, - mbar_ptr_Q, - gmem_tiled_copy_Q, - gmem_tiled_copy_O, - tma_atom_O, - tidx, - softmax_scale_log2, - softmax_scale, - block_info, - SeqlenInfoCls, - AttentionMaskCls, - TileSchedulerCls, - blocksparse_tensors, - aux_tensors, - fastdiv_mods, - ) - - @cute.jit - def load( - self, - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - sQ: cute.Tensor, - sK: cute.Tensor, - sV: cute.Tensor, - tma_atom_Q: cute.CopyAtom, - tma_atom_K: cute.CopyAtom, - tma_atom_V: cute.CopyAtom, - pipeline_k: cutlass.pipeline.PipelineAsync, - pipeline_v: cutlass.pipeline.PipelineAsync, - mbar_ptr_Q: cutlass.Pointer, - blocksparse_tensors: Optional[BlockSparseTensors], - block_info: BlockInfo, - SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, - ): - warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 - if warp_idx_in_wg == 0: - q_producer_phase = Int32(1) - kv_producer_state = pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Producer, self.num_stages - ) - tile_scheduler = TileSchedulerCls() - work_tile = tile_scheduler.initial_work_tile_info() - while work_tile.is_valid_tile: - # if work_tile.is_valid_tile: - m_block, head_idx, batch_idx, _ = work_tile.tile_idx - seqlen = SeqlenInfoCls(batch_idx) - mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] - head_idx_kv = ( - head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx - ) - mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] - mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv] - gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0)) - gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (None, 0)) - load_Q = None - if const_expr(self.use_tma_Q): - gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) - load_Q, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_Q, 0, cute.make_layout(1), gQ, sQ, single_stage=True - ) - # TODO: mcast - # TODO check warp_idx if we have 128 producer threads - load_K, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_K, 0, cute.make_layout(1), gK, sK - ) - load_K = copy_utils.tma_producer_copy_fn(load_K, pipeline_k) - load_V, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_V, 0, cute.make_layout(1), gV, sV - ) - load_V = copy_utils.tma_producer_copy_fn(load_V, pipeline_v) - - if const_expr(not self.use_block_sparsity): - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - # if cute.arch.thread_idx()[0] == 0: - # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) - # First iteration: load both Q & K with the same mbarrier - n_block = n_block_max - 1 - pipeline_k.producer_acquire( - kv_producer_state, - extra_tx_count=self.tma_copy_bytes["Q"] - if const_expr(self.use_tma_Q) - else 0, - ) - if const_expr(self.use_tma_Q): - load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) - load_K(src_idx=n_block, producer_state=kv_producer_state) - - if const_expr(not self.intra_wg_overlap): - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block, producer_state=kv_producer_state) - kv_producer_state.advance() - for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): - n_block = n_block_max - 1 - i - 1 - pipeline_k.producer_acquire(kv_producer_state) - load_K(src_idx=n_block, producer_state=kv_producer_state) - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block, producer_state=kv_producer_state) - kv_producer_state.advance() - else: - for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): - n_block_prev = n_block_max - i - 1 - n_block = n_block_prev - 1 - kv_producer_state_prev = kv_producer_state.clone() - kv_producer_state.advance() - pipeline_k.producer_acquire(kv_producer_state) - load_K(src_idx=n_block, producer_state=kv_producer_state) - pipeline_v.producer_acquire(kv_producer_state_prev) - load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev) - n_block = n_block_min - pipeline_v.producer_acquire(kv_producer_state) - load_V(src_idx=n_block, producer_state=kv_producer_state) - kv_producer_state.advance() - else: - kv_producer_state = produce_block_sparse_loads( - blocksparse_tensors, - batch_idx, - head_idx, - m_block, - kv_producer_state, - load_Q, - load_K, - load_V, - pipeline_k, - pipeline_v, - self.use_tma_Q, - self.tma_copy_bytes["Q"], - self.intra_wg_overlap, - self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, - self.q_subtile_factor if self.q_subtile_factor is not None else 1, - ) - - tile_scheduler.prefetch_next_work() - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() - # End of persistent scheduler loop - - @cute.jit - def mma( - self, - tiled_mma_qk: cute.TiledMma, - tiled_mma_pv: cute.TiledMma, - # softmax: Softmax, - # acc_O: cute.Tensor, - mQ: cute.Tensor, - mO: cute.Tensor, - mLSE: Optional[cute.Tensor], - sQ: cute.Tensor, - sK: cute.Tensor, - sVt: cute.Tensor, - sP: Optional[cute.Tensor], - sO: cute.Tensor, - learnable_sink: Optional[cute.Tensor], - pipeline_k: cutlass.pipeline.PipelineAsync, - pipeline_v: cutlass.pipeline.PipelineAsync, - mbar_ptr_Q: cutlass.Pointer, - gmem_tiled_copy_Q: cute.TiledCopy, - gmem_tiled_copy_O: cute.TiledCopy, - tma_atom_O: Optional[cute.CopyAtom], - tidx: Int32, - softmax_scale_log2: Float32, - softmax_scale: Optional[Float32], - block_info: BlockInfo, - SeqlenInfoCls: Callable, - AttentionMaskCls: Callable, - TileSchedulerCls: Callable, - blocksparse_tensors: Optional[BlockSparseTensors], - aux_tensors: Optional[list], - fastdiv_mods=None, - ): - warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) - warp_group_thread_layout = cute.make_layout( - self.num_mma_warp_groups, stride=self.num_threads_per_warp_group - ) - thr_mma_qk = tiled_mma_qk.get_slice(tidx) - wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)) - wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)) - _, tSrQ, tSrK = sm90_utils.partition_fragment_ABC( - wg_mma_qk, (self.tile_m, self.tile_n, self.tile_hdim), sQ, sK - ) - mma_qk_fn = partial( - sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK - ) - acc_O, tOrP, tOrVt = sm90_utils.partition_fragment_ABC( - wg_mma_pv, (self.tile_m, self.tile_hdimv, self.tile_n), sP, sVt - ) - mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt) - - # /////////////////////////////////////////////////////////////////////////////// - # Smem copy atom tiling - # /////////////////////////////////////////////////////////////////////////////// - smem_copy_atom_P = utils.get_smem_store_atom(self.arch, self.dtype) - smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx) - tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None - smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP) - - self.mma_init() - - mma_one_n_block_all = partial( - self.mma_one_n_block_intrawg_overlap - if const_expr(self.intra_wg_overlap) - else self.mma_one_n_block, - mma_qk_fn=mma_qk_fn, - pipeline_k=pipeline_k, - pipeline_v=pipeline_v, - acc_O=acc_O, - tOrP=tOrP, - smem_copy_params=smem_copy_params, - check_inf=True, - ) - - q_consumer_phase = Int32(0) - kv_consumer_state = pipeline.make_pipeline_state( - cutlass.pipeline.PipelineUserType.Consumer, self.num_stages - ) - - tile_scheduler = TileSchedulerCls() - work_tile = tile_scheduler.initial_work_tile_info() - softmax = Softmax.create( - softmax_scale_log2, - num_rows=acc_O.shape[0][0] * acc_O.shape[1], - softmax_scale=softmax_scale, - ) - - process_first_half_block = partial( - self.first_half_block_overlap, - mma_qk_fn=mma_qk_fn, - pipeline_k=pipeline_k, - tOrP=tOrP, - smem_copy_params=smem_copy_params, - softmax=softmax, - ) - process_last_half_block = partial( - self.last_half_block_overlap, - pipeline_v=pipeline_v, - mma_pv_fn=mma_pv_fn, - ) - while work_tile.is_valid_tile: - # if work_tile.is_valid_tile: - - # shape: (atom_v_m * rest_m) - m_block, head_idx, batch_idx, _ = work_tile.tile_idx - seqlen = SeqlenInfoCls(batch_idx) - - # Recompute fastdiv_mods if necessary for varlen with aux_tensors - recompute_fastdiv_mods_q = cutlass.const_expr( - aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q) - ) - recompute_fastdiv_mods_k = cutlass.const_expr( - aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k) - ) - if cutlass.const_expr(fastdiv_mods is not None): - seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods - fastdiv_mods = ( - seqlen_q_divmod - if not recompute_fastdiv_mods_q - else FastDivmodDivisor(seqlen.seqlen_q), - seqlen_k_divmod - if not recompute_fastdiv_mods_k - else FastDivmodDivisor(seqlen.seqlen_k), - ) - - mask = AttentionMaskCls(seqlen) - mask_fn = partial( - mask.apply_mask, - batch_idx=batch_idx, - head_idx=head_idx, - m_block=m_block, - thr_mma=thr_mma_qk, - mask_causal=self.is_causal, - mask_local=self.is_local, - aux_tensors=aux_tensors, - fastdiv_mods=fastdiv_mods, - ) - score_mod_fn = None - if const_expr(self.score_mod is not None): - score_mod_fn = partial( - self.apply_score_mod, - thr_mma_qk, - batch_idx, - head_idx, - m_block, - softmax_scale=softmax_scale, - aux_tensors=aux_tensors, - fastdiv_mods=fastdiv_mods, - ) - mma_one_n_block = partial( - mma_one_n_block_all, - seqlen=seqlen, - softmax=softmax, - score_mod_fn=score_mod_fn, - ) - # Load Q if not TMA_Q - if const_expr(not self.use_tma_Q): - pack_gqa = PackGQA( - self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead - ) - mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] - # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) - # gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) - # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, - # headdim=mQ.shape[1]) - pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q) - cute.arch.cp_async_mbarrier_arrive_noinc(mbar_ptr_Q) - - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) - if const_expr(not self.use_tma_Q): - cute.arch.mbarrier_wait(mbar_ptr_Q, phase=q_consumer_phase) - q_consumer_phase ^= 1 - # For performance reason, we separate out two kinds of iterations: - # those that need masking on S, and those that don't. - # We need masking on S for the very last block when K and V has length not multiple of tile_n. - # We also need masking on S if it's causal, for the last several blocks. - # softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True - O_should_accumulate = False - - # ========================================== - # MAINLOOP - # ========================================== - if const_expr(not self.use_block_sparsity): - # ========================================== - # No block-sparsity (original path) - # ========================================== - # First iteration with seqlen masking - if const_expr(self.intra_wg_overlap): - kv_consumer_state = process_first_half_block( - n_block=n_block_max - 1, - seqlen=seqlen, - kv_consumer_state=kv_consumer_state, - mask_fn=partial(mask_fn, mask_mod=self.mask_mod), - score_mod_fn=score_mod_fn, - is_first_block=True, - ) - # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter - # acc_O.fill(0.0) - else: - self.warp_scheduler_barrier_sync() - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=n_block_max - 1, - seqlen=seqlen, - mma_pv_fn=partial(mma_pv_fn, zero_init=True), - is_first_n_block=True, - mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), - ) - O_should_accumulate = True - # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) - n_block_max -= 1 - # Next couple of iterations with causal masking - if const_expr(self.is_causal or self.is_local): - n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( - seqlen, m_block, n_block_min - ) - # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) - for n_tile in cutlass.range( - n_block_max - n_block_min_causal_local_mask, unroll=1 - ): - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=n_block_max - 1 - n_tile, - seqlen=seqlen, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), - ) - O_should_accumulate = True - n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) - # The remaining iterations have no masking - n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( - seqlen, m_block, n_block_min - ) - # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) - for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=n_block_max - 1 - n_tile, - seqlen=seqlen, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), - ) - O_should_accumulate = True - # Separate iterations with local masking on the left - if const_expr(self.is_local and block_info.window_size_left is not None): - n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) - for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): - kv_consumer_state = mma_one_n_block( - kv_consumer_state, - n_block=n_block_max - 1 - n_tile, - seqlen=seqlen, - mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), - mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), - ) - O_should_accumulate = True - # Last "half" iteration - if const_expr(self.intra_wg_overlap): - kv_consumer_state = process_last_half_block( - kv_consumer_state=kv_consumer_state, - zero_init=not O_should_accumulate, - ) - O_should_accumulate = True - else: - self.warp_scheduler_barrier_arrive() - - else: - # ========================================== - # Block sparsity - # ========================================== - kv_consumer_state, O_should_accumulate, processed_any = consume_block_sparse_loads( - blocksparse_tensors, - batch_idx, - head_idx, - m_block, - seqlen, - kv_consumer_state, - mma_pv_fn, - mma_one_n_block, - process_first_half_block, - process_last_half_block, - mask_fn, - score_mod_fn, - O_should_accumulate, - self.mask_mod, - fastdiv_mods, - self.intra_wg_overlap, - self.warp_scheduler_barrier_sync, - self.warp_scheduler_barrier_arrive, - self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, - self.q_subtile_factor if self.q_subtile_factor is not None else 1, - ) - - # Handle empty case (when no blocks to process) - if not processed_any: - softmax.reset() - acc_O.fill(0.0) - - sink_val = None - if const_expr(learnable_sink is not None): - if const_expr(not self.pack_gqa): - sink_val = Float32(learnable_sink[head_idx]) - else: # Each thread might have a different sink value due to different q_head - sink_val = cute.make_fragment_like(softmax.row_max, Float32) - cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) - tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma_qk.partition_C(cS)) - for r in cutlass.range(cute.size(sink_val), unroll_full=True): - row = m_block * self.tile_m + tScS_mn[r][0] - q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead - sink_val[r] = Float32(learnable_sink[q_head_idx]) - - # normalize acc_O by row_sum and calculate the lse - row_scale = softmax.finalize(sink_val=sink_val) - softmax.rescale_O(acc_O, row_scale) - - # /////////////////////////////////////////////////////////////////////////////// - # Epilogue - # /////////////////////////////////////////////////////////////////////////////// - self.epilogue( - acc_O, - softmax.row_sum, - mO, - mLSE, - sO, - seqlen, - gmem_tiled_copy_O, - tma_atom_O, - tiled_mma_pv, - tidx, - m_block, - head_idx, - batch_idx, - ) - - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() - - - @cute.jit - def first_half_block_overlap( - self, - n_block: Int32, - mma_qk_fn: Callable, - kv_consumer_state, - pipeline_k, - tOrP: cute.Tensor, - smem_copy_params: SimpleNamespace, - softmax: Softmax, - seqlen: SeqlenInfoQK, - mask_fn: Callable = None, - score_mod_fn: Optional[Callable] = None, - is_first_block: bool = False, - ): - """Processes the first half block when using intra-warpgroup-overlap""" - - pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state)) - acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0) - pipeline_k.consumer_release(kv_consumer_state) - - # Apply score modification if present - if const_expr(score_mod_fn is not None): - score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) - - # Apply mask; mask_seqlen always True for first block - # Caveat: if full block further right than mask block, seqlen masking is redundant; - # however, masking is being applied anyway, so essentially no perf hit - mask_fn(acc_S, n_block=n_block, mask_seqlen=True) - - softmax.online_softmax(acc_S, is_first=is_first_block) - - tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) - tOrP_cur = ( - tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) - ) - tOrP_cur.store(tOrP_acc.load().to(self.dtype)) - - # if pv gemm not rs - if const_expr(not self.mma_pv_is_rs): - tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) - cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) - # Fence and barrier to make smem store visible to WGMMA - cute.arch.fence_view_async_shared() - cute.arch.sync_warp() - - return kv_consumer_state - - @cute.jit - def last_half_block_overlap( - self, - kv_consumer_state, - pipeline_v, - mma_pv_fn: Callable, - zero_init: bool, - ): - """Processes the final PV GEMM when using intra-warpgroup-overlap""" - - pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) - mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0) - pipeline_v.consumer_release(kv_consumer_state) - kv_consumer_state.advance() - return kv_consumer_state - - @cute.jit - def mma_one_n_block( - self, - smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, - n_block: Int32, - mma_qk_fn: Callable, - mma_pv_fn: Callable, - pipeline_k: cutlass.pipeline.PipelineAsync, - pipeline_v: cutlass.pipeline.PipelineAsync, - acc_O: cute.Tensor, - tOrP: cute.Tensor, - smem_copy_params: SimpleNamespace, - softmax: Softmax, - seqlen: SeqlenInfoQK, - score_mod_fn: Optional[Callable] = None, - mask_fn: Optional[Callable] = None, - is_first_n_block: cutlass.Constexpr = False, - check_inf: cutlass.Constexpr = True, - ): - pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) - # S = Q @ K.T - acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) - self.warp_scheduler_barrier_arrive() - warpgroup.wait_group(0) - pipeline_k.consumer_release(smem_pipe_read) - - # handle score mods and masking - if const_expr(score_mod_fn is not None): - score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) - if const_expr(mask_fn is not None): - mask_fn(acc_S=acc_S, n_block=n_block) - - row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) - # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S)) - tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) - tOrP_cur = ( - tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) - ) - # tOrP.store(tOrP_acc.load().to(self.dtype)) - # the "to(self.dtype)" conversion fails to vectorize for block sizes other - # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of - # 2 elements. So we just call ptx directly. - utils.cvt_f16(tOrP_acc, tOrP_cur) - if const_expr(not self.mma_pv_is_rs): - tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) - cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) - softmax.rescale_O(acc_O, row_scale) - if const_expr(not self.mma_pv_is_rs): - # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_view_async_shared() - cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV - pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) - self.warp_scheduler_barrier_sync() - # O += P @ V - mma_pv_fn(B_idx=smem_pipe_read.index, wg_wait=0) - pipeline_v.consumer_release(smem_pipe_read) - smem_pipe_read.advance() - return smem_pipe_read - - @cute.jit - def mma_one_n_block_intrawg_overlap( - self, - smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, - n_block: Int32, - mma_qk_fn: Callable, - mma_pv_fn: Callable, - pipeline_k: cutlass.pipeline.PipelineAsync, - pipeline_v: cutlass.pipeline.PipelineAsync, - acc_O: cute.Tensor, - tOrP: cute.Tensor, - smem_copy_params: SimpleNamespace, - softmax: Softmax, - seqlen: SeqlenInfoQK, - score_mod_fn: Optional[Callable] = None, - mask_fn: Optional[Callable] = None, - check_inf: cutlass.Constexpr = True, - ): - smem_pipe_read_v = smem_pipe_read.clone() - smem_pipe_read.advance() - pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) - self.warp_scheduler_barrier_sync() - # S = Q @ K.T - acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) - pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v)) - # O += P @ V - mma_pv_fn(B_idx=smem_pipe_read_v.index, wg_wait=-1) - self.warp_scheduler_barrier_arrive() - warpgroup.wait_group(1) - pipeline_k.consumer_release(smem_pipe_read) - - # handle score mods and masking - if const_expr(score_mod_fn is not None): - score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) - if const_expr(mask_fn is not None): - mask_fn(acc_S=acc_S, n_block=n_block) - # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S)) - - row_scale = softmax.online_softmax(acc_S, check_inf=check_inf) - warpgroup.wait_group(0) - pipeline_v.consumer_release(smem_pipe_read_v) - tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) - tOrP_cur = ( - tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) - ) - # tOrP_cur.store(tOrP_acc.load().to(self.dtype)) - # the "to(self.dtype)" conversion fails to vectorize for block sizes other - # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of - # 2 elements. So we just call ptx directly. - utils.cvt_f16(tOrP_acc, tOrP_cur) - if const_expr(not self.mma_pv_is_rs): - tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) - cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) - softmax.rescale_O(acc_O, row_scale) - if const_expr(not self.mma_pv_is_rs): - # Fence and barrier to make sure smem store is visible to WGMMA - cute.arch.fence_view_async_shared() - cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV - return smem_pipe_read - - @cute.jit - def mma_init(self): - warp_group_idx = utils.canonical_warp_group_idx(sync=False) - if const_expr(self.use_scheduler_barrier): - if warp_group_idx == 1: - cute.arch.barrier_arrive( - barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), - number_of_threads=2 * self.num_threads_per_warp_group, - ) - - @cute.jit - def apply_score_mod( - self, - thr_mma_qk, - batch_idx, - head_idx, - m_block, - acc_S, - n_block, - softmax_scale, - seqlen, - aux_tensors: Optional[list] = None, - fastdiv_mods=None, - ): - # Prepare index tensor - cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) - cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS) - tScS = thr_mma_qk.partition_C(cS) - - apply_score_mod_inner( - acc_S, - tScS, - self.score_mod, - batch_idx, - head_idx, - softmax_scale, - self.vec_size, - self.qk_acc_dtype, - aux_tensors, - fastdiv_mods, - seqlen_info=seqlen, - constant_q_idx=None, - qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, - ) - - def warp_scheduler_barrier_sync(self): - if const_expr(self.use_scheduler_barrier): - cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) - - 1 - + utils.canonical_warp_group_idx(sync=False), - number_of_threads=2 * self.num_threads_per_warp_group, - ) - - def warp_scheduler_barrier_arrive(self): - if const_expr(self.use_scheduler_barrier): - assert self.num_mma_warp_groups in [2, 3] - cur_wg = utils.canonical_warp_group_idx(sync=False) - 1 - if const_expr(self.num_mma_warp_groups == 2): - next_wg = 1 - cur_wg - else: - t = cur_wg + 1 - next_wg = t % self.num_mma_warp_groups - cute.arch.barrier_arrive( - barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, - number_of_threads=2 * self.num_threads_per_warp_group, - ) +# SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility +def __getattr__(name): + if name == "FlashAttentionForwardSm90": + from .flash_fwd_sm90 import FlashAttentionForwardSm90 + return FlashAttentionForwardSm90 + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/flash-attn4/torch-ext/flash_attn4/flash_fwd_combine.py b/flash-attn4/torch-ext/flash_attn4/flash_fwd_combine.py index 478218b1..e6411368 100644 --- a/flash-attn4/torch-ext/flash_attn4/flash_fwd_combine.py +++ b/flash-attn4/torch-ext/flash_attn4/flash_fwd_combine.py @@ -10,7 +10,7 @@ import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync -from cutlass import Float32, Int32, const_expr +from cutlass import Float32, Int32, Boolean, const_expr from . import utils from .cute_dsl_utils import assume_tensor_aligned @@ -24,7 +24,7 @@ def __init__( dtype: Type[cutlass.Numeric], dtype_partial: Type[cutlass.Numeric], head_dim: int, - m_block_size: int = 8, + tile_m: int = 8, k_block_size: int = 64, log_max_splits: int = 4, num_threads: int = 256, @@ -36,7 +36,7 @@ def __init__( :param dtype: output data type :param dtype_partial: partial accumulation data type :param head_dim: head dimension - :param m_block_size: m block size + :param tile_m: m block size :param k_block_size: k block size :param log_max_splits: log2 of maximum splits :param num_threads: number of threads @@ -46,7 +46,7 @@ def __init__( self.dtype = dtype self.dtype_partial = dtype_partial self.head_dim = head_dim - self.m_block_size = m_block_size + self.tile_m = tile_m self.k_block_size = k_block_size self.max_splits = 1 << log_max_splits self.num_threads = num_threads @@ -58,7 +58,7 @@ def can_implement( dtype, dtype_partial, head_dim, - m_block_size, + tile_m, k_block_size, log_max_splits, num_threads, @@ -72,12 +72,12 @@ def can_implement( return False if num_threads % 32 != 0: return False - if m_block_size % 8 != 0: + if tile_m % 8 != 0: return False max_splits = 1 << log_max_splits if max_splits > 256: return False - if (m_block_size * max_splits) % num_threads != 0: + if (tile_m * max_splits) % num_threads != 0: return False return True @@ -124,15 +124,11 @@ def _setup_attributes(self): lse_copy_bits = Float32.width # 1 element per copy, width is in bits m_block_smem = ( 128 - if self.m_block_size % 128 == 0 + if self.tile_m % 128 == 0 else ( 64 - if self.m_block_size % 64 == 0 - else ( - 32 - if self.m_block_size % 32 == 0 - else (16 if self.m_block_size % 16 == 0 else 8) - ) + if self.tile_m % 64 == 0 + else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8)) ) ) gmem_threads_per_row_lse = m_block_smem @@ -183,12 +179,12 @@ def _setup_attributes(self): smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0)) ) self.smem_layout_lse = cute.tile_to_shape( - smem_layout_atom_lse, (self.max_splits, self.m_block_size), (0, 1) + smem_layout_atom_lse, (self.max_splits, self.tile_m), (0, 1) ) # O partial shared memory layout (simple layout for pipeline stages) self.smem_layout_o = cute.make_ordered_layout( - (self.m_block_size, self.k_block_size, self.stages), order=(1, 0, 2) + (self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2) ) @cute.jit @@ -201,7 +197,9 @@ def __call__( cu_seqlens: Optional[cute.Tensor] = None, seqused: Optional[cute.Tensor] = None, num_splits_dynamic_ptr: Optional[cute.Tensor] = None, + varlen_batch_idx: Optional[cute.Tensor] = None, semaphore_to_reset: Optional[cute.Tensor] = None, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): # Type checking @@ -269,7 +267,7 @@ class SharedStorage: sLSE: cute.struct.Align[ cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 ] - sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.m_block_size], 128] + sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.tile_m], 128] sO: cute.struct.Align[ cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128 ] @@ -290,7 +288,7 @@ class SharedStorage: head_divmod = FastDivmodDivisor(num_head) grid_dim = ( - cute.ceil_div(seqlen * num_head, self.m_block_size), + cute.ceil_div(seqlen * num_head, self.tile_m), cute.ceil_div(self.head_dim, self.k_block_size), batch_size, ) @@ -303,6 +301,7 @@ class SharedStorage: cu_seqlens, seqused, num_splits_dynamic_ptr, + varlen_batch_idx, semaphore_to_reset, SharedStorage, self.smem_layout_lse, @@ -331,6 +330,7 @@ def kernel( cu_seqlens: Optional[cute.Tensor], seqused: Optional[cute.Tensor], num_splits_dynamic_ptr: Optional[cute.Tensor], + varlen_batch_idx: Optional[cute.Tensor], semaphore_to_reset: Optional[cute.Tensor], SharedStorage: cutlass.Constexpr, smem_layout_lse: cute.Layout | cute.ComposedLayout, @@ -345,7 +345,14 @@ def kernel( ): # Thread and block indices tidx, _, _ = cute.arch.thread_idx() - m_block, k_block, batch_idx = cute.arch.block_idx() + m_block, k_block, maybe_virtual_batch = cute.arch.block_idx() + + # Map virtual batch index to real batch index (for persistent tile schedulers) + batch_idx = ( + varlen_batch_idx[maybe_virtual_batch] + if const_expr(varlen_batch_idx is not None) + else maybe_virtual_batch + ) # /////////////////////////////////////////////////////////////////////////////// # Get shared memory buffer @@ -353,22 +360,23 @@ def kernel( smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) sLSE = storage.sLSE.get_tensor(smem_layout_lse) - sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.m_block_size,)) + sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,)) sO = storage.sO.get_tensor(smem_layout_o) - # Handle semaphore reset + # Handle semaphore reset โ€” wait for dependent grids first if const_expr(semaphore_to_reset is not None): if ( tidx == 0 and m_block == cute.arch.grid_dim()[0] - 1 and k_block == cute.arch.grid_dim()[1] - 1 - and batch_idx == cute.arch.grid_dim()[2] - 1 + and maybe_virtual_batch == cute.arch.grid_dim()[2] - 1 ): + cute.arch.griddepcontrol_wait() semaphore_to_reset[0] = 0 - # Get number of splits + # Get number of splits (use maybe_virtual_batch for per-batch-slot splits) num_splits = ( - num_splits_dynamic_ptr[batch_idx] + num_splits_dynamic_ptr[maybe_virtual_batch] if const_expr(num_splits_dynamic_ptr is not None) else mLSE_partial.shape[1] ) @@ -378,6 +386,7 @@ def kernel( seqlen_static=mO_partial.shape[0], cu_seqlens=cu_seqlens, seqused=seqused, + # Don't need to pass in tile size since we won't use offset_padded ) seqlen, offset = seqlen_info.seqlen, seqlen_info.offset @@ -387,29 +396,27 @@ def kernel( # Early exit for single split if dynamic if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and ( - const_expr(not varlen) or m_block * self.m_block_size < max_idx + const_expr(not varlen) or m_block * self.tile_m < max_idx ): + # Wait for dependent grids (e.g., the main attention kernel that produces O_partial/LSE_partial) + cute.arch.griddepcontrol_wait() + # =============================== # Step 1: Load LSE_partial from gmem to shared memory # =============================== - if const_expr(cu_seqlens is None): - mLSE_partial_cur = mLSE_partial[None, None, None, batch_idx] - else: - mLSE_partial_cur = cute.domain_offset((offset, 0, 0), mLSE_partial) + mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3) mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,)) - gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx) tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE) - # Create identity tensor for coordinate tracking - cLSE = cute.make_identity_tensor((self.max_splits, self.m_block_size)) + cLSE = cute.make_identity_tensor((self.max_splits, self.tile_m)) tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE) # Load LSE partial values for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): mi = tLSEcLSE[0, 0, m][1] # Get m coordinate - idx = m_block * self.m_block_size + mi + idx = m_block * self.tile_m + mi if idx < max_idx: # Calculate actual sequence position and head using FastDivmodDivisor if const_expr(not varlen): @@ -436,22 +443,19 @@ def kernel( # =============================== gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx) - cO = cute.make_identity_tensor((self.m_block_size, self.k_block_size)) + cO = cute.make_identity_tensor((self.tile_m, self.k_block_size)) tOcO = gmem_thr_copy_O_partial.partition_D(cO) tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO) - if const_expr(cu_seqlens is None): - mO_partial_cur = mO_partial[None, None, None, None, batch_idx] - else: - mO_partial_cur = cute.domain_offset((offset, 0, 0, 0), mO_partial) + mO_partial_cur = seqlen_info.offset_batch(mO_partial, batch_idx, dim=4) # Precompute these values to avoid recomputing them in the loop num_rows = const_expr(cute.size(tOcO, mode=[1])) - tOmidx = cute.make_fragment(num_rows, cutlass.Int32) - tOhidx = cute.make_fragment(num_rows, cutlass.Int32) - tOrOptr = cute.make_fragment(num_rows, cutlass.Int64) + tOmidx = cute.make_rmem_tensor(num_rows, cutlass.Int32) + tOhidx = cute.make_rmem_tensor(num_rows, cutlass.Int32) + tOrOptr = cute.make_rmem_tensor(num_rows, cutlass.Int64) for m in cutlass.range(num_rows, unroll_full=True): mi = tOcO[0, m, 0][0] # m coordinate - idx = m_block * self.m_block_size + mi + idx = m_block * self.tile_m + mi if const_expr(not varlen): tOhidx[m], tOmidx[m] = divmod(idx, seqlen_divmod) else: @@ -463,11 +467,12 @@ def kernel( if idx >= max_idx: tOhidx[m] = -1 - tOpO = cute.make_fragment(cute.size(tOcO, [2]), cutlass.Boolean) + tOpO = None if const_expr(not self.is_even_k): + tOpO = cute.make_rmem_tensor(cute.size(tOcO, mode=[2]), Boolean) for k in cutlass.range(cute.size(tOpO), unroll_full=True): tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size - # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO) + # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO) load_O_partial = partial( self.load_O_partial, @@ -501,17 +506,17 @@ def kernel( s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx) ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE) - ts2rrLSE = cute.make_fragment_like(ts2rsLSE) + ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE) cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE) # =============================== # Step 4: Compute final LSE along split dimension # =============================== - lse_sum = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Float32) + lse_sum = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32) ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE) # We compute the max valid split for each row to short-circuit the computation later - max_valid_split = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Int32) + max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32) assert cute.size(ts2rrLSE, mode=[0]) == 1 # Compute max, scales, and final LSE for each row for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): @@ -561,7 +566,7 @@ def kernel( for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes mi = ts2rcLSE[0, 0, m][1] - if mi < self.m_block_size: + if mi < self.tile_m: sMaxValidSplit[mi] = max_valid_split[m] # =============================== @@ -577,7 +582,7 @@ def kernel( for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes mi = ts2rcLSE[0, 0, m][1] - idx = m_block * self.m_block_size + mi + idx = m_block * self.tile_m + mi if idx < max_idx: if const_expr(not varlen): head_idx, m_idx = divmod(idx, seqlen_divmod) @@ -594,11 +599,11 @@ def kernel( # Get max valid split for this thread thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]] - for m in cutlass.range(1, cute.size(tOcO, mode=[1])): + for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True): thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]]) - tOrO_partial = cute.make_fragment_like(tOsO_partial[None, None, None, 0]) - tOrO = cute.make_fragment_like(tOrO_partial, Float32) + tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0]) + tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32) tOrO.fill(0.0) stage_load = self.stages - 1 @@ -607,7 +612,7 @@ def kernel( # Main accumulation loop for s in cutlass.range(thr_max_valid_split + 1, unroll=4): # Get scales for this split - scale = cute.make_fragment(num_rows, Float32) + scale = cute.make_rmem_tensor(num_rows, Float32) for m in cutlass.range(num_rows, unroll_full=True): scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem @@ -637,8 +642,9 @@ def kernel( # Step 7: Write final O to gmem # =============================== - rO = cute.make_fragment_like(tOrO, self.dtype) + rO = cute.make_rmem_tensor_like(tOrO, self.dtype) rO.store(tOrO.load().to(self.dtype)) + mO_cur = seqlen_info.offset_batch(mO, batch_idx, dim=3) if const_expr(cu_seqlens is None): mO_cur = mO[None, None, None, batch_idx] else: @@ -665,7 +671,7 @@ def load_O_partial( tOrOptr: cute.Tensor, tOsO_partial: cute.Tensor, tOhidx: cute.Tensor, - tOpO: cute.Tensor, + tOpO: Optional[cute.Tensor], tOcO: cute.Tensor, mO_cur_partial_layout: cute.Layout, split: Int32, @@ -684,7 +690,7 @@ def load_O_partial( mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,)) for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): k_idx = tOcO[0, 0, k][1] // elems_per_load - if const_expr(self.is_even_k) or tOpO[k]: + if const_expr(tOpO is None) or tOpO[k]: cute.copy( gmem_tiled_copy_O_partial, mO_partial_cur_copy[None, k_idx, split], diff --git a/flash-attn4/torch-ext/flash_attn4/flash_fwd_sm100.py b/flash-attn4/torch-ext/flash_attn4/flash_fwd_sm100.py index e9ba465f..14183a9d 100644 --- a/flash-attn4/torch-ext/flash_attn4/flash_fwd_sm100.py +++ b/flash-attn4/torch-ext/flash_attn4/flash_fwd_sm100.py @@ -13,9 +13,8 @@ # https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py -import enum import math -from typing import Type, Tuple, Callable, Optional, Literal +from typing import Tuple, Callable, Optional, Literal from functools import partial import cuda.bindings.driver as cuda @@ -28,6 +27,7 @@ import cutlass.utils.blackwell_helpers as sm100_utils_basic from cutlass import pipeline from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass.utils import ClcDynamicPersistentTileScheduler from cutlass.base_dsl.arch import Arch from cutlass.cutlass_dsl import BaseDSL @@ -35,7 +35,9 @@ from .paged_kv import PagedKVManager from .cute_dsl_utils import assume_tensor_aligned +from . import utils from . import pipeline as pipeline_custom +import cutlass.pipeline as cutlass_pipeline from .mask import AttentionMask from .softmax import SoftmaxSm100, apply_score_mod_inner from .seqlen_info import SeqlenInfoQK @@ -47,33 +49,45 @@ softmax_block_sparse_sm100, handle_block_sparse_empty_tile_correction_sm100, ) -from .pack_gqa import PackGQA +from .pack_gqa import PackGQA, pack_gqa_layout from . import mma_sm100_desc as sm100_desc from . import blackwell_helpers as sm100_utils +from .named_barrier import NamedBarrierFwdSm100 from cutlass.cute import FastDivmodDivisor from .quack.cute_dsl_utils import ParamsBase from .tile_scheduler import ( + ClcState, + SchedulingMode, TileSchedulerArguments, + TileSchedulerProtocol, SingleTileScheduler, StaticPersistentTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, ) - - -class NamedBarrierFwd(enum.IntEnum): - Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() - TmemPtr = enum.auto() - SoftmaxStatsW0 = enum.auto() - SoftmaxStatsW1 = enum.auto() - SoftmaxStatsW2 = enum.auto() - SoftmaxStatsW3 = enum.auto() - SoftmaxStatsW4 = enum.auto() - SoftmaxStatsW5 = enum.auto() - SoftmaxStatsW6 = enum.auto() - SoftmaxStatsW7 = enum.auto() -# WarpSchedulerWG1 = enum.auto() -# WarpSchedulerWG2 = enum.auto() +from .fa_logging import fa_log, fa_printf +from .utils import smid + +# === TUNING KNOBS (agent-editable) === +# Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool) +# Values: +# ex2_emu_freq: int โ€” how often to use emulated exp2 (0=all hardware exp2, higher=more emulation). +# SM103 has fast native exp2, so set freq=0 there. +# ex2_emu_start_frg: int โ€” fragment index to start emulation from +# num_regs_softmax: int โ€” register count for softmax warps (multiple of 8) +# num_regs_correction: int โ€” register count for correction warps (multiple of 8) +# num_regs_other is derived: 512 - num_regs_softmax * 2 - num_regs_correction +_TUNING_CONFIG = { + (True, False, 128, False): {'ex2_emu_freq': 10, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 176, 'num_regs_correction': 88}, + (False, True, 128, False): {'ex2_emu_freq': 16, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 192, 'num_regs_correction': 72}, + (True, False, 192, False): {"ex2_emu_freq": 16, "ex2_emu_start_frg": 0, "num_regs_softmax": 184, "num_regs_correction": 80}, + (False, True, 192, False): {"ex2_emu_freq": 32, "ex2_emu_start_frg": 1, "num_regs_softmax": 192, "num_regs_correction": 72}, + (True, False, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 80}, + (False, True, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64}, + (True, False, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64}, + (False, True, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 72}, +} +# === END TUNING KNOBS === class FlashAttentionForwardSm100: @@ -99,6 +113,7 @@ def __init__( paged_kv_non_tma: bool = False, is_varlen_q: bool = False, use_2cta_instrs: bool = False, + use_clc_scheduler: bool = False, ): self.use_tma_KV = not paged_kv_non_tma # self.dtype = dtype @@ -145,10 +160,6 @@ def __init__( self.is_split_kv = is_split_kv self.pack_gqa = pack_gqa self.q_subtile_factor = q_subtile_factor - if pack_gqa: - assert m_block_size % self.qhead_per_kvhead == 0, ( - "For PackGQA, m_block_size must be divisible by qhead_per_kvhead" - ) assert not (self.is_split_kv and self.head_dim_v_padded >= 192), ( "SplitKV is not supported for hdim >= 192" ) @@ -160,8 +171,10 @@ def __init__( # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f - # self.enable_ex2_emu = self.head_dim_padded <= 128 and not is_sm103 - self.enable_ex2_emu = (self.head_dim_padded <= 128 or (self.head_dim_padded == 192 and self.use_2cta_instrs and not self.is_causal and not self.is_local)) and not is_sm103 + self.is_sm103 = is_sm103 + # enable_ex2_emu is derived: True if tuning config has freq > 0, else fallback to default logic + _default_enable_ex2_emu = (self.head_dim_padded <= 128 or (self.head_dim_padded == 192 and self.use_2cta_instrs and not self.is_causal and not self.is_local)) and not is_sm103 + self.enable_ex2_emu = _default_enable_ex2_emu self.s0_s1_barrier = False self.overlap_sO_sQ = ( (self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or @@ -174,6 +187,32 @@ def __init__( "Paged KV does not support irregular head dim" ) + self.use_clc_scheduler = ( + use_clc_scheduler + and self.use_tma_KV + and not self.overlap_sO_sQ + ) + self.sched_stages = 1 + if self.use_clc_scheduler: + assert self.cluster_shape_mn[1] == 1, f"CLC requires cluster N == 1: {self.cluster_shape_mn}" + assert self.cluster_shape_mn[0] in (1, 2), f"bad CLC cluster M: {self.cluster_shape_mn}" + assert self.cluster_shape_mn[0] == self.cta_group_size, ( + f"CLC cluster M != cta_group_size: {self.cluster_shape_mn}, {self.cta_group_size}" + ) + + self.scheduling_mode = SchedulingMode.CLC if self.use_clc_scheduler else SchedulingMode.STATIC + + if is_varlen_q: + self.TileScheduler = SingleTileVarlenScheduler + elif self.is_causal or self.is_local or self.use_clc_scheduler: + self.TileScheduler = SingleTileLPTScheduler + elif self.is_persistent: + self.TileScheduler = StaticPersistentTileScheduler + else: + self.TileScheduler = SingleTileScheduler + + fa_log(1, f"TileScheduler={self.TileScheduler.__name__}, scheduling_mode={self.scheduling_mode.name}, USE_2CTA={self.use_2cta_instrs}") + self.softmax0_warp_ids = (0, 1, 2, 3) self.softmax1_warp_ids = (4, 5, 6, 7) self.correction_warp_ids = (8, 9, 10, 11) @@ -195,8 +234,10 @@ def __init__( ) ) + self.use_tma_Q = not (self.pack_gqa and self.m_block_size % self.qhead_per_kvhead != 0) + if self.q_stage == 1: - if not self.use_tma_KV: + if not self.use_tma_KV or not self.use_tma_Q: self.empty_warp_ids = self.empty_warp_ids + self.load_warp_ids self.load_warp_ids = self.softmax1_warp_ids else: @@ -212,6 +253,8 @@ def __init__( elif self.is_varlen_q: # fallback self.epilogue_warp_ids = (13, 14) + self.clc_scheduler_warp_id = self.empty_warp_ids[0] if self.use_clc_scheduler else None + self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 self.tmem_o_offset = [ self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded @@ -227,31 +270,26 @@ def __init__( # vec buffer for row_max & row_sum self.tmem_vec_offset = self.tmem_s_offset + # Look up tuning config for register counts and ex2_emu params + _tune_key = (self.use_2cta_instrs, self.is_causal, self.head_dim_padded, self.is_sm103) + self._tune = _TUNING_CONFIG.get(_tune_key, {}) + if "ex2_emu_freq" in self._tune: + self.enable_ex2_emu = self._tune["ex2_emu_freq"] > 0 if self.head_dim_padded < 96: self.num_regs_softmax = 200 if not paged_kv_non_tma else 184 self.num_regs_correction = 64 self.num_regs_other = 48 if not paged_kv_non_tma else 80 else: - # self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184 - if not self.enable_ex2_emu: - self.num_regs_softmax = 192 if not paged_kv_non_tma else 184 + if not paged_kv_non_tma and "num_regs_softmax" in self._tune: + self.num_regs_softmax = self._tune["num_regs_softmax"] + self.num_regs_correction = self._tune["num_regs_correction"] + elif not paged_kv_non_tma: + self.num_regs_softmax = 192 + self.num_regs_correction = 80 else: - # self.num_regs_softmax = 200 if not paged_kv_non_tma else 184 - self.num_regs_softmax = 192 if not paged_kv_non_tma else 184 - # self.num_regs_softmax = 176 - # self.num_regs_correction = 96 - # self.num_regs_correction = 64 if self.is_causal or self.is_local else 80 - if not self.enable_ex2_emu: - self.num_regs_correction = 80 if not paged_kv_non_tma else 64 - else: - # self.num_regs_correction = 64 - self.num_regs_correction = 80 if not paged_kv_non_tma else 64 - # self.num_regs_other = 32 - # self.num_regs_other = 64 - # self.num_regs_other = 80 - self.num_regs_other = 48 if not paged_kv_non_tma else 80 - # self.num_regs_other = 96 if self.is_causal or self.is_local else 80 - # self.num_regs_other = 64 if self.is_causal or self.is_local else 80 + self.num_regs_softmax = 184 + self.num_regs_correction = 64 + self.num_regs_other = 512 - self.num_regs_softmax * 2 - self.num_regs_correction self.buffer_align_bytes = 1024 @@ -289,7 +327,7 @@ def _setup_attributes(self): self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and self.kv_stage == 3 ) self.uneven_kv_smem_offset = ( - self.m_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2 + self.n_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2 if self.uneven_kv_smem else 0 ) @@ -304,7 +342,6 @@ def __call__( mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q mLSE: Optional[cute.Tensor], softmax_scale: Float32, - stream: cuda.CUstream, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, @@ -315,6 +352,8 @@ def __call__( learnable_sink: Optional[cute.Tensor] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors: Optional[list] = None, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, ): """Execute the Fused Multi-Head Attention operation on the provided tensors. @@ -367,22 +406,21 @@ def __call__( if const_expr(self.q_dtype != self.v_dtype): raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") self._setup_attributes() - self.use_tma_O = self.arch >= Arch.sm_90 and mCuSeqlensQ is None and mSeqUsedQ is None - # This can be tuned - # This is currently very ad-hoc, we should tune it systematically + self.use_tma_O = ( + self.arch >= Arch.sm_90 + and mCuSeqlensQ is None + and mSeqUsedQ is None + and not (self.pack_gqa and self.m_block_size % self.qhead_per_kvhead != 0) + and not (self.pack_gqa and self.is_split_kv) + ) self.ex2_emu_freq = 0 - # self.ex2_emu_start_frg = 1 if self.is_causal else 0 - self.ex2_emu_start_frg = 1 + self.ex2_emu_start_frg = self._tune.get("ex2_emu_start_frg", 1) if const_expr(self.enable_ex2_emu): - self.ex2_emu_freq = 16 - if const_expr(self.head_dim_padded == 128 and self.use_2cta_instrs): - self.ex2_emu_freq = 12 + self.ex2_emu_freq = self._tune.get("ex2_emu_freq", 16) if const_expr( self.pack_gqa and self.head_dim_padded > 64 and not self.is_causal and not self.is_local ): - self.ex2_emu_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else 10 - if const_expr(self.head_dim_padded > 64 and self.is_causal): - self.ex2_emu_freq = 10 + self.ex2_emu_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else self._tune.get("ex2_emu_freq", 10) cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE q_major_mode = tcgen05.OperandMajorMode.K @@ -462,50 +500,11 @@ def __call__( ) if const_expr(self.pack_gqa): - shape_Q_packed = ( - (self.qhead_per_kvhead, mQ.shape[0]), - mQ.shape[1], - mK.shape[2], - *mQ.shape[3:], - ) - stride_Q_packed = ( - (mQ.stride[2], mQ.stride[0]), - mQ.stride[1], - mQ.stride[2] * self.qhead_per_kvhead, - *mQ.stride[3:], - ) - mQ = cute.make_tensor( - mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed) - ) - shape_O_packed = ( - (self.qhead_per_kvhead, mO.shape[0]), - mO.shape[1], - mK.shape[2], - *mO.shape[3:], - ) - stride_O_packed = ( - (mO.stride[2], mO.stride[0]), - mO.stride[1], - mO.stride[2] * self.qhead_per_kvhead, - *mO.stride[3:], - ) - mO = cute.make_tensor( - mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed) - ) + nheads_kv = mK.shape[2] + mQ = pack_gqa_layout(mQ, self.qhead_per_kvhead, nheads_kv, head_idx=2) + mO = pack_gqa_layout(mO, self.qhead_per_kvhead, nheads_kv, head_idx=2) if const_expr(mLSE is not None): - shape_LSE_packed = ( - (self.qhead_per_kvhead, mLSE.shape[0]), - mK.shape[2], - *mLSE.shape[2:], - ) - stride_LSE_packed = ( - (mLSE.stride[1], mLSE.stride[0]), - mLSE.stride[1] * self.qhead_per_kvhead, - *mLSE.stride[2:], - ) - mLSE = cute.make_tensor( - mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed) - ) + mLSE = pack_gqa_layout(mLSE, self.qhead_per_kvhead, nheads_kv, head_idx=1) self.tma_copy_bytes = { name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) @@ -522,14 +521,24 @@ def __call__( tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) tma_store_op = cpasync.CopyBulkTensorTileS2GOp() - tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A( - tma_load_op, - mQ, - cute.select(sQ_layout, mode=[0, 1, 2]), - self.mma_tiler_qk, - tiled_mma_qk, - cta_layout_vmnk.shape, - ) + if const_expr(self.use_tma_Q): + tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + mQ, + cute.select(sQ_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, + tiled_mma_qk, + cta_layout_vmnk.shape, + ) + gmem_tiled_copy_Q = None + else: + tma_atom_Q = None + async_copy_elems = 128 // self.q_dtype.width + num_load_threads = cute.arch.WARP_SIZE * len(self.load_warp_ids) + threads_per_row = math.gcd(self.head_dim_padded // async_copy_elems, num_load_threads) + gmem_tiled_copy_Q = copy_utils.tiled_copy_2d( + self.q_dtype, threads_per_row, num_load_threads, async_copy_elems, is_async=True + ) tma_atom_K = None tma_atom_V = None @@ -578,19 +587,10 @@ def __call__( vO_layout = cute.make_layout((1, async_copy_elems)) gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) - if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): - TileScheduler = SingleTileVarlenScheduler - else: - if const_expr(self.is_causal or self.is_local): - TileScheduler = SingleTileLPTScheduler - else: - TileScheduler = ( - SingleTileScheduler - if const_expr(not self.is_persistent) - else StaticPersistentTileScheduler - ) + TileScheduler = self.TileScheduler + _num_block_divisor = self.cta_tiler[0] * (self.cta_group_size if not self.is_persistent and self.cta_group_size > 1 else 1) tile_sched_args = TileSchedulerArguments( - cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]), + cute.ceil_div(cute.size(mQ.shape[0]), _num_block_divisor), cute.size(mQ.shape[2]), cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) @@ -613,8 +613,11 @@ def __call__( lpt=self.is_causal or self.is_local, is_split_kv=self.is_split_kv, cluster_shape_mn=self.cluster_shape_mn, + use_cluster_idx=not self.is_persistent and self.cta_group_size > 1, + ) + tile_sched_params = TileScheduler.to_underlying_arguments( + tile_sched_args, scheduling_mode=self.scheduling_mode ) - tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) self.tile_scheduler_cls = TileScheduler grid_dim = TileScheduler.get_grid_shape(tile_sched_params) @@ -624,6 +627,9 @@ def __call__( cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width) ) + clc_response_size = self.sched_stages * 4 if self.use_clc_scheduler else 0 + clc_mbar_size = self.sched_stages * 2 if self.use_clc_scheduler else 0 + @cute.struct class SharedStorage: # m_barriers for pipelines @@ -643,6 +649,13 @@ class SharedStorage: # Smem tensors # store row max and row sum sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2] + # CLC buffers placed here to utilize padding before sO's 1024-byte alignment. + # This avoids adding bytes at the end when we're at the smem limit. + # PipelineClcFetchAsync expects 2 * sched_stages mbarriers (full + empty). + clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, clc_mbar_size] + # CLC response storage (16 bytes per stage, stored as 4 Int32s). + clc_response: cute.struct.MemRange[Int32, clc_response_size] + # Large TMA buffers with 1024-byte alignment sO: cute.struct.Align[ cute.struct.MemRange[self.o_dtype, sO_size], self.buffer_align_bytes ] @@ -657,35 +670,10 @@ class SharedStorage: self.shared_storage = SharedStorage - LOG2_E = math.log2(math.e) - if const_expr(self.score_mod is None): - softmax_scale_log2 = softmax_scale * LOG2_E - softmax_scale = None - else: - # NB: If a users passes in a score mod, we want to apply the score-mod in the sm_scaled qk - # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base - # and correctly apply the softmax_scale prior to score_mod in the softmax step - softmax_scale_log2 = LOG2_E - softmax_scale = softmax_scale - - if const_expr(window_size_left is not None): - window_size_left = Int32(window_size_left) - if const_expr(window_size_right is not None): - window_size_right = Int32(window_size_right) - - fastdiv_mods = None - if cutlass.const_expr(aux_tensors is not None): - seqlen_q = cute.size(mQ.shape[0]) // ( - self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 - ) - seqlen_k = ( - cute.size(mK.shape[0]) - if const_expr(mPageTable is None) - else mK.shape[0] * mPageTable.shape[1] - ) - seqlen_q_divmod = FastDivmodDivisor(seqlen_q) - seqlen_k_divmod = FastDivmodDivisor(seqlen_k) - fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(softmax_scale, self.score_mod) + window_size_left = Int32(window_size_left) if window_size_left is not None else None + window_size_right = Int32(window_size_right) if window_size_right is not None else None + fastdiv_mods = utils.compute_fastdiv_mods(mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors, mPageTable) head_divmod = None if cutlass.const_expr(self.pack_gqa): @@ -722,6 +710,7 @@ class SharedStorage: tP_layout, sV_layout, sO_layout, + gmem_tiled_copy_Q, gmem_tiled_copy_O, tiled_mma_qk, tiled_mma_pv, @@ -752,7 +741,7 @@ def kernel( mSeqUsedQ: Optional[cute.Tensor], mSeqUsedK: Optional[cute.Tensor], mPageTable: Optional[cute.Tensor], - tma_atom_Q: cute.CopyAtom, + tma_atom_Q: Optional[cute.CopyAtom], tma_atom_K: Optional[cute.CopyAtom], tma_atom_V: Optional[cute.CopyAtom], tma_atom_O: Optional[cute.CopyAtom], @@ -767,6 +756,7 @@ def kernel( tP_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, sO_layout: cute.ComposedLayout, + gmem_tiled_copy_Q: Optional[cute.TiledCopy], gmem_tiled_copy_O: Optional[cute.TiledCopy], tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, @@ -814,7 +804,7 @@ def kernel( storage = smem.allocate(self.shared_storage) tmem_alloc_barrier = pipeline.NamedBarrier( - barrier_id=int(NamedBarrierFwd.TmemPtr), + barrier_id=int(NamedBarrierFwdSm100.TmemPtr), num_threads=cute.arch.WARP_SIZE * len( (self.mma_warp_id, *self.softmax0_warp_ids, @@ -833,8 +823,8 @@ def kernel( ThreadCooperativeGroup = partial(pipeline.CooperativeGroup, pipeline.Agent.Thread) mma_warp = ThreadCooperativeGroup(len([self.mma_warp_id])) - load_warps = ThreadCooperativeGroup(len(self.load_warp_ids)) tma_warp = ThreadCooperativeGroup(1) + load_threads = ThreadCooperativeGroup(len(self.load_warp_ids) * cute.arch.WARP_SIZE) softmax_warps = ThreadCooperativeGroup(len(self.softmax0_warp_ids)) softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE * len(self.softmax0_warp_ids)) # softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE) @@ -857,15 +847,25 @@ def kernel( softmax_correction_threads_cluster = ThreadCooperativeGroup( cute.arch.WARP_SIZE * len(self.softmax0_warp_ids + self.correction_warp_ids) * self.cta_group_size ) - pipeline_q = pipeline_custom.PipelineTmaUmma.create( - barrier_storage=storage.mbar_load_Q.data_ptr(), - num_stages=self.q_stage, - producer_group=tma_warp, - consumer_group=mma_warp, - tx_count=self.tma_copy_bytes["Q"], - cta_layout_vmnk=cta_layout_vmnk, - defer_sync=True, - ) + if const_expr(self.use_tma_Q): + pipeline_q = pipeline_custom.PipelineTmaUmma.create( + barrier_storage=storage.mbar_load_Q.data_ptr(), + num_stages=self.q_stage, + producer_group=tma_warp, + consumer_group=mma_warp, + tx_count=self.tma_copy_bytes["Q"], + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + else: + pipeline_q = pipeline_custom.PipelineAsyncUmma.create( + barrier_storage=storage.mbar_load_Q.data_ptr(), + num_stages=self.q_stage, + producer_group=load_threads, + consumer_group=mma_warp, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) if const_expr(self.use_tma_KV): pipeline_kv = pipeline_custom.PipelineTmaUmma.create( barrier_storage=storage.mbar_load_KV.data_ptr(), @@ -877,13 +877,10 @@ def kernel( defer_sync=True, ) else: - cpasync_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len(self.load_warp_ids) * cute.arch.WARP_SIZE - ) pipeline_kv = pipeline.PipelineAsyncUmma.create( barrier_storage=storage.mbar_load_KV.data_ptr(), num_stages=self.kv_stage, - producer_group=cpasync_producer_group, + producer_group=load_threads, consumer_group=mma_warp, cta_layout_vmnk=cta_layout_vmnk, defer_sync=True, @@ -938,7 +935,7 @@ def kernel( ) # Should put the NamedBarrier inside the pipeline class so we'll just have pipeline_sm_stats sm_stats_barrier = pipeline_custom.NamedBarrier( - barrier_id=int(NamedBarrierFwd.SoftmaxStatsW0), num_threads=cute.arch.WARP_SIZE * 2 + barrier_id=int(NamedBarrierFwdSm100.SoftmaxStatsW0), num_threads=cute.arch.WARP_SIZE * 2 ) pipeline_o_epi = None if const_expr(not self.use_correction_warps_for_epi): @@ -1019,17 +1016,69 @@ def kernel( window_size_right=window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) - TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) - # Cluster wait before tensor memory alloc pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + if const_expr(self.use_clc_scheduler): + clc_response_ptr = storage.clc_response.data_ptr() + clc_mbar_ptr = storage.clc_mbar_ptr.data_ptr() + + clc_pipeline_producer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread + ) + num_clc_consumer_warps_per_cta = self.threads_per_cta // cute.arch.WARP_SIZE + # NB on CTA0 warp15 == scheduler on CTA1 == empty but still both consume + num_clc_consumer_warps = num_clc_consumer_warps_per_cta * self.cta_group_size + clc_pipeline_consumer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread, cute.arch.WARP_SIZE * num_clc_consumer_warps + ) + + block_idx = cute.arch.block_idx() + clc = ClcState.create( + hw_scheduler=ClcDynamicPersistentTileScheduler.create( + self.tile_scheduler_cls.clc_problem_shape(tile_sched_params), + block_idx, + cute.arch.grid_dim(), + clc_response_ptr, + ), + pipeline=cutlass_pipeline.PipelineClcFetchAsync.create( + barrier_storage=clc_mbar_ptr, + num_stages=self.sched_stages, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=16, + cta_layout_vmnk=cta_layout_vmnk, + ), + consumer_state=cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, self.sched_stages + ), + producer_state=cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Producer, self.sched_stages + ), + ) + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params, clc=clc) + else: + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) + assert isinstance(tile_scheduler, TileSchedulerProtocol), f"tile_scheduler is not a TileSchedulerProtocol: {type(tile_scheduler)}" + # /////////////////////////////////////////////////////////////////////////////// - # EMPTY + # EMPTY / CLC SCHEDULER WARP # /////////////////////////////////////////////////////////////////////////////// - for i in cutlass.range_constexpr(len(self.empty_warp_ids)): - if warp_idx == self.empty_warp_ids[i]: + if const_expr(self.use_clc_scheduler): + if warp_idx == self.clc_scheduler_warp_id: cute.arch.setmaxregister_decrease(self.num_regs_other) + if is_leader_cta: + self.clc_scheduler_warp(tile_scheduler) + else: + self.empty_warp(tile_scheduler) + for i in cutlass.range_constexpr(len(self.empty_warp_ids)): + if warp_idx == self.empty_warp_ids[i] and warp_idx != self.clc_scheduler_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_other) + self.empty_warp(tile_scheduler) + else: + for i in cutlass.range_constexpr(len(self.empty_warp_ids)): + if warp_idx == self.empty_warp_ids[i]: + cute.arch.setmaxregister_decrease(self.num_regs_other) # /////////////////////////////////////////////////////////////////////////////// # LOAD @@ -1049,13 +1098,14 @@ def kernel( tma_atom_Q, tma_atom_K, tma_atom_V, + gmem_tiled_copy_Q, pipeline_q, pipeline_kv, block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, blocksparse_tensors, + tile_scheduler=tile_scheduler, ) # /////////////////////////////////////////////////////////////////////////////// @@ -1085,8 +1135,8 @@ def kernel( block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, blocksparse_tensors, + tile_scheduler=tile_scheduler, ) # Dealloc the tensor memory buffer tmem.relinquish_alloc_permit() @@ -1108,8 +1158,8 @@ def kernel( block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, mma_tile_coord_v, + tile_scheduler=tile_scheduler, ) # /////////////////////////////////////////////////////////////////////////////// @@ -1141,11 +1191,11 @@ def kernel( num_splits=num_splits, SeqlenInfoCls=SeqlenInfoCls, AttentionMaskCls=AttentionMaskCls, - TileSchedulerCls=TileSchedulerCls, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, head_divmod=head_divmod, blocksparse_tensors=blocksparse_tensors, + tile_scheduler=tile_scheduler, ) if const_expr(not self.s0_s1_barrier): @@ -1189,8 +1239,8 @@ def kernel( block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, blocksparse_tensors, + tile_scheduler=tile_scheduler, ) tmem_alloc_barrier.arrive() @@ -1208,35 +1258,38 @@ def load( sK: cute.Tensor, sV: cute.Tensor, mPageTable: Optional[cute.Tensor], - tma_atom_Q: cute.CopyAtom, + tma_atom_Q: Optional[cute.CopyAtom], tma_atom_K: Optional[cute.CopyAtom], tma_atom_V: Optional[cute.CopyAtom], + gmem_tiled_copy_Q: Optional[cute.TiledCopy], pipeline_q: pipeline.PipelineAsync, pipeline_kv: pipeline.PipelineAsync, block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors], + tile_scheduler: TileSchedulerProtocol, ): num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE tidx = cute.arch.thread_idx()[0] % num_load_threads warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + issue_kv_for_this_warp = ( + const_expr(not self.use_tma_KV or len(self.load_warp_ids) == 1) or + warp_idx == self.load_warp_ids[0] + ) + issue_q_for_this_warp = ( + const_expr(not self.use_tma_Q or len(self.load_warp_ids) == 1) or + warp_idx == self.load_warp_ids[0] + ) q_producer_phase = Int32(1) kv_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.kv_stage ) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] - tiler_gQ = ((self.mma_tiler_qk[0] * self.q_stage), self.head_dim_padded) - gQ = cute.local_tile(mQ_cur, tiler_gQ, (m_block, 0)) # (128 * 2, 128) - gQ = layout_utils.select( - cute.flat_divide(gQ, (self.mma_tiler_qk[0],)), mode=[0, 2, 1] - ) # (128, 128, 2) head_idx_kv = ( head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx @@ -1258,12 +1311,32 @@ def load( gV = cute.local_tile( mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None) ) - tSgQ = thr_mma_qk.partition_A(gQ) tSgK = thr_mma_qk.partition_B(gK) tOgV = thr_mma_pv.partition_B(gV) - load_Q_fn, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ - ) + if const_expr(self.use_tma_Q): + tiler_gQ = ((self.mma_tiler_qk[0] * self.q_stage), self.head_dim_padded) + gQ = cute.local_tile(mQ_cur, tiler_gQ, (m_block, 0)) # (128 * 2, 128) + gQ = layout_utils.select( + cute.flat_divide(gQ, (self.mma_tiler_qk[0],)), mode=[0, 2, 1] + ) # (128, 128, 2) + tSgQ = thr_mma_qk.partition_A(gQ) + load_Q_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ + ) + load_Q = partial(self.load_Q, load_Q_fn, pipeline_q=pipeline_q, phase=q_producer_phase) + else: + assert gmem_tiled_copy_Q is not None + load_Q = partial( + self.load_Q_non_tma, + mQ_cur, + sQ, + gmem_tiled_copy_Q, + pipeline_q, + tidx, + seqlen.seqlen_q, + m_block, + phase=q_producer_phase, + ) if const_expr(self.use_tma_KV): tKsK, tKgK = cpasync.tma_partition( @@ -1302,7 +1375,6 @@ def load( tKsK, tKgK = None, None tVsV, tVgV = None, None - load_Q = partial(self.load_Q, load_Q_fn, pipeline_q=pipeline_q, phase=q_producer_phase) load_K = partial( self.load_KV, tma_atom_K, @@ -1337,24 +1409,19 @@ def load( ) if const_expr(not self.use_tma_KV): paged_kv_manager.load_page_table(n_block_first) - load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 + if issue_kv_for_this_warp: + load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 # load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx, extra_tx_count=self.tma_copy_bytes["Q"]) # K0 - if const_expr(len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0]: - # load_Q(block=0, stage=0) # Q0 - pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase) - # pipeline_q.sync_object_empty.wait(0, q_producer_phase) - tma_bar_ptr = pipeline_q.sync_object_full.get_barrier(0) - # tma_bar_ptr = pipeline_kv.producer_get_barrier(kv_producer_state) - load_Q_fn(src_idx=0, dst_idx=0, tma_bar_ptr=tma_bar_ptr) - kv_producer_state.advance() - if const_expr(self.q_stage == 2) and (const_expr(len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0]): - # load_Q(block=1, stage=1) # Q1 - pipeline_q.producer_acquire_w_index_phase(1, q_producer_phase) - tma_bar_ptr = pipeline_q.sync_object_full.get_barrier(1) - load_Q_fn(src_idx=1, dst_idx=1, tma_bar_ptr=tma_bar_ptr) + if issue_q_for_this_warp: + load_Q(block=0, stage=0) + if issue_kv_for_this_warp: + kv_producer_state.advance() + if const_expr(self.q_stage == 2) and issue_q_for_this_warp: + load_Q(block=1, stage=1) q_producer_phase ^= 1 - load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 - kv_producer_state.advance() + if issue_kv_for_this_warp: + load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 + kv_producer_state.advance() for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): n_block = n_block_max - 2 - i page_idx = ( @@ -1365,10 +1432,11 @@ def load( if const_expr(not self.use_tma_KV): paged_kv_manager.load_page_table(n_block) # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) - load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki - kv_producer_state.advance() - load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi - kv_producer_state.advance() + if issue_kv_for_this_warp: + load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki + kv_producer_state.advance() + load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi + kv_producer_state.advance() else: kv_producer_state, q_producer_phase = produce_block_sparse_loads_sm100( @@ -1387,14 +1455,14 @@ def load( self.q_subtile_factor if self.q_subtile_factor is not None else 1, ) - tile_scheduler.prefetch_next_work() - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + + work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop - pipeline_kv.producer_tail(kv_producer_state) - # This is equivalent to pipeline_q.producer_tail - if const_expr(len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0]: + if issue_kv_for_this_warp: + pipeline_kv.producer_tail(kv_producer_state) + # This is equivalent to pipeline_q.producer_tail for the TMA-Q producer warp. + if issue_q_for_this_warp: pipeline_q.producer_acquire_w_index_phase(self.q_stage - 1, q_producer_phase) @cute.jit @@ -1417,8 +1485,8 @@ def mma( block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors], + tile_scheduler=None, ): tSrQ = tiled_mma_qk.make_fragment_A(sQ) tSrK = tiled_mma_qk.make_fragment_B(sK) @@ -1507,7 +1575,6 @@ def mma( ) P_full_O_rescaled_phase = Int32(0) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -1678,8 +1745,7 @@ def mma( # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop # We don't need pipeline_s_p_o.producer_tail() since there's no dangling mbarrier at the end @@ -1708,11 +1774,11 @@ def softmax_loop( num_splits: Int32, SeqlenInfoCls: Callable, AttentionMaskCls: Callable, - TileSchedulerCls: Callable, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), head_divmod=None, blocksparse_tensors: Optional[BlockSparseTensors] = None, + tile_scheduler=None, ): """Compute softmax on attention scores from QK matrix multiplication. @@ -1772,7 +1838,6 @@ def softmax_loop( warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -2015,8 +2080,7 @@ def softmax_loop( # gLSE[tidx] = lse # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop # This is equivalent to pipeline_sm_stats.producer_tail @@ -2186,8 +2250,8 @@ def correction_loop( block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors] = None, + tile_scheduler=None, ): tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 @@ -2217,7 +2281,6 @@ def correction_loop( o_corr_consumer_phase = Int32(0) corr_epi_producer_phase = Int32(1) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -2228,12 +2291,14 @@ def correction_loop( mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] else: mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx] - tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded) - gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128) - gO = layout_utils.select( - cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1] - ) # (128, 128, 2) - gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None] + gO = None + if const_expr(self.use_tma_O or not self.pack_gqa): + tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded) + gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128) + gO = layout_utils.select( + cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1] + ) # (128, 128, 2) + gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None] # Default LSE to -inf for invalid split_idx tiles stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage @@ -2334,6 +2399,7 @@ def correction_loop( pipeline_o_acc.consumer_wait_w_index_phase(stage, o_corr_consumer_phase) if const_expr(not self.use_correction_warps_for_epi): pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase) + gO_stage = gO[None, None, stage] if const_expr(gO is not None) else None self.correction_epilogue( thr_mma_pv, tOtO[None, None, None, stage], @@ -2344,7 +2410,7 @@ def correction_loop( scale, sO[None, None, stage], mO_cur, - gO[None, None, stage], + gO_stage, gmem_tiled_copy_O, ) # Signal for the next work tile that O buffers in tmem are already read, so @@ -2414,7 +2480,6 @@ def correction_loop( mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) for stage in cutlass.range_constexpr(self.q_stage): m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v - gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_tile_idx,)) row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage] # if tidx == 0 and stage <= 1: # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan) @@ -2429,13 +2494,24 @@ def correction_loop( if const_expr(not self.pack_gqa) else seqlen.seqlen_q * self.qhead_per_kvhead ) - if tidx < seqlen_q - m_tile_idx * self.m_block_size: - # This actually just works with PackGQA too - gLSE[tidx] = lse + if const_expr(not self.pack_gqa or self.m_block_size % self.qhead_per_kvhead == 0): + gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_tile_idx,)) + if tidx < seqlen_q - m_tile_idx * self.m_block_size: + # This actually just works with PackGQA too + gLSE[tidx] = lse + else: + idx = m_tile_idx * self.m_block_size + tidx + if idx < seqlen_q: + m_idx = idx // self.qhead_per_kvhead + h_idx = idx - m_idx * self.qhead_per_kvhead + lse_ptr_i64 = utils.elem_pointer(mLSE_cur, ((h_idx, m_idx),)).toint() + lse_gmem_ptr = cute.make_ptr( + mLSE_cur.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4 + ) + cute.make_tensor(lse_gmem_ptr, (1,))[0] = lse # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop # This is equivalent to pipeline_o_epi.consumer_tail() for the correction warps @@ -2574,7 +2650,7 @@ def correction_epilogue( if const_expr(self.use_correction_warps_for_epi): assert(not self.use_tma_O) assert(gmem_tiled_copy_O is not None) - cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), + cute.arch.barrier(barrier_id=int(NamedBarrierFwdSm100.Epilogue), number_of_threads=len(self.epilogue_warp_ids) * cute.arch.WARP_SIZE) mma_tile_coord_v = thr_mma.thr_idx m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v @@ -2586,7 +2662,7 @@ def correction_epilogue( def _store_O_to_gmem( self, sO_stage: cute.Tensor, - gO: cute.Tensor, + gO: Optional[cute.Tensor], mO_cur: cute.Tensor, gmem_tiled_copy_O: cute.TiledCopy, tidx: Int32, @@ -2597,7 +2673,6 @@ def _store_O_to_gmem( gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) tOsO = gmem_thr_copy_O.partition_S(sO_stage) cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) - tOgO = gmem_thr_copy_O.partition_D(gO) tOcO = gmem_thr_copy_O.partition_S(cO) t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) tOpO = copy_utils.predicate_k(tOcO, limit=mO_cur.shape[1]) @@ -2613,6 +2688,8 @@ def _store_O_to_gmem( cute.autovec_copy(tOsO, tOrO) # copy acc O from rmem to gmem if const_expr(not self.pack_gqa): + assert gO is not None + tOgO = gmem_thr_copy_O.partition_D(gO) for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): if ( t0OcO[0, rest_m, 0][0] < seqlen_q - m_tile_idx * self.m_block_size - tOcO[0][0] @@ -2641,11 +2718,10 @@ def epilogue_s2g( block_info: BlockInfo, num_splits: int, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, mma_tile_coord_v: Int32 = 0, + tile_scheduler=None, ): epi_consumer_phase = Int32(0) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -2657,12 +2733,14 @@ def epilogue_s2g( mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] else: mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx] - tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded) - gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128) - gO = layout_utils.select( - cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1] - ) # (128, 128, 2) - gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None] + gO = None + if const_expr(self.use_tma_O or not self.pack_gqa): + tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded) + gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128) + gO = layout_utils.select( + cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1] + ) # (128, 128, 2) + gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None] if const_expr(self.use_tma_O): store_O, _, _ = copy_utils.tma_get_copy_fn( @@ -2689,8 +2767,9 @@ def epilogue_s2g( pipeline_o_epi.consumer_wait_w_index_phase(stage, epi_consumer_phase) # 2. copy O0 / O1 to gmem m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v + gO_stage = gO[None, None, stage] if const_expr(gO is not None) else None self._store_O_to_gmem( - sO[None, None, stage], gO[None, None, stage], mO_cur, gmem_tiled_copy_O, + sO[None, None, stage], gO_stage, mO_cur, gmem_tiled_copy_O, tidx, seqlen.seqlen_q, m_tile_idx, ) pipeline_o_epi.consumer_release_w_index(stage) @@ -2698,8 +2777,39 @@ def epilogue_s2g( epi_consumer_phase ^= 1 # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.advance_to_next_work() + + @cute.jit + def clc_scheduler_warp( + self, + tile_scheduler: TileSchedulerProtocol, + ): + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + tile_scheduler.prefetch_next_work() + work_tile = tile_scheduler.advance_to_next_work() + if cute.arch.thread_idx()[0] == self.clc_scheduler_warp_id * cute.arch.WARP_SIZE: + fa_printf( + 3, + "[CLC] query sm={} cta={} (m_blk={},h={},b={},s={}) valid={}\n", + smid(), + cute.arch.block_idx()[0], + work_tile.tile_idx[0], + work_tile.tile_idx[1], + work_tile.tile_idx[2], + work_tile.tile_idx[3], + work_tile.is_valid_tile, + ) + tile_scheduler.producer_tail() + + @cute.jit + def empty_warp( + self, + tile_scheduler: TileSchedulerProtocol, + ): + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_tile = tile_scheduler.advance_to_next_work() def load_Q( self, @@ -2712,6 +2822,39 @@ def load_Q( pipeline_q.producer_acquire_w_index_phase(stage, phase) load_Q_fn(src_idx=block, dst_idx=stage, tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(stage)) + def load_Q_non_tma( + self, + mQ: cute.Tensor, + sQ: cute.Tensor, + gmem_tiled_copy_Q: cute.TiledCopy, + pipeline_q: pipeline.PipelineAsync, + tidx: Int32, + seqlen_q: Int32, + m_block: Int32, + block: Int32, + stage: int, + phase: Int32, + ): + assert self.cta_group_size == 1, "cta_group_size must be 1 for non-tma Q load" + pipeline_q.producer_acquire_w_index_phase(stage, phase) + pack_gqa = PackGQA( + self.m_block_size, + self.head_dim_padded, + self.check_hdim_oob, + self.qhead_per_kvhead, + ) + sQ_stage = sQ[None, None, None, stage] + sQ_pi = cute.make_tensor( + sQ_stage.iterator, + cute.make_layout( + (sQ_stage.shape[0][0], (sQ_stage.shape[0][1], sQ_stage.shape[2])), + stride=(sQ_stage.stride[0][0], (sQ_stage.stride[0][1], sQ_stage.stride[2])), + ), + ) + pack_gqa.load_Q(mQ, sQ_pi, gmem_tiled_copy_Q, tidx, m_block * self.q_stage + block, seqlen_q) + cute.arch.cp_async_commit_group() + pipeline_q.sync_object_full.arrive_cp_async_mbarrier(stage) + @cute.jit def load_KV( self, @@ -2754,7 +2897,10 @@ def load_KV( else: assert paged_kv_manager is not None assert extra_tx_count is None - paged_kv_manager.load_KV(block, sX[None, None, None, stage], K_or_V) + sX_cur = sX[None, None, None, stage] + if const_expr(self.uneven_kv_smem): + sX_cur = self.offset_kv_smem(sX_cur, stage, phase ^ 1) + paged_kv_manager.load_KV(block, sX_cur, K_or_V) cute.arch.cp_async_commit_group() pipeline_kv.sync_object_full.arrive_cp_async_mbarrier(stage) @@ -2765,6 +2911,9 @@ def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): # (smem_large + smem_small) // 2. So for stage == 1, move right by offset if # phase == 0, or left by offset if phase == 1. offset = 0 if stage != 1 else self.uneven_kv_smem_offset * (1 - 2 * phase) + # Hint that the offset is 128-bit aligned so that + # ptr + offset preserves the alignment needed by cp.async. + offset = cute.assume(offset, divby=128 // self.k_dtype.width) return cute.make_tensor(sX.iterator + offset, sX.layout) else: return sX @@ -2774,12 +2923,12 @@ def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): # warp_group_idx = utils.canonical_warp_group_idx(sync=False) # if warp_group_idx == 0: # cute.arch.barrier_arrive( - # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), number_of_threads=2 * 128, + # barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1), number_of_threads=2 * 128, # ) # def warp_scheduler_barrier_sync(self): # cute.arch.barrier( - # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False), + # barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False), # number_of_threads=2 * 128 # ) @@ -2787,7 +2936,7 @@ def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): # cur_wg = utils.canonical_warp_group_idx(sync=False) # next_wg = 1 - cur_wg # cute.arch.barrier_arrive( - # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128, + # barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128, # ) @cute.jit diff --git a/flash-attn4/torch-ext/flash_attn4/flash_fwd_sm120.py b/flash-attn4/torch-ext/flash_attn4/flash_fwd_sm120.py new file mode 100644 index 00000000..ce2ab4e9 --- /dev/null +++ b/flash-attn4/torch-ext/flash_attn4/flash_fwd_sm120.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# SM120 (Blackwell GeForce / DGX Spark) forward pass. +# +# SM120 uses the same SM80-era MMA instructions (mma.sync.aligned.m16n8k16) but has +# a smaller shared memory capacity (99 KB vs 163 KB on SM80). This module subclasses +# FlashAttentionForwardSm80 and overrides the SMEM capacity check accordingly. + +import cutlass +import cutlass.utils as utils_basic + +from .flash_fwd import FlashAttentionForwardSm80 + + +class FlashAttentionForwardSm120(FlashAttentionForwardSm80): + # Keep arch = 80 to use CpAsync code paths (no TMA for output). + # The compilation target is determined by the GPU at compile time, not this field. + arch = 80 + + @staticmethod + def can_implement( + dtype, + head_dim, + head_dim_v, + tile_m, + tile_n, + num_stages, + num_threads, + is_causal, + Q_in_regs=False, + ) -> bool: + """Check if the kernel can be implemented on SM120. + + Same logic as SM80 but uses SM120's shared memory capacity (99 KB). + """ + if dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if head_dim % 8 != 0: + return False + if head_dim_v % 8 != 0: + return False + if tile_n % 16 != 0: + return False + if num_threads % 32 != 0: + return False + # Shared memory usage: Q tile + (K tile + V tile) + smem_usage_Q = tile_m * head_dim * 2 + smem_usage_K = tile_n * head_dim * num_stages * 2 + smem_usage_V = tile_n * head_dim_v * num_stages * 2 + smem_usage_QV = ( + (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V) + ) + smem_usage = smem_usage_QV + smem_usage_K + # SM120 has 99 KB shared memory (vs 163 KB on SM80) + smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_120") + if smem_usage > smem_capacity: + return False + if (tile_m * 2) % num_threads != 0: + return False + return True diff --git a/flash-attn4/torch-ext/flash_attn4/flash_fwd_sm90.py b/flash-attn4/torch-ext/flash_attn4/flash_fwd_sm90.py new file mode 100644 index 00000000..9f4c4c00 --- /dev/null +++ b/flash-attn4/torch-ext/flash_attn4/flash_fwd_sm90.py @@ -0,0 +1,1534 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# SM90 (Hopper) forward pass for flash attention, extracted from flash_fwd.py. + +from types import SimpleNamespace +from typing import Callable, Literal, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync, warpgroup +from cutlass.utils import LayoutEnum +import cutlass.utils.hopper_helpers as sm90_utils_basic +from cutlass import pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass.base_dsl.arch import Arch + +from .quack import copy_utils +from .quack import layout_utils +from .quack import sm90_utils + +from .cute_dsl_utils import assume_tensor_aligned +from . import utils +from .mask import AttentionMask +from .softmax import Softmax, apply_score_mod_inner +from .seqlen_info import SeqlenInfoQK +from .block_info import BlockInfo +from .block_sparsity import BlockSparseTensors +from .block_sparse_utils import ( + produce_block_sparse_loads, + consume_block_sparse_loads, +) +from . import pipeline as pipeline_custom +from .pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom +from .paged_kv import PagedKVManager +from .named_barrier import NamedBarrierFwd +from .quack.cute_dsl_utils import ParamsBase +from .tile_scheduler import ( + TileSchedulerArguments, + SingleTileScheduler, + SingleTileLPTScheduler, + SingleTileVarlenScheduler, +) +from cutlass.cute import FastDivmodDivisor + +from .flash_fwd import FlashAttentionForwardBase + + +class FlashAttentionForwardSm90(FlashAttentionForwardBase): + def __init__( + self, + *args, + intra_wg_overlap: bool = True, + mma_pv_is_rs: bool = True, + paged_kv_non_tma: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.intra_wg_overlap = intra_wg_overlap + self.mma_pv_is_rs = mma_pv_is_rs + self.buffer_align_bytes = 1024 + self.use_tma_KV = not paged_kv_non_tma + assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), ( + "Paged KV does not support irregular head dim" + ) + self.cluster_shape_mn = (1, 1) + assert self.arch >= Arch.sm_90 and self.arch <= Arch.sm_90a, "Only SM 9.x is supported" + + def _get_smem_layout_atom(self): + sQ_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom(LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim), + self.dtype, + ) + sK_layout_atom = sQ_layout_atom + sV_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom( + LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv + ), + self.dtype, + ) + sO_layout_atom = sV_layout_atom + if not self.mma_pv_is_rs: + sP_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom( + LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n + ), + self.dtype, + ) + else: + sP_layout_atom = None + return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom + + def _get_tiled_mma(self): + tiled_mma_qk = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.K, + Float32, + atom_layout_mnk=(self.tile_m // 64, 1, 1), + tiler_mn=(64, self.tile_n), + ) + tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.MN, + Float32, + atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 + tiler_mn=(64, self.tile_hdimv), + a_source=warpgroup.OperandSource.RMEM + if self.mma_pv_is_rs + else warpgroup.OperandSource.SMEM, + ) + return tiled_mma_qk, tiled_mma_pv + + def _get_shared_storage_cls(self): + sQ_struct, sK_struct, sV_struct = [ + cute.struct.Align[ + cute.struct.MemRange[self.dtype, cute.cosize(layout)], self.buffer_align_bytes + ] + for layout in (self.sQ_layout, self.sK_layout, self.sV_layout) + ] + cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) + sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] + cosize_sP = cute.cosize(self.sP_layout) if const_expr(self.sP_layout is not None) else 0 + sP_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024] + # 1 stage * 2 for Q pipeline (full + empty), self.num_stages*2 for K, self.num_stages*2 for V, + mbar_ptr_Q_struct = cute.struct.MemRange[cutlass.Int64, 1 * 2] + mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + + @cute.struct + class SharedStorageQKV: + mbar_ptr_Q: mbar_ptr_Q_struct + mbar_ptr_K: mbar_ptr_K_struct + mbar_ptr_V: mbar_ptr_V_struct + sV: sV_struct + sQ: sQ_struct + sK: sK_struct + sP: sP_struct + + @cute.struct + class SharedStorageSharedQV: + mbar_ptr_Q: mbar_ptr_Q_struct + mbar_ptr_K: mbar_ptr_K_struct + mbar_ptr_V: mbar_ptr_V_struct + sQ: sQV_struct + sK: sK_struct + sP: sP_struct + + return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table + mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table + mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + mLSE: Optional[cute.Tensor], + softmax_scale: Float32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + learnable_sink: Optional[cute.Tensor] = None, + blocksparse_tensors: Optional[BlockSparseTensors] = None, + aux_tensors: Optional[list] = None, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, + ): + """Configures and launches the flash attention kernel. + + mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: + (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) + """ + + self._check_type( + *( + t.element_type if t is not None else None + for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK) + ) + ) + + self.varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None + + mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)] + QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + mQ, mO = [layout_utils.select(t, QO_layout_transpose) for t in (mQ, mO)] + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] + mK, mV = [layout_utils.select(t, KV_layout_transpose) for t in (mK, mV)] + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mLSE = ( + layout_utils.select(mLSE, LSE_layout_transpose) + if const_expr(mLSE is not None) + else None + ) + + tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() + self.num_mma_threads = tiled_mma_qk.size + self.num_threads_per_warp_group = 128 + self.num_wg_mma = self.num_mma_threads // self.num_threads_per_warp_group + assert self.num_wg_mma in [1, 2, 3] + self.num_threads = self.num_threads_per_warp_group * (self.num_wg_mma + 1) + self.num_producer_threads = 32 + self.num_Q_load_threads = self.num_threads_per_warp_group # If not TMA_Q + self.num_epilogue_threads = self.num_mma_threads + self.num_mma_regs, self.num_producer_regs = {1: (256, 56), 2: (240, 24), 3: (160, 32)}[ + self.num_wg_mma + ] + self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + + self.use_scheduler_barrier = ( + (self.num_wg_mma >= 2 and self.tile_hdim <= 128) + if const_expr(self.intra_wg_overlap) + else (self.num_wg_mma == 2) + ) + self.use_tma_Q = self.arch >= Arch.sm_90 and not ( + self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0 + ) + self.use_tma_O = self.use_tma_Q + # Producer needs more registers when doing cp.async Q or KV loads + if const_expr(self.num_wg_mma == 2 and (not self.use_tma_Q or not self.use_tma_KV)): + self.num_mma_regs, self.num_producer_regs = 224, 40 + self.rescale_O_before_gemm = self.tile_hdimv > 128 and self.intra_wg_overlap + self._setup_attributes() + # TODO: we prob don't need most of what's in _setup_attributes + self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout = [ + sm90_utils.make_smem_layout(mX.element_type, LayoutEnum.ROW_MAJOR, shape, stage) + for mX, shape, stage in [ + (mQ, (self.tile_m, self.tile_hdim), None), + (mK, (self.tile_n, self.tile_hdim), self.num_stages), + (mV, (self.tile_n, self.tile_hdimv), self.num_stages), + (mO, (self.tile_m, self.tile_hdimv), None), + ] + ] + self.sP_layout = None + if const_expr(not self.mma_pv_is_rs): + self.sP_layout = sm90_utils.make_smem_layout( + mV.element_type, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n) + ) + + SharedStorage = self._get_shared_storage_cls() + + mQ_og, mO_og = mQ, mO + if const_expr(self.pack_gqa): + nheads_kv = mK.shape[2] + mQ = pack_gqa_layout(mQ, self.qhead_per_kvhead, nheads_kv, head_idx=2) + mO = pack_gqa_layout(mO, self.qhead_per_kvhead, nheads_kv, head_idx=2) + if const_expr(mLSE is not None): + mLSE = pack_gqa_layout(mLSE, self.qhead_per_kvhead, nheads_kv, head_idx=1) + + # TMA + gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp() + gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast + gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp() + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1])) + for name, mX, layout in [ + ("Q", mQ, self.sQ_layout), + ("K", mK, self.sK_layout), + ("V", mV, self.sV_layout), + ] + } + make_tiled_tma_atom_fn = ( + partial(make_packgqa_tiled_tma_atom, qhead_per_kvhead=self.qhead_per_kvhead, head_idx=2) + if const_expr(self.pack_gqa) + else cpasync.make_tiled_tma_atom + ) + tma_atom_Q, tma_tensor_Q = None, None + if const_expr(self.use_tma_Q): + tma_atom_Q, tma_tensor_Q = make_tiled_tma_atom_fn( + gmem_tiled_copy_Q, + mQ_og if const_expr(self.pack_gqa) else mQ, + self.sQ_layout, + (self.tile_m, self.tile_hdim), # No mcast + ) + tma_atom_K, tma_tensor_K = None, None + tma_atom_V, tma_tensor_V = None, None + if const_expr(self.use_tma_KV): + tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_KV, + mK, + cute.select(self.sK_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdim), + 1, # No mcast for now + ) + tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_KV, + mV, + cute.select(self.sV_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdimv), + 1, # No mcast for now + ) + tma_atom_O, tma_tensor_O = None, None + if const_expr(self.use_tma_O): + mO_tma = mO_og if const_expr(self.pack_gqa) else mO + if const_expr(self.varlen_q): + mO_tma = copy_utils.create_ragged_tensor_for_tma( + mO_tma, ragged_dim=0, ptr_shift=True + ) + tma_atom_O, tma_tensor_O = make_tiled_tma_atom_fn( + gmem_tiled_copy_O, + mO_tma, + self.sO_layout, + (self.tile_m, self.tile_hdimv), # No mcast + ) + if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): + TileScheduler = SingleTileVarlenScheduler + else: + TileScheduler = ( + SingleTileScheduler + if const_expr(not self.is_causal or self.is_local) + else SingleTileLPTScheduler + ) + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m), + cute.size(mQ.shape[2]), + cute.size(mQ.shape[3]) + if const_expr(mCuSeqlensQ is None) + else cute.size(mCuSeqlensQ.shape[0] - 1), + 1, # num_splits + cute.size(mK.shape[0]) + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1], + mQ.shape[1], + mV.shape[1], + total_q=cute.size(mQ.shape[0]) + if const_expr(mCuSeqlensQ is not None) + else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + tile_shape_mn=(self.tile_m, self.tile_n), + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + element_size=self.dtype.width // 8, + is_persistent=False, + lpt=self.is_causal or self.is_local, + ) + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2( + softmax_scale, self.score_mod + ) + window_size_left = Int32(window_size_left) if window_size_left is not None else None + window_size_right = Int32(window_size_right) if window_size_right is not None else None + fastdiv_mods = utils.compute_fastdiv_mods( + mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors, mPageTable + ) + + self.kernel( + tma_tensor_Q if const_expr(self.use_tma_Q) else mQ, + tma_tensor_K if const_expr(self.use_tma_KV) else mK, + tma_tensor_V if const_expr(self.use_tma_KV) else mV, + tma_tensor_O if const_expr(self.use_tma_O) else mO, + mLSE, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, + mPageTable, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_O, + softmax_scale_log2, + softmax_scale, + window_size_left, + window_size_right, + learnable_sink, + blocksparse_tensors, + self.sQ_layout, + self.sK_layout, + self.sV_layout, + self.sO_layout, + self.sP_layout, + self.gmem_tiled_copy_Q, + self.gmem_tiled_copy_K, + self.gmem_tiled_copy_V, + self.gmem_tiled_copy_O, + tiled_mma_qk, + tiled_mma_pv, + tile_sched_params, + TileScheduler, + SharedStorage, + aux_tensors, + fastdiv_mods, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + mPageTable: Optional[cute.Tensor], + tma_atom_Q: Optional[cute.CopyAtom], + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V: Optional[cute.CopyAtom], + tma_atom_O: Optional[cute.CopyAtom], + softmax_scale_log2: Float32, + softmax_scale: Optional[Float32], + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + learnable_sink: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + sP_layout: cute.ComposedLayout | None, + gmem_tiled_copy_Q: cute.TiledCopy, + gmem_tiled_copy_K: cute.TiledCopy, + gmem_tiled_copy_V: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], + SharedStorage: cutlass.Constexpr[Callable], + aux_tensors=Optional[list[cute.Tensor]], + fastdiv_mods=None, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + # Prefetch tma descriptor + if warp_idx == 0: + for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O): + if const_expr(tma_atom is not None): + cpasync.prefetch_descriptor(tma_atom) + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # Mbarrier / pipeline init + mbar_ptr_Q = storage.mbar_ptr_Q.data_ptr() + + ThreadCooperativeGroup = partial(pipeline.CooperativeGroup, pipeline.Agent.Thread) + tma_warp = ThreadCooperativeGroup(1) + load_threads = ThreadCooperativeGroup(self.num_threads_per_warp_group) + mma_warps = ThreadCooperativeGroup(self.num_mma_threads // cute.arch.WARP_SIZE) + if const_expr(self.use_tma_Q): + pipeline_q = pipeline_custom.PipelineTmaAsync.create( + barrier_storage=mbar_ptr_Q, + num_stages=1, + producer_group=tma_warp, + consumer_group=mma_warps, + tx_count=self.tma_copy_bytes["Q"], + defer_sync=True, + ) + else: + pipeline_q = pipeline_custom.PipelineCpAsync.create( + barrier_storage=mbar_ptr_Q, + num_stages=1, + producer_group=load_threads, + consumer_group=mma_warps, + defer_sync=True, + elect_one_release=True, + syncwarp_before_release=False, + ) + + if const_expr(self.use_tma_KV): + pipeline_k = pipeline_custom.PipelineTmaAsync.create( + barrier_storage=storage.mbar_ptr_K.data_ptr(), + num_stages=self.num_stages, + producer_group=tma_warp, + consumer_group=mma_warps, + tx_count=self.tma_copy_bytes["K"], + defer_sync=True, + ) + pipeline_v = pipeline_custom.PipelineTmaAsync.create( + barrier_storage=storage.mbar_ptr_V.data_ptr(), + num_stages=self.num_stages, + producer_group=tma_warp, + consumer_group=mma_warps, + tx_count=self.tma_copy_bytes["V"], + defer_sync=True, + ) + else: + pipeline_k = pipeline_custom.PipelineCpAsync.create( + barrier_storage=storage.mbar_ptr_K.data_ptr(), + num_stages=self.num_stages, + producer_group=load_threads, + consumer_group=mma_warps, + defer_sync=True, + elect_one_release=True, + syncwarp_before_release=False, + ) + pipeline_v = pipeline_custom.PipelineCpAsync.create( + barrier_storage=storage.mbar_ptr_V.data_ptr(), + num_stages=self.num_stages, + producer_group=load_threads, + consumer_group=mma_warps, + defer_sync=True, + elect_one_release=True, + syncwarp_before_release=False, + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + if const_expr(not self.Q_in_regs): + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + else: + sV = storage.sQ.get_tensor( + sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type + ) + # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma + sVt = layout_utils.transpose_view(sV) + sP = None + if const_expr(sP_layout is not None): + sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner) + # reuse sQ's data iterator + sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype) + + block_info = BlockInfo( + self.tile_m, + self.tile_n, + self.is_causal, + self.is_local, + False, # is_split_kv + window_size_left, + window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + SeqlenInfoCls = partial( + SeqlenInfoQK.create, + seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], + seqlen_k_static=mK.shape[0] + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1], + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + # Don't need to pass in tile_mn because we won't access offset_padded + ) + AttentionMaskCls = partial( + AttentionMask, + self.tile_m, + self.tile_n, + window_size_left=window_size_left, + window_size_right=window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) + + # Cluster wait before starting + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + if warp_idx < 4: # Producer + cute.arch.setmaxregister_decrease(self.num_producer_regs) + self.load( + mQ, + mK, + mV, + sQ, + sK, + sV, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + pipeline_k, + pipeline_v, + pipeline_q, + gmem_tiled_copy_Q, + mPageTable, + blocksparse_tensors, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + + else: # Consumer + cute.arch.setmaxregister_increase(self.num_mma_regs) + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + tidx, _, _ = cute.arch.thread_idx() + tidx = tidx - 128 + self.mma( + tiled_mma_qk, + tiled_mma_pv, + mO, + mLSE, + sQ, + sK, + sVt, + sP, + sO, + learnable_sink, + pipeline_k, + pipeline_v, + pipeline_q, + gmem_tiled_copy_O, + tma_atom_O, + tidx, + softmax_scale_log2, + softmax_scale, + block_info, + SeqlenInfoCls, + AttentionMaskCls, + TileSchedulerCls, + blocksparse_tensors, + aux_tensors, + fastdiv_mods, + ) + + @cute.jit + def load( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tma_atom_Q: Optional[cute.CopyAtom], + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V: Optional[cute.CopyAtom], + pipeline_k: pipeline.PipelineAsync, + pipeline_v: pipeline.PipelineAsync, + pipeline_q: pipeline.PipelineAsync, + gmem_tiled_copy_Q: cute.TiledCopy, + mPageTable: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + tidx, _, _ = cute.arch.thread_idx() + + # TMA: only warp 0 loads. cp_async: all warps load. + # When not use_tma_Q, all 128 producer threads participate in Q loading. + is_load_warp = warp_idx_in_wg == 0 or const_expr(not self.use_tma_KV or not self.use_tma_Q) + # KV loading restricted to warp 0 for TMA, all warps for non-TMA KV + is_kv_load_warp = warp_idx_in_wg == 0 or const_expr(not self.use_tma_KV) + + if is_load_warp: + q_producer_phase = Int32(1) + kv_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_stages + ) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + # if work_tile.is_valid_tile: + m_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + head_idx_kv = ( + head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + ) + + load_Q = None + if const_expr(self.use_tma_Q): + gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) + load_Q, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ, sQ, single_stage=True + ) + + paged_kv_manager = None + tma_load_K_fn = None + tma_load_V_fn = None + if const_expr(self.use_tma_KV): + # === TMA path (non-paged and paged with page_size == n_block_size) === + if const_expr(mPageTable is not None): + # Paged TMA: keep page dimension indexable + mK_cur = mK[None, None, head_idx_kv, None] + mV_cur = mV[None, None, head_idx_kv, None] + gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (0, 0, None)) + gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (0, 0, None)) + else: + # Non-paged TMA + mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[ + None, None, head_idx_kv + ] + mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[ + None, None, head_idx_kv + ] + gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0)) + gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (None, 0)) + # TODO: mcast + tma_load_K_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K, 0, cute.make_layout(1), gK, sK + ) + tma_load_K_fn = copy_utils.tma_producer_copy_fn(tma_load_K_fn, pipeline_k) + tma_load_V_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V, 0, cute.make_layout(1), gV, sV + ) + tma_load_V_fn = copy_utils.tma_producer_copy_fn(tma_load_V_fn, pipeline_v) + else: + # === cp_async path (paged KV with page_size != n_block_size) === + paged_kv_manager = PagedKVManager.create( + mPageTable, + mK, + mV, + FastDivmodDivisor(mK.shape[0]), + batch_idx, + head_idx_kv, + tidx, + seqlen.seqlen_k, + 0, # leftpad_k + self.tile_n, + self.tile_hdim, + self.tile_hdimv, + self.num_threads_per_warp_group, + mK.element_type, + arch=self.arch.major * 10 + self.arch.minor, + ) + + load_K = partial( + self.load_KV, + tma_load_K_fn, + paged_kv_manager, + sK, + pipeline_kv=pipeline_k, + K_or_V="K", + ) + load_V = partial( + self.load_KV, + tma_load_V_fn, + paged_kv_manager, + sV, + pipeline_kv=pipeline_v, + K_or_V="V", + ) + + pack_gqa = None + if const_expr(not self.use_tma_Q): + pack_gqa = PackGQA( + self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead + ) + + if const_expr(not self.use_block_sparsity): + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + # if cute.arch.thread_idx()[0] == 0: + # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) + # Clamp n_block to 0 when n_block_max == 0 (can happen with causal + # + pack_gqa when seqlen_k < tile_n). TMA handles n_block=-1 + # gracefully (fills zeros), but cp.async would crash on + # out-of-bounds page table access. + n_block = ( + n_block_max - 1 + if const_expr(self.use_tma_KV) + else cutlass.max(n_block_max - 1, 0) + ) + page_idx = ( + mPageTable[batch_idx, n_block] + if const_expr(mPageTable is not None and self.use_tma_KV) + else None + ) + + # First iteration: load K on pipeline_k, Q on pipeline_q + if is_kv_load_warp: + pipeline_k.producer_acquire(kv_producer_state) + if const_expr(not self.use_tma_KV): + paged_kv_manager.load_page_table(n_block) + load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) + if const_expr(self.use_tma_Q): + if warp_idx_in_wg == 0: + pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase) + load_Q(tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(0)) + q_producer_phase ^= 1 + else: + pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase) + pack_gqa.load_Q( + mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q + ) + cute.arch.cp_async_commit_group() + pipeline_q.producer_commit_w_index(0) + q_producer_phase ^= 1 + + if is_kv_load_warp: + if const_expr(not self.intra_wg_overlap or not self.use_tma_KV): + pipeline_v.producer_acquire(kv_producer_state) + load_V( + block=n_block, producer_state=kv_producer_state, page_idx=page_idx + ) + kv_producer_state.advance() + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block = n_block_max - 1 - i - 1 + page_idx = ( + mPageTable[batch_idx, n_block] + if const_expr(mPageTable is not None and self.use_tma_KV) + else None + ) + if const_expr(not self.use_tma_KV): + paged_kv_manager.load_page_table(n_block) + pipeline_k.producer_acquire(kv_producer_state) + load_K( + block=n_block, + producer_state=kv_producer_state, + page_idx=page_idx, + ) + pipeline_v.producer_acquire(kv_producer_state) + load_V( + block=n_block, + producer_state=kv_producer_state, + page_idx=page_idx, + ) + kv_producer_state.advance() + else: + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block_prev = n_block_max - i - 1 + n_block = n_block_prev - 1 + page_idx = ( + mPageTable[batch_idx, n_block] + if const_expr(mPageTable is not None) + else None + ) + page_idx_prev = ( + mPageTable[batch_idx, n_block_prev] + if const_expr(mPageTable is not None) + else None + ) + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K( + block=n_block, + producer_state=kv_producer_state, + page_idx=page_idx, + ) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V( + block=n_block_prev, + producer_state=kv_producer_state_prev, + page_idx=page_idx_prev, + ) + n_block = n_block_min + page_idx = ( + mPageTable[batch_idx, n_block] + if const_expr(mPageTable is not None) + else None + ) + pipeline_v.producer_acquire(kv_producer_state) + load_V( + block=n_block, producer_state=kv_producer_state, page_idx=page_idx + ) + kv_producer_state.advance() + else: + # Block sparsity: use TMA closures directly (not paged) + # Load Q on pipeline_q, separate from K/V pipeline + if const_expr(self.use_tma_Q): + if warp_idx_in_wg == 0: + pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase) + load_Q(tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(0)) + q_producer_phase ^= 1 + else: + pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase) + pack_gqa.load_Q( + mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q + ) + cute.arch.cp_async_commit_group() + pipeline_q.producer_commit_w_index(0) + q_producer_phase ^= 1 + if is_kv_load_warp: + kv_producer_state = produce_block_sparse_loads( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + kv_producer_state, + tma_load_K_fn, + tma_load_V_fn, + pipeline_k, + pipeline_v, + self.intra_wg_overlap, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + self.q_subtile_factor if self.q_subtile_factor is not None else 1, + ) + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + # Producer tail is only useful for cluster to avoid early exit of blocks. + # We only need producer_tail on V since that's the last that's loaded, we don't + # need it for Q (no cluster) and K. + if is_kv_load_warp: + pipeline_v.producer_tail(kv_producer_state) + + @cute.jit + def load_KV( + self, + tma_load_fn: Optional[Callable], + paged_kv_manager: Optional[PagedKVManager], + sX: cute.Tensor, + block: Int32, + pipeline_kv: pipeline.PipelineAsync, + producer_state: pipeline.PipelineState, + K_or_V: Literal["K", "V"], + page_idx: Optional[Int32] = None, + ): + if const_expr(self.use_tma_KV): + src_idx = block if const_expr(page_idx is None) else page_idx + tma_load_fn(src_idx=src_idx, producer_state=producer_state) + else: + paged_kv_manager.load_KV(block, sX[None, None, producer_state.index], K_or_V) + cute.arch.cp_async_commit_group() + pipeline_kv.producer_commit(producer_state) + + @cute.jit + def mma( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + sQ: cute.Tensor, + sK: cute.Tensor, + sVt: cute.Tensor, + sP: Optional[cute.Tensor], + sO: cute.Tensor, + learnable_sink: Optional[cute.Tensor], + pipeline_k: pipeline.PipelineAsync, + pipeline_v: pipeline.PipelineAsync, + pipeline_q: pipeline.PipelineAsync, + gmem_tiled_copy_O: cute.TiledCopy, + tma_atom_O: Optional[cute.CopyAtom], + tidx: Int32, + softmax_scale_log2: Float32, + softmax_scale: Optional[Float32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, + TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors], + aux_tensors: Optional[list], + fastdiv_mods=None, + ): + warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) + warp_group_thread_layout = cute.make_layout( + self.num_wg_mma, stride=self.num_threads_per_warp_group + ) + thr_mma_qk = tiled_mma_qk.get_slice(tidx) + wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)) + _, tSrQ, tSrK = sm90_utils.partition_fragment_ABC( + wg_mma_qk, (self.tile_m, self.tile_n, self.tile_hdim), sQ, sK + ) + mma_qk_fn = partial( + sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK + ) + acc_O, tOrP, tOrVt = sm90_utils.partition_fragment_ABC( + wg_mma_pv, (self.tile_m, self.tile_hdimv, self.tile_n), sP, sVt + ) + mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt) + + # /////////////////////////////////////////////////////////////////////////////// + # Smem copy atom tiling + # /////////////////////////////////////////////////////////////////////////////// + smem_copy_atom_P = utils.get_smem_store_atom( + self.arch.major * 10 + self.arch.minor, self.dtype + ) + smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx) + tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None + smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP) + + self.mma_init() + + q_consumer_phase = Int32(0) + kv_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_stages + ) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + softmax = Softmax.create( + softmax_scale_log2, + num_rows=acc_O.shape[0][0] * acc_O.shape[1], + softmax_scale=softmax_scale, + ) + + # For RescaleOBeforeGemm: persistent scores_scale across iterations + scores_scale = None + if const_expr(self.rescale_O_before_gemm): + scores_scale = cute.make_rmem_tensor_like(softmax.row_max, Float32) + + mma_one_n_block_all = partial( + self.mma_one_n_block_intrawg_overlap + if const_expr(self.intra_wg_overlap) + else self.mma_one_n_block, + mma_qk_fn=mma_qk_fn, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + acc_O=acc_O, + tOrP=tOrP, + smem_copy_params=smem_copy_params, + check_inf=True, + scores_scale=scores_scale, + ) + + process_first_half_block = partial( + self.first_half_block_overlap, + mma_qk_fn=mma_qk_fn, + pipeline_k=pipeline_k, + tOrP=tOrP, + smem_copy_params=smem_copy_params, + scores_scale=scores_scale, + softmax=softmax, + acc_O=acc_O, + ) + process_last_half_block = partial( + self.last_half_block_overlap, + pipeline_v=pipeline_v, + mma_pv_fn=mma_pv_fn, + scores_scale=scores_scale, + softmax=softmax, + acc_O=acc_O, + ) + while work_tile.is_valid_tile: + # if work_tile.is_valid_tile: + + # shape: (atom_v_m * rest_m) + m_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + + # Recompute fastdiv_mods if necessary for varlen with aux_tensors + recompute_fastdiv_mods_q = cutlass.const_expr( + aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q) + ) + recompute_fastdiv_mods_k = cutlass.const_expr( + aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k) + ) + if cutlass.const_expr(fastdiv_mods is not None): + seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods + fastdiv_mods = ( + seqlen_q_divmod + if not recompute_fastdiv_mods_q + else FastDivmodDivisor(seqlen.seqlen_q), + seqlen_k_divmod + if not recompute_fastdiv_mods_k + else FastDivmodDivisor(seqlen.seqlen_k), + ) + + mask = AttentionMaskCls(seqlen) + mask_fn = partial( + mask.apply_mask, + batch_idx=batch_idx, + head_idx=head_idx, + m_block=m_block, + thr_mma=thr_mma_qk, + mask_causal=self.is_causal, + mask_local=self.is_local, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + score_mod_fn = None + if const_expr(self.score_mod is not None): + score_mod_fn = partial( + self.apply_score_mod, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + softmax_scale=softmax_scale, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + mma_one_n_block = partial( + mma_one_n_block_all, seqlen=seqlen, softmax=softmax, score_mod_fn=score_mod_fn + ) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + pipeline_q.consumer_wait_w_index_phase(0, q_consumer_phase) + # For performance reason, we separate out two kinds of iterations: + # those that need masking on S, and those that don't. + # We need masking on S for the very last block when K and V has length not multiple of tile_n. + # We also need masking on S if it's causal, for the last several blocks. + # softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True + O_should_accumulate = False + + # ========================================== + # MAINLOOP + # ========================================== + if const_expr(not self.use_block_sparsity): + # ========================================== + # No block-sparsity (original path) + # ========================================== + # First iteration with seqlen masking + if const_expr(self.intra_wg_overlap): + kv_consumer_state = process_first_half_block( + n_block=n_block_max - 1, + seqlen=seqlen, + kv_consumer_state=kv_consumer_state, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod), + score_mod_fn=score_mod_fn, + is_first_block=True, + ) + else: + self.warp_scheduler_barrier_sync() + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=True), + is_first_n_block=True, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), + ) + O_should_accumulate = True + # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) + n_block_max -= 1 + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) + for n_tile in cutlass.range( + n_block_max - n_block_min_causal_local_mask, unroll=1 + ): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) + # The remaining iterations have no masking + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen, m_block, n_block_min + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) + for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True + # Separate iterations with local masking on the left + if const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True + # Release Q pipeline so the producer can load the next tile's Q + pipeline_q.consumer_release_w_index(0) + # Last "half" iteration + if const_expr(self.intra_wg_overlap): + kv_consumer_state = process_last_half_block( + kv_consumer_state=kv_consumer_state, + zero_init=not O_should_accumulate, + ) + O_should_accumulate = True + else: + self.warp_scheduler_barrier_arrive() + + else: + # ========================================== + # Block sparsity + # ========================================== + kv_consumer_state, O_should_accumulate, processed_any = consume_block_sparse_loads( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + seqlen, + kv_consumer_state, + mma_pv_fn, + mma_one_n_block, + process_first_half_block, + process_last_half_block, + mask_fn, + score_mod_fn, + O_should_accumulate, + self.mask_mod, + fastdiv_mods, + self.intra_wg_overlap, + self.warp_scheduler_barrier_sync, + self.warp_scheduler_barrier_arrive, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + self.q_subtile_factor if self.q_subtile_factor is not None else 1, + ) + + # Release Q pipeline so the producer can load the next tile's Q + pipeline_q.consumer_release_w_index(0) + + # Handle empty case (when no blocks to process) + if not processed_any: + softmax.reset() + acc_O.fill(0.0) + + q_consumer_phase ^= 1 + + sink_val = None + if const_expr(learnable_sink is not None): + if const_expr(not self.pack_gqa): + sink_val = Float32(learnable_sink[head_idx]) + else: # Each thread might have a different sink value due to different q_head + sink_val = cute.make_rmem_tensor_like(softmax.row_max, Float32) + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) + tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma_qk.partition_C(cS)) + for r in cutlass.range(cute.size(sink_val), unroll_full=True): + row = m_block * self.tile_m + tScS_mn[r][0] + q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead + sink_val[r] = Float32(learnable_sink[q_head_idx]) + + # normalize acc_O by row_sum and calculate the lse + row_scale = softmax.finalize(sink_val=sink_val) + softmax.rescale_O(acc_O, row_scale) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + self.epilogue( + acc_O, + softmax.row_sum, + mO, + mLSE, + sO, + seqlen, + gmem_tiled_copy_O, + tma_atom_O, + tiled_mma_pv, + tidx, + m_block, + head_idx, + batch_idx, + ) + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def first_half_block_overlap( + self, + n_block: Int32, + mma_qk_fn: Callable, + kv_consumer_state, + pipeline_k, + tOrP: cute.Tensor, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + seqlen: SeqlenInfoQK, + scores_scale: Optional[cute.Tensor] = None, + acc_O: Optional[cute.Tensor] = None, + mask_fn: Callable = None, + score_mod_fn: Optional[Callable] = None, + is_first_block: bool = False, + ): + """Processes the first half block when using intra-warpgroup-overlap""" + + pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state)) + acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0) + pipeline_k.consumer_release(kv_consumer_state) + + # Apply score modification if present + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) + + # Apply mask; mask_seqlen always True for first block + # Caveat: if full block further right than mask block, seqlen masking is redundant; + # however, masking is being applied anyway, so essentially no perf hit + mask_fn(acc_S, n_block=n_block, mask_seqlen=True) + + row_scale = softmax.online_softmax(acc_S, is_first=is_first_block) + + tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) + tOrP_cur = ( + tOrP + if const_expr(self.mma_pv_is_rs) + else cute.make_rmem_tensor_like(tOrP_acc, self.dtype) + ) + tOrP_cur.store(tOrP_acc.load().to(self.dtype)) + + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + # Fence and barrier to make smem store visible to WGMMA + cute.arch.fence_view_async_shared() + cute.arch.sync_warp() + + # For RescaleOBeforeGemm: initialize acc_O + if const_expr(self.rescale_O_before_gemm): + acc_O.fill(0.0) + scores_scale.store(row_scale.load()) + + return kv_consumer_state + + @cute.jit + def last_half_block_overlap( + self, + kv_consumer_state, + pipeline_v, + mma_pv_fn: Callable, + zero_init: bool, + scores_scale: Optional[cute.Tensor] = None, + softmax: Optional[Softmax] = None, + acc_O: Optional[cute.Tensor] = None, + ): + """Processes the final PV GEMM when using intra-warpgroup-overlap""" + + # For RescaleOBeforeGemm: rescale O before the final PV GEMM + if const_expr(self.rescale_O_before_gemm): + softmax.rescale_O(acc_O, scores_scale) + + pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) + mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0) + pipeline_v.consumer_release(kv_consumer_state) + kv_consumer_state.advance() + return kv_consumer_state + + @cute.jit + def mma_one_n_block( + self, + smem_pipe_read: pipeline.PipelineState | pipeline_custom.PipelineStateSimple, + n_block: Int32, + mma_qk_fn: Callable, + mma_pv_fn: Callable, + pipeline_k: pipeline.PipelineAsync, + pipeline_v: pipeline.PipelineAsync, + acc_O: cute.Tensor, + tOrP: cute.Tensor, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + seqlen: SeqlenInfoQK, + scores_scale: Optional[cute.Tensor] = None, # not used + score_mod_fn: Optional[Callable] = None, + mask_fn: Optional[Callable] = None, + is_first_n_block: cutlass.Constexpr = False, + check_inf: cutlass.Constexpr = True, + ): + pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) + # S = Q @ K.T + acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) + self.warp_scheduler_barrier_arrive() + warpgroup.wait_group(0) + pipeline_k.consumer_release(smem_pipe_read) + + # handle score mods and masking + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) + if const_expr(mask_fn is not None): + mask_fn(acc_S=acc_S, n_block=n_block) + + row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) + # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S)) + tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) + tOrP_cur = ( + tOrP + if const_expr(self.mma_pv_is_rs) + else cute.make_rmem_tensor_like(tOrP_acc, self.dtype) + ) + # tOrP.store(tOrP_acc.load().to(self.dtype)) + # the "to(self.dtype)" conversion fails to vectorize for block sizes other + # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of + # 2 elements. So we just call ptx directly. + utils.cvt_f16(tOrP_acc, tOrP_cur) + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + softmax.rescale_O(acc_O, row_scale) + if const_expr(not self.mma_pv_is_rs): + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_view_async_shared() + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) + self.warp_scheduler_barrier_sync() + # O += P @ V + mma_pv_fn(B_idx=smem_pipe_read.index, wg_wait=0) + pipeline_v.consumer_release(smem_pipe_read) + smem_pipe_read.advance() + return smem_pipe_read + + @cute.jit + def mma_one_n_block_intrawg_overlap( + self, + smem_pipe_read: pipeline.PipelineState | pipeline_custom.PipelineStateSimple, + n_block: Int32, + mma_qk_fn: Callable, + mma_pv_fn: Callable, + pipeline_k: pipeline.PipelineAsync, + pipeline_v: pipeline.PipelineAsync, + acc_O: cute.Tensor, + tOrP: cute.Tensor, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + seqlen: SeqlenInfoQK, + scores_scale: Optional[cute.Tensor] = None, + score_mod_fn: Optional[Callable] = None, + mask_fn: Optional[Callable] = None, + check_inf: cutlass.Constexpr = True, + ): + smem_pipe_read_v = smem_pipe_read.clone() + smem_pipe_read.advance() + pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) + self.warp_scheduler_barrier_sync() + # S = Q @ K.T + acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) + # RescaleOBeforeGemm: rescale O while QK GEMM is in flight, before PV GEMM + if const_expr(self.rescale_O_before_gemm): + softmax.rescale_O(acc_O, scores_scale) + pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v)) + # O += P @ V + mma_pv_fn(B_idx=smem_pipe_read_v.index, wg_wait=-1) + self.warp_scheduler_barrier_arrive() + warpgroup.wait_group(1) + pipeline_k.consumer_release(smem_pipe_read) + + # handle score mods and masking + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) + if const_expr(mask_fn is not None): + mask_fn(acc_S=acc_S, n_block=n_block) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S)) + + row_scale = softmax.online_softmax(acc_S, check_inf=check_inf) + warpgroup.wait_group(0) + pipeline_v.consumer_release(smem_pipe_read_v) + tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) + tOrP_cur = ( + tOrP + if const_expr(self.mma_pv_is_rs) + else cute.make_rmem_tensor_like(tOrP_acc, self.dtype) + ) + # tOrP_cur.store(tOrP_acc.load().to(self.dtype)) + # the "to(self.dtype)" conversion fails to vectorize for block sizes other + # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of + # 2 elements. So we just call ptx directly. + utils.cvt_f16(tOrP_acc, tOrP_cur) + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + if const_expr(not self.rescale_O_before_gemm): + softmax.rescale_O(acc_O, row_scale) + if const_expr(self.rescale_O_before_gemm): + scores_scale.store(row_scale.load()) + if const_expr(not self.mma_pv_is_rs): + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_view_async_shared() + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + return smem_pipe_read + + @cute.jit + def mma_init(self): + warp_group_idx = utils.canonical_warp_group_idx(sync=False) + if const_expr(self.use_scheduler_barrier): + if warp_group_idx == 1: + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), + number_of_threads=2 * self.num_threads_per_warp_group, + ) + + @cute.jit + def apply_score_mod( + self, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + acc_S, + n_block, + softmax_scale, + seqlen, + aux_tensors: Optional[list] = None, + fastdiv_mods=None, + ): + # Prepare index tensor + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) + cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS) + tScS = thr_mma_qk.partition_C(cS) + + apply_score_mod_inner( + acc_S, + tScS, + self.score_mod, + batch_idx, + head_idx, + softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + seqlen_info=seqlen, + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + + def warp_scheduler_barrier_sync(self): + if const_expr(self.use_scheduler_barrier): + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + - 1 + + utils.canonical_warp_group_idx(sync=False), + number_of_threads=2 * self.num_threads_per_warp_group, + ) + + def warp_scheduler_barrier_arrive(self): + if const_expr(self.use_scheduler_barrier): + assert self.num_wg_mma in [2, 3] + cur_wg = utils.canonical_warp_group_idx(sync=False) - 1 + if const_expr(self.num_wg_mma == 2): + next_wg = 1 - cur_wg + else: + t = cur_wg + 1 + next_wg = t % self.num_wg_mma + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, + number_of_threads=2 * self.num_threads_per_warp_group, + ) diff --git a/flash-attn4/torch-ext/flash_attn4/interface.py b/flash-attn4/torch-ext/flash_attn4/interface.py index cc1c9cc4..38dc2702 100644 --- a/flash-attn4/torch-ext/flash_attn4/interface.py +++ b/flash-attn4/torch-ext/flash_attn4/interface.py @@ -21,6 +21,7 @@ import os import math +from dataclasses import dataclass from functools import lru_cache from typing import Optional, Tuple, Callable @@ -31,6 +32,8 @@ import cutlass import cutlass.cute as cute +from cutlass import Int32, Float32 +from .quack.compile_utils import make_fake_tensor as fake_tensor from .cache_utils import get_jit_cache from .testing import is_fake_mode @@ -43,30 +46,201 @@ from . import utils +from . import fa_logging from .cute_dsl_utils import ( to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims, ) -from .flash_fwd import FlashAttentionForwardSm90 +from .flash_fwd import FlashAttentionForwardSm80 +from .flash_fwd_sm90 import FlashAttentionForwardSm90 from .flash_fwd_sm100 import FlashAttentionForwardSm100 +from .flash_fwd_sm120 import FlashAttentionForwardSm120 from .flash_bwd_preprocess import FlashAttentionBackwardPreprocess from .flash_bwd import FlashAttentionBackwardSm80 from .flash_bwd_sm90 import FlashAttentionBackwardSm90 from .flash_bwd_sm100 import FlashAttentionBackwardSm100 +from .flash_bwd_sm120 import FlashAttentionBackwardSm120 from .flash_bwd_postprocess import FlashAttentionBackwardPostprocess from .flash_fwd_combine import FlashAttentionForwardCombine from .block_sparsity import ( BlockSparseTensorsTorch, + get_sparse_q_block_size, to_cute_block_sparse_tensors, normalize_block_sparse_config, normalize_block_sparse_config_bwd, ) +def _parse_arch_str(arch_str): + """Parse arch string (e.g. 'sm_80', 'sm_90a', '80', '100') to int (e.g. 80, 90, 100).""" + import re + match = re.match(r"^(?:sm_?|SM_?)?(\d+)(\d)([af]?)$", arch_str) + if not match: + raise ValueError(f"Invalid arch format: {arch_str}") + major, minor, _ = match.groups() + return int(major) * 10 + int(minor) + + @lru_cache(maxsize=None) def _get_device_arch(): - """Cached device arch check.""" + """Cached device arch check. + + Override with FLASH_ATTENTION_ARCH (e.g. 'sm_80' or '80') to select which + kernel path to use (SM80/SM90/SM100/SM120) independently of the compilation + target (CUTE_DSL_ARCH). + + For CPU-only compilation (no GPU), set both: + FLASH_ATTENTION_ARCH=sm_80 (kernel selection) + CUTE_DSL_ARCH=sm_80 (compilation target) + """ + arch_override = os.environ.get("FLASH_ATTENTION_ARCH", None) + if arch_override is not None: + return _parse_arch_str(arch_override) major, minor = torch.cuda.get_device_capability() - return major * 10 + minor + return major * 10 + int(minor) + + +def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, alignment: int) -> None: + """Validate head dimension constraints based on compute capability.""" + is_deepseek_shape = head_dim == 192 and head_dim_v == 128 + is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128 + + is_sm90_range = 8 <= head_dim <= 256 and 8 <= head_dim_v <= 256 + if compute_capability == 9: + assert is_sm90_range and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( + f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM90. " + f"head_dim and head_dim_v must be between 8 and 256 and divisible by {alignment}." + ) + elif compute_capability in [10, 11]: + assert (is_standard_range or is_deepseek_shape) and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( + f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM100/SM110. " + f"head_dim and head_dim_v must be between 8 and 128 and divisible by {alignment}, or (192, 128) for DeepSeek." + ) + + +@dataclass(frozen=True) +class FwdConfig: + m_block_size: int + n_block_size: int + mma_pv_is_rs: bool + intra_wg_overlap: bool + + +def _tile_size_fwd_sm90(head_dim, head_dim_v, is_causal, is_local, sparse_block_size_q=None): + """Return FwdConfig for SM90 forward. + + Tile sizes and flags based on tile_size_fwd_sm90 in hopper/tile_size.h, adjusted + for the Python kernel's different register/smem tradeoffs (benchmarked on H100 SXM). + + When sparse_block_size_q is set, tile_m must divide it. For head_dim <= 96 the + optimal tile_m=192 is used when compatible, otherwise we fall back to 128. + """ + if head_dim <= 64: + # C++: 192ร—192 non-causal, 192ร—128 causal/local. + # Python: 192ร—128 RS+OL is consistently best across seqlens. + if sparse_block_size_q is not None and sparse_block_size_q % 192 != 0: + return FwdConfig(128, 128, True, True) + return FwdConfig(192, 128, True, True) + elif head_dim <= 96: + # C++: 192ร—144 noRS+OL for all cases. + # Python: RS is catastrophic with 192ร— tiles (~300 vs ~600 TFLOPS). + # noRS+OL is always required. Causal: 192ร—128 slightly better short seqlen. + if sparse_block_size_q is not None and sparse_block_size_q % 192 != 0: + return FwdConfig(128, 128, False, True) + if is_causal or is_local: + return FwdConfig(192, 128, False, True) + else: + return FwdConfig(192, 144, False, True) + elif head_dim <= 128: + return FwdConfig(128, 128, True, True) + elif head_dim <= 192: + tile_n = 96 if is_local else (128 if head_dim_v <= 128 else 112) + return FwdConfig(128, tile_n, True, True) + else: # hdim 256 + tile_n = 64 if is_local else 80 + return FwdConfig(128, tile_n, True, True) + +@dataclass(frozen=True) +class BwdConfig: + m_block_size: int + n_block_size: int + num_stages_Q: int + num_stages_dO: int + num_stages_PdS: int + SdP_swapAB: bool + dKV_swapAB: bool + dQ_swapAB: bool + AtomLayoutMSdP: int + AtomLayoutNdKV: int + AtomLayoutMdQ: int + num_wg: int = 2 # MMA warp groups (total threads = (num_wg + 1) * 128) + dQ_single_wg: bool = False + + +def _tile_size_bwd_sm90(head_dim, head_dim_v, causal, local, sparse_block_size_q=None): + """Return BwdConfig for SM90. + + Configs based on C++ FA3 hopper/flash_bwd_launch_template.h, + benchmarked on H100 SXM. + """ + if head_dim <= 64: + # C++ FA3: 128, 128, 64, ..., 2, 2, true, false, false, 2, 1, 2, 2 + return BwdConfig( + m_block_size=128, n_block_size=128, + num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2, + SdP_swapAB=True, dKV_swapAB=False, dQ_swapAB=False, + AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=2, + ) + elif head_dim <= 96: + # C++ FA3: 64, 128, 96, dQ_swapAB=False + return BwdConfig( + m_block_size=64, n_block_size=128, + num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2, + SdP_swapAB=True, dKV_swapAB=False, dQ_swapAB=False, + AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1, + dQ_single_wg=True, + ) + elif head_dim <= 128: + # C++ FA3: causal/local: 64, 128; non-causal: 80, 128 with dQ_swapAB + is_causal_or_local = causal or local + m_block_size = 64 if is_causal_or_local else 80 + if sparse_block_size_q is not None and sparse_block_size_q % m_block_size != 0: + m_block_size = 64 + return BwdConfig( + m_block_size=m_block_size, + n_block_size=128, + num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2, + SdP_swapAB=True, dKV_swapAB=False, + dQ_swapAB=m_block_size % 64 != 0, + AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1, + ) + elif head_dim <= 192: + hdimv128 = head_dim_v <= 128 + if hdimv128: + return BwdConfig( + m_block_size=64, n_block_size=96, + num_stages_Q=2, num_stages_dO=2, num_stages_PdS=1, + SdP_swapAB=False, dKV_swapAB=True, dQ_swapAB=False, + AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1, + num_wg=2, + ) + else: + return BwdConfig( + m_block_size=64, n_block_size=96, + num_stages_Q=2, num_stages_dO=1, num_stages_PdS=1, + SdP_swapAB=False, dKV_swapAB=True, dQ_swapAB=False, + AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1, + num_wg=2, + ) + else: + # hdim 256 + return BwdConfig( + m_block_size=64, n_block_size=64, + num_stages_Q=1, num_stages_dO=1, num_stages_PdS=1, + SdP_swapAB=False, dKV_swapAB=False, dQ_swapAB=False, + AtomLayoutMSdP=1, AtomLayoutNdKV=1, AtomLayoutMdQ=1, + ) + + def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x @@ -76,7 +250,8 @@ def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device): assert t.shape == expected_shape, f"{name} shape {t.shape} != expected {expected_shape}" assert t.dtype == expected_dtype, f"{name} dtype {t.dtype} != expected {expected_dtype}" assert t.device == expected_device, f"{name} device {t.device} != expected {expected_device}" - assert t.is_cuda, f"{name} must be on CUDA" + if not is_fake_mode(): + assert t.is_cuda, f"{name} must be on CUDA" torch2cute_dtype_map = { @@ -96,6 +271,29 @@ def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits): return min(num_SMs // total_mblocks, max_splits, num_n_blocks) +def _resolve_causal_local_window(causal, window_size_left, window_size_right, mask_mod=None): + """Resolve causal/local/window settings into canonical form. + + Returns (causal, local, window_size_left, window_size_right). + """ + if mask_mod is not None: + return False, False, window_size_left, window_size_right + if causal: + window_size_right = 0 + if window_size_left is not None and window_size_right is not None and window_size_left + window_size_right < 0: + window_size_left = None + window_size_right = None + if window_size_left is not None or window_size_right is not None: + if window_size_left is None and window_size_right == 0: + causal, local = True, False + window_size_right = None + else: + causal, local = False, True + else: + local = False + return causal, local, window_size_left, window_size_right + + def _flash_attn_fwd( q: torch.Tensor, k: torch.Tensor, @@ -113,11 +311,9 @@ def _flash_attn_fwd( window_size_left: Optional[int] = None, window_size_right: Optional[int] = None, learnable_sink: Optional[torch.Tensor] = None, - # m_block_size: int = 128, - # n_block_size: int = 64, - # num_threads: int = 128, - m_block_size: int = 128, - n_block_size: int = 128, + tile_mn: Optional[Tuple[int, int]] = None, + mma_pv_is_rs: Optional[bool] = None, + intra_wg_overlap: Optional[bool] = None, num_threads: int = 384, num_splits: int = 1, pack_gqa: Optional[bool] = None, @@ -138,7 +334,7 @@ def _flash_attn_fwd( mask_mod: A callable that takes token position information and selectively masks block_sparse_tensors: A tuple of tensors used for block sparsity. return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate - Note: the returned LSE currently does not support taking gradient. + The returned LSE supports taking gradient. out: Optional pre-allocated output tensor. If None, will be allocated internally. lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed. aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. @@ -203,25 +399,27 @@ def _flash_attn_fwd( assert learnable_sink.shape == (num_head,) assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" - assert all( - t is None or t.is_cuda - for t in ( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - page_table, - learnable_sink, - ) - ), "inputs must be on CUDA device" + if not is_fake_mode(): + assert all( + t is None or t.is_cuda + for t in ( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + page_table, + learnable_sink, + ) + ), "inputs must be on CUDA device" + arch = _get_device_arch() if _arch is None else _arch + assert arch // 10 in [8, 9, 10, 11, 12], "Unsupported compute capability. Supported: 8.x, 9.x, 10.x, 11.x, 12.x" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" - assert head_dim <= 256, "head_dim must be less than or equal to 256" alignment = 16 // q.element_size() - assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" - assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" + if arch // 10 not in [8, 12]: + _validate_head_dims(head_dim, head_dim_v, arch // 10, alignment) if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(head_dim) if softcap == 0.0: @@ -253,43 +451,47 @@ def _flash_attn_fwd( _validate_tensor(lse, "lse", lse_shape, torch.float32, device) dtype = torch2cute_dtype_map[q.dtype] - arch = _get_device_arch() if _arch is None else _arch + use_block_sparsity = block_sparse_tensors is not None + + causal, local, window_size_left, window_size_right = _resolve_causal_local_window( + causal, window_size_left, window_size_right, mask_mod + ) - assert arch // 10 in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x" + requested_use_clc_scheduler = utils._get_use_clc_scheduler_default() + requested_disable_2cta = utils._get_disable_2cta_default() - use_block_sparsity = block_sparse_tensors is not None + current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) - if mask_mod is None: - if causal: - window_size_right = 0 - if window_size_left is not None and window_size_right is not None and window_size_left + window_size_right < 0: - window_size_left = None - window_size_right = None - local = window_size_left is not None or window_size_right is not None - if window_size_left is not None or window_size_right is not None: - if window_size_left is None and window_size_right == 0: - causal, local = True, False - window_size_right = None + # SM80/SM120: uses SM80 MMA, 128 threads (4 warps) + if arch // 10 in [8, 12]: + num_threads = 128 + + fwd_cfg = FwdConfig(128, 128, True, True) # default + if tile_mn is None: + if arch // 10 == 12: + # SM120 tile sizes tuned for 99 KB SMEM capacity: + # D<=64: 128x128 โ†’ 48 KB (good occupancy) + # D>64: 128x64 โ†’ 64 KB (128x128 would use 96 KB, hurting occupancy) + if head_dim <= 64: + fwd_cfg = FwdConfig(128, 128, True, True) else: - causal, local = False, True + fwd_cfg = FwdConfig(128, 64, True, True) + elif arch // 10 == 8: + fwd_cfg = FwdConfig(128, 64, True, True) # SM80, should tune + elif arch // 10 == 9: + sparse_q = get_sparse_q_block_size(block_sparse_tensors, seqlen_q) + fwd_cfg = _tile_size_fwd_sm90(head_dim, head_dim_v, causal, local, sparse_block_size_q=sparse_q) else: - causal, local = False, False - - current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + fwd_cfg = FwdConfig(tile_mn[0], tile_mn[1], fwd_cfg.mma_pv_is_rs, fwd_cfg.intra_wg_overlap) + tile_m, tile_n = fwd_cfg.m_block_size, fwd_cfg.n_block_size + if mma_pv_is_rs is None: + mma_pv_is_rs = fwd_cfg.mma_pv_is_rs + if intra_wg_overlap is None: + intra_wg_overlap = fwd_cfg.intra_wg_overlap - if arch // 10 == 9: # TODO: tune block size according to hdim. - if head_dim == head_dim_v == 128 and not causal and not local and not use_block_sparsity: - n_block_size = 192 - - if arch // 10 in [10, 11]: - if ( - pack_gqa - and (128 % qhead_per_kvhead != 0) - ): - pack_gqa = False - # TODO: fix GQA + SplitKV + non-varlen - if pack_gqa and num_splits != 1 and cu_seqlens_q is None: - pack_gqa = False + # TODO: fix GQA + SplitKV + non-varlen + if pack_gqa and num_splits != 1 and cu_seqlens_q is None: + pack_gqa = False if max_seqlen_q is None: max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q @@ -297,28 +499,50 @@ def _flash_attn_fwd( max_seqlen_k = seqlen_k seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead if arch // 10 == 10: - q_stage = 2 if seqlen_q_packgqa > m_block_size else 1 + q_stage = 2 if seqlen_q_packgqa > tile_m else 1 else: q_stage = 1 + m_block_size_effective = q_stage * tile_m + seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, (window_size_right or max_seqlen_k) + (window_size_left or max_seqlen_k) + 1 + tile_m)) + num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective + total_mblocks = batch_size * num_head_kv * num_m_blocks + num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n + num_SMs = 132 if is_fake_mode() else torch.cuda.get_device_properties(device).multi_processor_count if num_splits < 1: - m_block_size_effective = q_stage * m_block_size - seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, window_size_right + window_size_left + 1 + m_block_size)) - num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size - num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective - total_mblocks = batch_size * num_head_kv * num_m_blocks - num_splits = num_splits_heuristic( - total_mblocks, - torch.cuda.get_device_properties(device).multi_processor_count, - num_n_blocks, - 128, - ) + num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) + + # SplitKV uses float32 partial output, which doubles the O buffer size + # in shared memory, causing OOM for diff-headdim (192, 128) + if arch // 10 in [10, 11] and head_dim != head_dim_v and num_splits > 1: + if num_n_blocks >= 64: + tile_n = 64 + num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n + num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) + else: + num_splits = 1 is_split_kv = num_splits > 1 if is_split_kv: out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device) lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device) + use_2cta_instrs = ( + arch // 10 in [10, 11] + and not requested_disable_2cta + and not causal + and not local + and not is_split_kv + and cu_seqlens_q is None + and seqused_q is None + and not use_block_sparsity + and page_size in [None, 128] + and int(math.ceil(head_dim / 16) * 16) in [128, 192] + and int(math.ceil(head_dim_v / 16) * 16) == 128 + and seqlen_q_packgqa > 2 * tile_m + and (tile_m % qhead_per_kvhead == 0 or not pack_gqa) + ) + # hash score and mask mods for compile cache score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False @@ -370,14 +594,14 @@ def _flash_attn_fwd( num_head=num_head, seqlen_q=seqlen_q, seqlen_k=seqlen_k, - block_size=(m_block_size, n_block_size), + block_size=(tile_m, tile_n), q_stage=q_stage, ) - if aux_tensors is not None: + if aux_tensors is not None: aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors) else: aux_tensor_metadata = None - + compile_key = ( dtype, head_dim, @@ -398,15 +622,20 @@ def _flash_attn_fwd( window_size_left is not None, window_size_right is not None, learnable_sink is not None, - m_block_size, - n_block_size, + tile_m, + tile_n, q_stage, num_threads, is_split_kv, pack_gqa, arch, - page_size not in [None, 128], # paged KV non-TMA + page_size not in [None, tile_n], # paged KV non-TMA + use_2cta_instrs, q_subtile_factor, + mma_pv_is_rs, + intra_wg_overlap, + requested_use_clc_scheduler, + fa_logging.get_fa_log_level(), ) if compile_key not in _flash_attn_fwd.compile_cache: ( @@ -445,10 +674,28 @@ def _flash_attn_fwd( if aux_tensors is not None: cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors] - if arch // 10 == 9: - assert page_table is None, "paged KV not supported on SM 9.0" + if arch // 10 == 8: + assert page_table is None, "paged KV not supported on SM 8.0" + assert not is_split_kv, "SplitKV not supported on SM 8.0" + fa_fwd = FlashAttentionForwardSm80( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + is_causal=causal, + is_local=local, + pack_gqa=pack_gqa, + tile_m=tile_m, + tile_n=tile_n, + num_stages=1, + num_threads=num_threads, + Q_in_regs=False, + score_mod=score_mod, + mask_mod=mask_mod, + has_aux_tensors=aux_tensors is not None, + ) + elif arch // 10 == 9: assert not is_split_kv, "SplitKV not supported on SM 9.0" - # fa_fwd = FlashAttentionForwardSm80( fa_fwd = FlashAttentionForwardSm90( dtype, head_dim, @@ -457,33 +704,21 @@ def _flash_attn_fwd( is_causal=causal, is_local=local, pack_gqa=pack_gqa, - tile_m=m_block_size, - tile_n=n_block_size, + tile_m=tile_m, + tile_n=tile_n, # num_stages=1, num_stages=2, num_threads=num_threads, Q_in_regs=False, - intra_wg_overlap=True, - mma_pv_is_rs=True, + intra_wg_overlap=intra_wg_overlap, + mma_pv_is_rs=mma_pv_is_rs, mask_mod=mask_mod, score_mod=score_mod, has_aux_tensors=aux_tensors is not None, q_subtile_factor=q_subtile_factor, + paged_kv_non_tma=page_size not in [None, tile_n], ) elif arch // 10 in [10, 11]: - head_dim_padded = int(math.ceil(head_dim / 16) * 16) - head_dim_v_padded = int(math.ceil(head_dim / 16) * 16) - use_2cta_instrs = ( - not causal - and not local - and not is_split_kv - and cu_seqlens_q is None - and seqused_q is None - and not use_block_sparsity - and page_size in [None, 128] - and head_dim_padded == 128 - and head_dim_v_padded == 128 - ) fa_fwd = FlashAttentionForwardSm100( head_dim, head_dim_v, @@ -492,8 +727,8 @@ def _flash_attn_fwd( is_local=local, is_split_kv=is_split_kv, pack_gqa=pack_gqa, - m_block_size=m_block_size, - n_block_size=n_block_size, + m_block_size=tile_m, + n_block_size=tile_n, q_stage=q_stage, is_persistent=not causal and not local @@ -503,14 +738,37 @@ def _flash_attn_fwd( score_mod=score_mod, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, - paged_kv_non_tma=page_size not in [None, 128], + paged_kv_non_tma=page_size not in [None, tile_n], is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, q_subtile_factor=q_subtile_factor, use_2cta_instrs=use_2cta_instrs, + use_clc_scheduler=requested_use_clc_scheduler, + ) + elif arch // 10 == 12: + # SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity + assert not use_block_sparsity, "Block sparsity not supported on SM 12.0" + assert page_table is None, "Paged KV not supported on SM 12.0 in this PR" + assert not is_split_kv, "SplitKV not supported on SM 12.0 in this PR" + fa_fwd = FlashAttentionForwardSm120( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + is_causal=causal, + is_local=local, + pack_gqa=pack_gqa, + tile_m=tile_m, + tile_n=tile_n, + num_stages=1, + num_threads=num_threads, + Q_in_regs=False, + score_mod=score_mod, + mask_mod=mask_mod, + has_aux_tensors=aux_tensors is not None, ) else: raise ValueError( - f"Unsupported compute capability: {arch}. Supported: 9.x, 10.x, 11.x" + f"Unsupported compute capability: {arch}. Supported: 8.x, 9.x, 10.x, 11.x, 12.x" ) # TODO: check @can_implement _flash_attn_fwd.compile_cache[compile_key] = cute.compile( @@ -521,7 +779,6 @@ def _flash_attn_fwd( o_tensor, lse_tensor, softmax_scale, - current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, @@ -532,6 +789,7 @@ def _flash_attn_fwd( learnable_sink_tensor, sparse_tensors, cute_aux_tensors, + current_stream, options="--enable-tvm-ffi", ) @@ -547,7 +805,6 @@ def _flash_attn_fwd( out.detach() if not is_split_kv else out_partial, lse_partial if is_split_kv else lse, softmax_scale, - current_stream, cu_seqlens_q, cu_seqlens_k, seqused_q, @@ -574,6 +831,140 @@ def _flash_attn_fwd( _flash_attn_fwd.compile_cache = get_jit_cache("fwd") +def make_fake_bwd_tensors(dtype, has_gqa, varlen_q, varlen_k): + sym = cute.sym_int + # divisibility in elements: assumed_align_bytes = divisibility * dtype.width // 8 + # For 16-byte align: fp16/bf16 โ†’ divisibility=8, float32 โ†’ divisibility=4 + div = 128 // dtype.width # 8 for fp16/bf16 + # Shared sym_ints for dimensions that must match across tensors + b, seqlen_q, seqlen_k, h_q, d, d_v = sym(), sym(), sym(), sym(), sym(), sym() + h_kv = h_q if not has_gqa else sym() + seqlen_q_rounded, seqlen_k_rounded = sym(), sym() + seqlen_q_d_rounded, seqlen_k_d_rounded, seqlen_k_dv_rounded = sym(), sym(), sym() + total_q, total_k, total_q_rounded, total_k_rounded = sym(), sym(), sym(), sym() + total_q_d_rounded, total_k_d_rounded, total_k_dv_rounded = sym(), sym(), sym() + b_seqlenq = (b, seqlen_q) if not varlen_q else (total_q,) + b_seqlenk = (b, seqlen_k) if not varlen_k else (total_k,) + mQ = fake_tensor(dtype, (*b_seqlenq, h_q, d), divisibility=div) + mO = fake_tensor(dtype, (*b_seqlenq, h_q, d_v), divisibility=div) + mdO = fake_tensor(dtype, (*b_seqlenq, h_q, d_v), divisibility=div) + mK = fake_tensor(dtype, (*b_seqlenk, h_kv, d), divisibility=div) + mV = fake_tensor(dtype, (*b_seqlenk, h_kv, d_v), divisibility=div) + mdQ = fake_tensor(dtype, (*b_seqlenq, h_q, d), divisibility=div) + mdK = fake_tensor(dtype, (*b_seqlenk, h_kv, d), divisibility=div) + mdV = fake_tensor(dtype, (*b_seqlenk, h_kv, d_v), divisibility=div) + if not varlen_q: + mLSE = fake_tensor(Float32, (b, h_q, seqlen_q), divisibility=1) + mLSElog2 = fake_tensor(Float32, (b, h_q, seqlen_q_rounded), divisibility=4) + mPdPsum = fake_tensor(Float32, (b, h_q, seqlen_q_rounded), divisibility=4) + dQaccum = fake_tensor(Float32, (b, h_q, seqlen_q_d_rounded), divisibility=4) + else: + mLSE = fake_tensor(Float32, (h_q, total_q), divisibility=1) + mLSElog2 = fake_tensor(Float32, (h_q, total_q_rounded), divisibility=4) + mPdPsum = fake_tensor(Float32, (h_q, total_q_rounded), divisibility=4) + dQaccum = fake_tensor(Float32, (h_q, total_q_d_rounded), divisibility=4) + if not has_gqa: + mdKaccum, mdVaccum = None, None + else: + if not varlen_k: + mdKaccum = fake_tensor(Float32, (b, h_kv, seqlen_k_rounded), divisibility=4) + mdVaccum = fake_tensor(Float32, (b, h_kv, seqlen_k_dv_rounded), divisibility=4) + else: + mdKaccum = fake_tensor(Float32, (h_kv, total_k_rounded), divisibility=4) + mdVaccum = fake_tensor(Float32, (h_kv, total_k_dv_rounded), divisibility=4) + return mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, dQaccum, mdKaccum, mdVaccum + + +def _compile_bwd_preprocess( + dtype, head_dim, head_dim_v, m_block_size, has_cuseqlens_q, has_seqused_q, has_dlse, +): + """Compile bwd preprocess kernel using cute fake tensors (no real GPU tensors needed).""" + mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors( + dtype, has_gqa=True, varlen_q=has_cuseqlens_q, varlen_k=False + ) + batch = mQ.shape[0] if not has_cuseqlens_q else cute.sym_int() + batchp1 = cute.sym_int() + mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None + mSequsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None + mdLSE = fake_tensor(Float32, mLSE.shape, divisibility=1) if has_dlse else None + fa_bwd_pre = FlashAttentionBackwardPreprocess(dtype, head_dim, head_dim_v, m_block_size) + return cute.compile( + fa_bwd_pre, mO, mdO, mPdPsum, mLSE, mLSElog2, mdQaccum, mCuSeqlensQ, mSequsedQ, mdLSE, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + +def _bwd_preprocess( + out, dout, dpsum, lse, lse_log2, dq_accum, + cu_seqlens_q, seqused_q, dlse, + dtype, head_dim, head_dim_v, m_block_size, +): + """Backward preprocess: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum.""" + is_varlen = cu_seqlens_q is not None + compile_key = ( + dtype, head_dim, head_dim_v, m_block_size, is_varlen, seqused_q is not None, dlse is not None, + ) + if compile_key not in _bwd_preprocess.compile_cache: + _bwd_preprocess.compile_cache[compile_key] = _compile_bwd_preprocess(*compile_key) + if not is_fake_mode(): + _bwd_preprocess.compile_cache[compile_key]( + out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse + ) + + +_bwd_preprocess.compile_cache = get_jit_cache("bwd_pre") + + +def _compile_bwd_postprocess( + dtype, hdim, block_size, num_threads, atom_layout, swap_ab, + has_cuseqlens_q, has_seqused_q, + use_2cta_instrs, cluster_size, arch, +): + """Compile bwd postprocess kernel using cute fake tensors.""" + mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors( + dtype, has_gqa=True, varlen_q=has_cuseqlens_q, varlen_k=False + ) + batch = mQ.shape[0] if not has_cuseqlens_q else cute.sym_int() + batchp1 = cute.sym_int() + mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None + mSeqUsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None + fa_bwd_post = FlashAttentionBackwardPostprocess( + dtype, hdim, arch, block_size, num_threads, atom_layout, swap_ab, + use_2cta_instrs=use_2cta_instrs, + cluster_size=cluster_size, + ) + return cute.compile( + fa_bwd_post, mdQaccum, mdQ, Float32(0.0), mCuSeqlensQ, mSeqUsedQ, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + +def _bwd_postprocess_convert( + accum, output, scale, + cu_seqlens, seqused, + arch, dtype, hdim, block_size, num_threads, + atom_layout, swap_ab, + use_2cta_instrs=False, cluster_size=1, +): + """Backward postprocess: convert float32 accumulator to bf16/fp16 output.""" + compile_key = ( + dtype, hdim, block_size, num_threads, atom_layout, swap_ab, + cu_seqlens is not None, seqused is not None, + use_2cta_instrs, cluster_size, arch, + ) + if compile_key not in _bwd_postprocess_convert.compile_cache: + _bwd_postprocess_convert.compile_cache[compile_key] = _compile_bwd_postprocess(*compile_key) + if not is_fake_mode(): + _bwd_postprocess_convert.compile_cache[compile_key]( + accum, output, scale, cu_seqlens, seqused, + ) + + +_bwd_postprocess_convert.compile_cache = get_jit_cache("bwd_post") + + def _flash_attn_bwd( q: torch.Tensor, k: torch.Tensor, @@ -614,47 +1005,74 @@ def _flash_attn_bwd( mask_mod: Optional[Callable] = None, aux_tensors: Optional[list[torch.Tensor]] = None, block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, + dlse: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: arch = _get_device_arch() - assert arch // 10 in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x" + assert arch // 10 in [9, 10, 11, 12], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x, 12.x" + sparse_q = None + if block_sparse_tensors is not None and arch // 10 == 9: + sparse_q = block_sparse_tensors.block_size[0] if block_sparse_tensors.block_size is not None else 128 num_head, head_dim = q.shape[-2:] + head_dim_v = v.shape[-1] - if causal: - window_size_right = 0 - if window_size_left is not None and window_size_right is not None and window_size_left + window_size_right < 0: - window_size_left = None - window_size_right = None - local = window_size_left is not None or window_size_right is not None - if local: - if window_size_left is None and window_size_right == 0: - causal, local = True, False - window_size_right = None - else: - causal, local = False, True + causal, local, window_size_left, window_size_right = _resolve_causal_local_window( + causal, window_size_left, window_size_right + ) - if arch // 10 == 9: - m_block_size = 80 if not causal else 64 - n_block_size = 128 - num_stages_Q = 2 - num_stages_dO = 2 - num_stages_PdS = 2 - SdP_swapAB = True + if arch // 10 == 12: + # SM120: uses SM80 MMA with 99 KB SMEM, 128 threads (4 warps). + m_block_size = 64 + n_block_size = 64 + if head_dim <= 64: + num_stages_Q = 2 + num_stages_dO = 2 + else: + num_stages_Q = 1 + num_stages_dO = 1 + SdP_swapAB = False dKV_swapAB = False - dQ_swapAB = not causal - AtomLayoutMSdP = 1 - AtomLayoutNdKV = 2 - AtomLayoutMdQ = 1 + dQ_swapAB = False + AtomLayoutMSdP = 4 + AtomLayoutNdKV = 4 + AtomLayoutMdQ = 4 + V_in_regs = False + cluster_size = 1 + use_2cta_instrs = False + num_threads = 128 + assert not (block_sparse_tensors is not None), "Block sparsity backward not supported on SM 12.0" + assert score_mod is None and score_mod_bwd is None, "score_mod backward not supported on SM 12.0" + assert mask_mod is None, "mask_mod backward not supported on SM 12.0" + assert deterministic is False, "deterministic backward not supported on SM 12.0" + elif arch // 10 == 9: + cfg = _tile_size_bwd_sm90( + head_dim, + head_dim_v, + causal, + local, + sparse_block_size_q=sparse_q, + ) + m_block_size = cfg.m_block_size + n_block_size = cfg.n_block_size + num_stages_Q = cfg.num_stages_Q + num_stages_dO = cfg.num_stages_dO + num_stages_PdS = cfg.num_stages_PdS + SdP_swapAB = cfg.SdP_swapAB + dKV_swapAB = cfg.dKV_swapAB + dQ_swapAB = cfg.dQ_swapAB + AtomLayoutMSdP = cfg.AtomLayoutMSdP + AtomLayoutNdKV = cfg.AtomLayoutNdKV + AtomLayoutMdQ = cfg.AtomLayoutMdQ + num_threads = (cfg.num_wg + 1) * 128 + dQ_single_wg = cfg.dQ_single_wg cluster_size = 1 use_2cta_instrs = False - assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x" is_varlen = ( cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None ) - assert not is_varlen, "varlen backward is not yet supported on sm90" else: m_block_size = 128 n_block_size = 128 @@ -662,15 +1080,17 @@ def _flash_attn_bwd( dKV_swapAB = False AtomLayoutMdQ = 1 AtomLayoutNdKV = 1 + requested_disable_2cta = utils._get_disable_2cta_default() disable_2cta = ( - local + requested_disable_2cta or score_mod is not None or score_mod_bwd is not None or mask_mod is not None + or block_sparse_tensors is not None ) cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1 use_2cta_instrs = cluster_size==2 - + q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ maybe_contiguous(t) for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) @@ -692,19 +1112,9 @@ def _flash_attn_bwd( seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k num_head_kv = k.shape[-2] - head_dim_v = v.shape[-1] use_block_sparsity = block_sparse_tensors is not None - - # SM90 block-sparse backward: tile_m=64 is the GCD between a m_block_size that fits, - # the base block_m of 128 from forward, and block-sparse size for subtiling. - if arch // 10 == 9 and use_block_sparsity: - m_block_size = 64 - # dQ_swapAB tuning: use False when m_block_size=64 (same as causal case) - dQ_swapAB = False - - # NB: this could be derived from the block_sparse_tensors but for now we hardcode it to 2 - subtile_factor = 2 + subtile_factor = sparse_q // m_block_size if sparse_q is not None else 2 seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size num_n_blocks = seqlen_k_rounded // n_block_size @@ -744,14 +1154,16 @@ def _flash_attn_bwd( if t is not None: assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32" assert lse.dtype == torch.float32, "lse must be float32" - assert all( - t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k) - ), "inputs must be on CUDA device" + if dlse is not None: + dlse = maybe_contiguous(dlse) + if not is_fake_mode(): + assert all( + t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k) + ), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" - assert head_dim <= 256, "head_dim must be less than or equal to 256" alignment = 16 // q.element_size() - assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" - assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" + if arch // 10 != 12: + _validate_head_dims(head_dim, head_dim_v, arch // 10, alignment) if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(head_dim) qhead_per_kvhead = num_head // num_head_kv @@ -759,9 +1171,6 @@ def _flash_attn_bwd( pack_gqa = qhead_per_kvhead > 1 # pack_gqa backward not yet supported in bwd pack_gqa = False - if arch // 10 not in [10, 11]: - assert deterministic is False, "bwd deterministic only supported for sm100/sm110 for now" - if score_mod is not None: assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided" assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)" @@ -813,6 +1222,9 @@ def _flash_attn_bwd( dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) + # GQA (qhead_per_kvhead > 1) needs dK/dV accum+postprocess since multiple Q heads + # accumulate into the same dK/dV. SM90 varlen_k with qhead_per_kvhead==1 now uses + # ragged TMA tensors for direct store, so no longer needs accum+postprocess. dKV_postprocess = qhead_per_kvhead > 1 if dKV_postprocess: head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 @@ -850,83 +1262,30 @@ def _flash_attn_bwd( ) dtype = torch2cute_dtype_map[q.dtype] - current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) if deterministic: - dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, cluster_size, dtype=torch.int32, device="cuda") + dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, cluster_size, dtype=torch.int32, device=device) else: dQ_semaphore = None if deterministic and qhead_per_kvhead > 1: - dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") - dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") + dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device=device) + dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device=device) else: dK_semaphore = None dV_semaphore = None - # Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum. - compile_key_pre = ( - arch, - dtype, - head_dim, - head_dim_v, - m_block_size, - num_threads, - cu_seqlens_q is None, - seqused_q is None, - get_broadcast_dims(out), - get_broadcast_dims(dout), + # Preprocess kernel: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum. + _bwd_preprocess( + out, dout, dpsum, lse, lse_log2, dq_accum, + cu_seqlens_q, seqused_q, dlse, + dtype, head_dim, head_dim_v, m_block_size, ) - if compile_key_pre not in _flash_attn_bwd.compile_cache_pre: - o_tensor, do_tensor = [to_cute_tensor(t) for t in (out, dout)] - dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ - to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2) - ] - lse_tensor = to_cute_tensor(lse, assumed_align=4) - cu_seqlens_q_tensor, seqused_q_tensor = [ - to_cute_tensor(t, assumed_align=4) if t is not None else None - for t in (cu_seqlens_q, seqused_q) - ] - fa_bwd_pre = FlashAttentionBackwardPreprocess( - dtype, - head_dim, - head_dim_v, - arch, - m_block_size, - num_threads=num_threads, - ) - # TODO: check @can_implement - _flash_attn_bwd.compile_cache_pre[compile_key_pre] = cute.compile( - fa_bwd_pre, - o_tensor, - do_tensor, - dpsum_tensor, - lse_tensor, - lse_log2_tensor, - dq_accum_tensor, - cu_seqlens_q_tensor, - seqused_q_tensor, - current_stream, - options="--enable-tvm-ffi", - ) - if not is_fake_mode(): - _flash_attn_bwd.compile_cache_pre[compile_key_pre]( - out, - dout, - dpsum, - lse, - lse_log2, - dq_accum, - cu_seqlens_q, - seqused_q, - current_stream, - ) - - # NB num_threads application for 3 kernels - # There are pre, main, post processing kernels, currenlty num_threads is only actually - # used for the pre proc, and then we hard code to 384 for the main and post proc, and we do - # before cache key gen - num_threads = 384 + # num_threads: SM90 derives from BwdConfig.num_wg, SM120 is set to 128 above, + # SM100/SM110 uses default from function signature (384). + if arch // 10 not in [9, 12]: + num_threads = 384 # Backward kernel: compute dk, dv, dq_accum. score_mod_hash = utils.hash_callable(score_mod) if score_mod else False @@ -953,7 +1312,7 @@ def _flash_attn_bwd( subtile_factor=subtile_factor, ) - if arch // 10 == 9: + if arch // 10 in [8, 9, 12]: compile_key = ( arch, dtype, @@ -961,6 +1320,8 @@ def _flash_attn_bwd( head_dim_v, qhead_per_kvhead, causal, + window_size_left is not None, + window_size_right is not None, softcap != 0.0, m_block_size, n_block_size, @@ -975,6 +1336,8 @@ def _flash_attn_bwd( AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs, + dQ_single_wg, + deterministic, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, @@ -1043,51 +1406,56 @@ def _flash_attn_bwd( if t is not None else None for t in (dQ_semaphore, dK_semaphore, dV_semaphore) ] - fa_bwd_sm80 = FlashAttentionBackwardSm80( - dtype, - head_dim, - head_dim_v, - qhead_per_kvhead, - m_block_size, - n_block_size, - num_stages_Q, - num_stages_dO, - num_threads, - pack_gqa, - causal, - SdP_swapAB, - dKV_swapAB, - dQ_swapAB, - AtomLayoutMSdP, - AtomLayoutNdKV, - AtomLayoutMdQ, - V_in_regs=V_in_regs, - ) - if arch // 10 == 9: - fa_bwd_obj = FlashAttentionBackwardSm90( + if arch // 10 in [8, 12]: + flash_bwd_obj_cls = FlashAttentionBackwardSm120 if arch // 10 == 12 else FlashAttentionBackwardSm80 + fa_bwd_obj = flash_bwd_obj_cls( dtype, head_dim, head_dim_v, qhead_per_kvhead, - causal, m_block_size, n_block_size, num_stages_Q, num_stages_dO, - num_stages_PdS, + num_threads, + pack_gqa, + causal, SdP_swapAB, dKV_swapAB, dQ_swapAB, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, - num_threads, + V_in_regs=V_in_regs, + ) + elif arch // 10 == 9: + fa_bwd_obj = FlashAttentionBackwardSm90( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + causal, + is_local=local, + deterministic=deterministic, + tile_m=m_block_size, + tile_n=n_block_size, + Q_stage=num_stages_Q, + dO_stage=num_stages_dO, + PdS_stage=num_stages_PdS, + SdP_swapAB=SdP_swapAB, + dKV_swapAB=dKV_swapAB, + dQ_swapAB=dQ_swapAB, + AtomLayoutMSdP=AtomLayoutMSdP, + AtomLayoutNdKV=AtomLayoutNdKV, + AtomLayoutMdQ=AtomLayoutMdQ, + num_threads=num_threads, V_in_regs=V_in_regs, score_mod=score_mod, score_mod_bwd=score_mod_bwd, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, subtile_factor=subtile_factor, + dQ_single_wg=dQ_single_wg, ) else: fa_bwd_obj = FlashAttentionBackwardSm100( @@ -1126,7 +1494,6 @@ def _flash_attn_bwd( dk_tensor if not dKV_postprocess else dk_accum_tensor, dv_tensor if not dKV_postprocess else dv_accum_tensor, softmax_scale, - current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, @@ -1139,6 +1506,7 @@ def _flash_attn_bwd( dV_semaphore_tensor, cute_aux_tensors, sparse_tensors_compile, + current_stream, options="--enable-tvm-ffi", ) if not is_fake_mode(): @@ -1153,7 +1521,6 @@ def _flash_attn_bwd( dk if not dKV_postprocess else dk_accum, dv if not dKV_postprocess else dv_accum, softmax_scale, - current_stream, cu_seqlens_q, cu_seqlens_k, seqused_q, @@ -1168,157 +1535,45 @@ def _flash_attn_bwd( normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, ) - num_threads = 256 if arch // 10 == 9 else 128 - # Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16 - compile_key_post = ( - arch, - dtype, - head_dim, - m_block_size, - num_threads, - AtomLayoutMdQ, - dQ_swapAB, - cu_seqlens_q is None, - seqused_q is None, - use_2cta_instrs, - 1, # no cluster for tile_m - get_broadcast_dims(dq_accum), - get_broadcast_dims(dq), + if arch // 10 == 9: + # dQ postprocess: match main kernel's MMA WG count, unless dQ_single_wg + num_threads_post_dQ = 128 if dQ_single_wg else cfg.num_wg * 128 + num_threads_post_dKV = cfg.num_wg * 128 + else: + num_threads_post_dQ = 128 + num_threads_post_dKV = 128 + + # Postprocess: convert dq_accum from float32 to dq in bf16/fp16 + _bwd_postprocess_convert( + dq_accum, dq, softmax_scale, + cu_seqlens_q, seqused_q, + arch, dtype, head_dim, m_block_size, num_threads_post_dQ, + AtomLayoutMdQ, dQ_swapAB, + use_2cta_instrs=use_2cta_instrs, cluster_size=1, ) - if compile_key_post not in _flash_attn_bwd.compile_cache_post: - dq_accum_tensor = to_cute_tensor(dq_accum) - dq_tensor = to_cute_tensor(dq) - cu_seqlens_q_tensor, seqused_q_tensor = [ - to_cute_tensor(t, assumed_align=4) if t is not None else None - for t in (cu_seqlens_q, seqused_q) - ] - fa_bwd_post = FlashAttentionBackwardPostprocess( - dtype, head_dim, arch, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB, - use_2cta_instrs=use_2cta_instrs, - ) - # TODO: check @can_implement - _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, - dq_accum_tensor, - dq_tensor, - softmax_scale, - cu_seqlens_q_tensor, - seqused_q_tensor, - current_stream, - options="--enable-tvm-ffi", - ) - - if not is_fake_mode(): - _flash_attn_bwd.compile_cache_post[compile_key_post]( - dq_accum, - dq, - softmax_scale, - cu_seqlens_q, - seqused_q, - current_stream, - ) if dKV_postprocess: - # Postprocess kernel: convert dk_accum & dv_accum from float32 to bf16/fp16 - compile_key_post = ( - arch, - dtype, - head_dim, - n_block_size, - num_threads, - AtomLayoutNdKV, - dKV_swapAB, - cu_seqlens_k is None, - seqused_k is None, - False, # even for 2cta, is split along hdim, so always False - cluster_size, # cluster is for tile_n - get_broadcast_dims(dk_accum), - get_broadcast_dims(dk), + # Postprocess: convert dk_accum from float32 to dk in bf16/fp16 + _bwd_postprocess_convert( + dk_accum, dk, softmax_scale, + cu_seqlens_k, seqused_k, + arch, dtype, head_dim, n_block_size, num_threads_post_dKV, + AtomLayoutNdKV, dKV_swapAB, + cluster_size=cluster_size, ) - if compile_key_post not in _flash_attn_bwd.compile_cache_post: - dk_accum_tensor = to_cute_tensor(dk_accum) - dk_tensor = to_cute_tensor(dk) - cu_seqlens_k_tensor, seqused_k_tensor = [ - to_cute_tensor(t, assumed_align=4) if t is not None else None - for t in (cu_seqlens_k, seqused_k) - ] - fa_bwd_post = FlashAttentionBackwardPostprocess( - dtype, head_dim, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB, - cluster_size=cluster_size, - ) - # TODO: check @can_implement - _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, - dk_accum_tensor, - dk_tensor, - softmax_scale, - cu_seqlens_k_tensor, - seqused_k_tensor, - current_stream, - options="--enable-tvm-ffi", - ) - if not is_fake_mode(): - _flash_attn_bwd.compile_cache_post[compile_key_post]( - dk_accum, - dk, - softmax_scale, - cu_seqlens_k, - seqused_k, - current_stream, - ) - compile_key_post = ( - arch, - dtype, - head_dim_v, - n_block_size, - num_threads, - AtomLayoutNdKV, - dKV_swapAB, - cu_seqlens_k is None, - seqused_k is None, - False, - cluster_size, - get_broadcast_dims(dv_accum), - get_broadcast_dims(dv), + # Postprocess: convert dv_accum from float32 to dv in bf16/fp16 + _bwd_postprocess_convert( + dv_accum, dv, 1.0, + cu_seqlens_k, seqused_k, + arch, dtype, head_dim_v, n_block_size, num_threads_post_dKV, + AtomLayoutNdKV, dKV_swapAB, + cluster_size=cluster_size, ) - if compile_key_post not in _flash_attn_bwd.compile_cache_post: - dv_accum_tensor = to_cute_tensor(dv_accum) - dv_tensor = to_cute_tensor(dv) - cu_seqlens_k_tensor, seqused_k_tensor = [ - to_cute_tensor(t, assumed_align=4) if t is not None else None - for t in (cu_seqlens_k, seqused_k) - ] - fa_bwd_post = FlashAttentionBackwardPostprocess( - dtype, head_dim_v, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB, - cluster_size=cluster_size, - ) - # TODO: check @can_implement - _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( - fa_bwd_post, - dv_accum_tensor, - dv_tensor, - cutlass.Float32(1.0), - cu_seqlens_k_tensor, - seqused_k_tensor, - current_stream, - options="--enable-tvm-ffi", - ) - if not is_fake_mode(): - _flash_attn_bwd.compile_cache_post[compile_key_post]( - dv_accum, - dv, - 1.0, - cu_seqlens_k, - seqused_k, - current_stream, - ) return dq, dk, dv -_flash_attn_bwd.compile_cache_pre = get_jit_cache("bwd_pre") _flash_attn_bwd.compile_cache = get_jit_cache("bwd") -_flash_attn_bwd.compile_cache_post = get_jit_cache("bwd_post") class FlashAttnFunc(torch.autograd.Function): @@ -1376,14 +1631,17 @@ def forward( ctx.window_size = window_size ctx.softcap = softcap ctx.deterministic = deterministic - # LSE gradient is not supported yet - if lse is not None: - ctx.mark_non_differentiable(lse) + ctx.return_lse = return_lse + ctx.set_materialize_grads(False) return out, lse @staticmethod - def backward(ctx, dout, *args): + def backward(ctx, dout, dlse): q, k, v, out, lse = ctx.saved_tensors + if not ctx.return_lse: + dlse = None + if dout is None: + dout = torch.zeros_like(out) dq, dk, dv = _flash_attn_bwd( q, k, @@ -1397,6 +1655,7 @@ def backward(ctx, dout, *args): window_size_left=ctx.window_size[0], window_size_right=ctx.window_size[1], deterministic=ctx.deterministic, + dlse=dlse, ) return dq, dk, dv, *((None,) * 20) # Extra Nones is fine @@ -1458,15 +1717,18 @@ def forward( ctx.deterministic = deterministic ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k - # LSE gradient is not supported yet - if lse is not None: - ctx.mark_non_differentiable(lse) + ctx.return_lse = return_lse + ctx.set_materialize_grads(False) return out, lse @staticmethod - def backward(ctx, dout, *args): + def backward(ctx, dout, dlse): q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors assert ctx.softcap == 0.0 + if not ctx.return_lse: + dlse = None + if dout is None: + dout = torch.zeros_like(out) dq, dk, dv = _flash_attn_bwd( q, k, @@ -1486,6 +1748,7 @@ def backward(ctx, dout, *args): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_k=ctx.max_seqlen_k, deterministic=ctx.deterministic, + dlse=dlse, ) return dq, dk, dv, *((None,) * 20) @@ -1581,6 +1844,63 @@ def flash_attn_varlen_func( ) +def _compile_fwd_combine( + dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits, + has_cu_seqlens, has_seqused, has_lse, has_varlen_batch_idx, +): + """Compile fwd combine kernel using cute fake tensors (no real GPU tensors needed).""" + sym = cute.sym_int + div = 128 // dtype_partial.width # 16-byte alignment in elements + + fa_combine = FlashAttentionForwardCombine( + dtype=dtype, + dtype_partial=dtype_partial, + head_dim=head_dim, + tile_m=tile_m, + k_block_size=k_block_size, + log_max_splits=log_max_splits, + ) + if not fa_combine.can_implement( + dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits, + num_threads=256, + ): + raise RuntimeError( + "FlashAttention combine kernel cannot be implemented with given parameters" + ) + + if has_cu_seqlens: + # Varlen: (num_splits, total_q, nheads, headdim) + num_splits, total_q, nheads = sym(), sym(), sym() + mO_partial = fake_tensor(dtype_partial, (num_splits, total_q, nheads, head_dim), divisibility=div) + mLSE_partial = fake_tensor(Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=1) + mO = fake_tensor(dtype, (total_q, nheads, head_dim), divisibility=div) + mLSE = fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=0) if has_lse else None + else: + # Batched: (num_splits, batch, seqlen, nheads, headdim) + num_splits, batch, seqlen, nheads = sym(), sym(), sym(), sym() + mO_partial = fake_tensor(dtype_partial, (num_splits, batch, seqlen, nheads, head_dim), divisibility=div) + mLSE_partial = fake_tensor(Float32, (num_splits, batch, seqlen, nheads), divisibility=1, leading_dim=2) + mO = fake_tensor(dtype, (batch, seqlen, nheads, head_dim), divisibility=div) + mLSE = fake_tensor(Float32, (batch, seqlen, nheads), divisibility=1, leading_dim=1) if has_lse else None + batch = mO_partial.shape[1] + + batch_for_1d = batch if not has_cu_seqlens else sym() + batchp1 = sym() + mCuSeqlens = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cu_seqlens else None + mSeqused = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_seqused else None + mNumSplitsDynamic = None # Not parametrized in compile_key + mVarlenBatchIdx = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_varlen_batch_idx else None + mSemaphore = None # Not parametrized in compile_key + + return cute.compile( + fa_combine, + mO_partial, mLSE_partial, mO, mLSE, + mCuSeqlens, mSeqused, mNumSplitsDynamic, mVarlenBatchIdx, mSemaphore, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + + def _flash_attn_fwd_combine( out_partial: torch.Tensor, lse_partial: torch.Tensor, @@ -1589,6 +1909,7 @@ def _flash_attn_fwd_combine( cu_seqlens: Optional[torch.Tensor] = None, seqused: Optional[torch.Tensor] = None, num_splits_dynamic_ptr: Optional[torch.Tensor] = None, + varlen_batch_idx: Optional[torch.Tensor] = None, semaphore_to_reset: Optional[torch.Tensor] = None, ) -> None: """Forward combine kernel for split attention computation. @@ -1612,27 +1933,13 @@ def _flash_attn_fwd_combine( Returns: None """ - # Input validation - assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" - assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions" assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], ( "out_partial must be fp16, bf16, or fp32" ) - assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" - assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device" - assert out_partial.stride(-1) == 1, "out_partial must be contiguous in the last dimension" - assert lse_partial.stride(-2) == 1, "lse_partial must be contiguous in the seqlen dimension" - assert lse_partial.shape == out_partial.shape[:-1] - + if not is_fake_mode(): + assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device" # Determine if this is variable length based on dimensions is_varlen = out_partial.dim() == 4 - - # Validate output tensor shapes and types - assert out.shape == out_partial.shape[1:], "out shape mismatch" - if lse is not None: - assert lse.shape == lse_partial.shape[1:], "lse shape mismatch" - assert lse.dtype == torch.float32, "lse must be fp32" - # Validate optional tensors for t, name in [ (cu_seqlens, "cu_seqlens"), @@ -1640,10 +1947,9 @@ def _flash_attn_fwd_combine( (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"), ]: if t is not None: - assert t.dtype == torch.int32, f"{name} must be int32" - assert t.is_cuda, f"{name} must be on CUDA device" + if not is_fake_mode(): + assert t.is_cuda, f"{name} must be on CUDA device" assert t.is_contiguous(), f"{name} must be contiguous" - head_dim = out_partial.shape[-1] num_splits = out_partial.shape[0] assert num_splits <= 256 @@ -1652,101 +1958,37 @@ def _flash_attn_fwd_combine( k_block_size = 64 if head_dim <= 64 else 128 # We want kBlockM to be as small as possible to maximize parallelism. # E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). - m_block_size = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32) + tile_m = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32) log_max_splits = max(math.ceil(math.log2(num_splits)), 4) - if m_block_size == 8: + if tile_m == 8: # If kBlockM == 8 then the minimum number of splits is 32. # TODO: we can deal w this by using 128 threads instead log_max_splits = max(log_max_splits, 5) - current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - # Create combine kernel configuration dtype = torch2cute_dtype_map[out.dtype] dtype_partial = torch2cute_dtype_map[out_partial.dtype] - compile_key = ( dtype, dtype_partial, head_dim, - m_block_size, + tile_m, k_block_size, log_max_splits, cu_seqlens is not None, seqused is not None, lse is not None, + varlen_batch_idx is not None, ) - if compile_key not in _flash_attn_fwd_combine.compile_cache: - out_partial_tensor = to_cute_tensor( - out_partial, leading_dim=4 if not is_varlen else 3 - ) - lse_partial_tensor = to_cute_tensor( - lse_partial, assumed_align=4, leading_dim=lse_partial.ndim - 2 - ) - out_tensor = to_cute_tensor(out, leading_dim=3 if not is_varlen else 2) - lse_tensor = ( - to_cute_tensor(lse, assumed_align=4, leading_dim=lse.ndim - 2) - if lse is not None - else None - ) - - optional_tensors = [ - to_cute_tensor(t, assumed_align=4, leading_dim=0) - if t is not None - else None - for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset) - ] - cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = ( - optional_tensors - ) - fa_combine = FlashAttentionForwardCombine( - dtype=dtype, - dtype_partial=dtype_partial, - head_dim=head_dim, - m_block_size=m_block_size, - k_block_size=k_block_size, - log_max_splits=log_max_splits, - ) - - # Check if implementation is supported - if not fa_combine.can_implement( - dtype, - dtype_partial, - head_dim, - m_block_size, - k_block_size, - log_max_splits, - num_threads=256, - ): - raise RuntimeError( - "FlashAttention combine kernel cannot be implemented with given parameters" - ) - - _flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile( - fa_combine, - out_partial_tensor, - lse_partial_tensor, - out_tensor, - lse_tensor, - cu_seqlens_tensor, - seqused_tensor, - num_splits_dynamic_tensor, - semaphore_tensor, - current_stream, - options="--enable-tvm-ffi", + _flash_attn_fwd_combine.compile_cache[compile_key] = _compile_fwd_combine( + *compile_key ) if not is_fake_mode(): _flash_attn_fwd_combine.compile_cache[compile_key]( - out_partial, - lse_partial, - out, - lse, - cu_seqlens, - seqused, - num_splits_dynamic_ptr, + out_partial, lse_partial, out, lse, + cu_seqlens, seqused, num_splits_dynamic_ptr, varlen_batch_idx, semaphore_to_reset, - current_stream, ) @@ -1760,6 +2002,7 @@ def flash_attn_combine( out_dtype: Optional[torch.dtype] = None, cu_seqlens: Optional[torch.Tensor] = None, seqused: Optional[torch.Tensor] = None, + varlen_batch_idx: Optional[torch.Tensor] = None, return_lse: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Flash Attention combine function for split attention computation. @@ -1779,6 +2022,9 @@ def flash_attn_combine( out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input. cu_seqlens: Cumulative sequence lengths for variable length sequences seqused: Used sequence lengths for each batch + varlen_batch_idx: Optional mapping from virtual batch index to real batch index + (int32 tensor of shape (batch_size,)). Used by persistent tile schedulers + that reorder batch processing for load balancing. return_lse: Whether to return the combined LSE tensor. Default is True. Returns: @@ -1795,32 +2041,19 @@ def flash_attn_combine( """ # Input validation assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" - assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions" - assert out_partial.dtype == torch.float32, "out_partial must be fp32 (from accumulation)" - assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" - # Determine if this is variable length based on dimensions is_varlen = out_partial.dim() == 4 - if is_varlen: # Variable length: (num_splits, total_q, num_heads, head_size) num_splits, total_q, num_heads, head_size = out_partial.shape - assert lse_partial.shape == (num_splits, total_q, num_heads), ( - "lse_partial shape mismatch for varlen" - ) batch_size = 1 # Treat as single batch for varlen seqlen = total_q else: # Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size) num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape - assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), ( - "lse_partial shape mismatch" - ) - # Determine output dtype if out_dtype is None: out_dtype = out_partial.dtype - # Create output if not provided device = out_partial.device if out is None: @@ -1830,20 +2063,15 @@ def flash_attn_combine( out = torch.empty( batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device ) - # Create lse output only if requested if return_lse: if is_varlen: - lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device).transpose( - 0, 1 - ) + lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device) else: - lse = torch.empty( - batch_size, num_heads, seqlen, dtype=torch.float32, device=device - ).transpose(1, 2) + lse = torch.empty(batch_size, num_heads, seqlen, dtype=torch.float32, device=device) + lse = lse.transpose(-1, -2) else: lse = None - _flash_attn_fwd_combine( out_partial, lse_partial, @@ -1851,5 +2079,6 @@ def flash_attn_combine( lse, cu_seqlens, seqused, + varlen_batch_idx=varlen_batch_idx, ) return out, lse diff --git a/flash-attn4/torch-ext/flash_attn4/mask.py b/flash-attn4/torch-ext/flash_attn4/mask.py index 8da83570..da9da248 100644 --- a/flash-attn4/torch-ext/flash_attn4/mask.py +++ b/flash-attn4/torch-ext/flash_attn4/mask.py @@ -1,109 +1,102 @@ # Copyright (c) 2025, Tri Dao. -from typing import Optional, Callable +from typing import Optional, Callable, TypeAlias from dataclasses import dataclass import cutlass import cutlass.cute as cute -from cutlass import Float32, Int32, const_expr +from cutlass import Float32, Int32, Uint32, const_expr from .quack import layout_utils -from . import utils +from . import utils as utils from .seqlen_info import SeqlenInfoQK +MaskGenFn: TypeAlias = Callable[[int], Uint32] +MASK_R2P_CHUNK_SIZE: int = 32 + @cute.jit -def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = False) -> None: - # Bit manipulation, compiles down to the R2P instruction - # For sm100: we know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using. - # For sm90: instead of comparing limit to 0, 1, 8, 9, 16, 17, ..., - # we compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ... - if const_expr(arch == 90): - col_limit_transformed = col_limit // 8 * 2 + min(col_limit % 8, 2) - else: - col_limit_transformed = col_limit - ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape)) - # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31 - for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): - # Don't need to clamp to 32 since the shr.u32 instruction does that already - col_limit_right_s = max(col_limit_transformed - s * 24, 0) - # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 - mask = (1 << col_limit_right_s) - 1 - # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction - for i in cutlass.range_constexpr(min(24, ncol - s * 24)): - in_bound = cutlass.Boolean(mask & (1 << i)) - c = s * 24 + i - if const_expr(rank1): - X[c] = X[c] if in_bound else -Float32.inf - # This is the equivalent of: - # X[s * 24 + i] = X[s * 24 + i] if col_limit_right_s <= i else -Float32.inf - else: - for r in cutlass.range_constexpr(cute.size(X.shape[0])): - X[r, c] = X[r, c] if in_bound else -Float32.inf +def r2p_bitmask_below(limit: Int32, s: int) -> Uint32: + """32-bit R2P bitmask keeping positions < limit (exclusive upper bound). + + Positions 0..limit-1 in chunk `s` get bit=1 (keep), the rest bit=0 (mask). + Uses inline PTX to avoid shift-by-type-width UB. + """ + m = max((s + 1) * MASK_R2P_CHUNK_SIZE - limit, 0) + return utils.shr_u32(Uint32(0xFFFFFFFF), Uint32(m)) @cute.jit -def mask_r2p_transposed(X: cute.Tensor, row_limit_top: Int32, num_rep: int) -> None: - # Bit manipulation, compiles down to the R2P instruction - # For sm100: we know that tScS_t2r[i][0] has the form 0, 1, ..., 31, 64, ..., 127 - # or 0, 1, ..., 15, 32, ..., 47, 64, ... - # We compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ... - # Here we hardcode for the case of 2 warp groups. - num_wg = 2 - row_limit_top_transformed = row_limit_top // (num_rep * num_wg) * num_rep + min( - row_limit_top % (num_rep * num_wg), num_rep - ) - ncol = cute.size(X.shape) - # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31 - for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): - row_limit_top_s = max(row_limit_top_transformed - s * 24, 0) - # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 - mask = (1 << row_limit_top_s) - 1 - # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction - for i in cutlass.range_constexpr(min(24, ncol - s * 24)): - out_bound = cutlass.Boolean(mask & (1 << i)) - c = s * 24 + i - X[c] = -Float32.inf if out_bound else X[c] - # tidx = cute.arch.thread_idx()[0] % 256 - # if tidx == 128: - # cute.printf("tidx = {}, s = {}, i = {}, row_limit_top = {}, row_limit_top_s = {}, mask = {}, out_bound = {}", tidx, s, i, row_limit_top, row_limit_top_s, mask, out_bound) +def r2p_bitmask_above(limit: Int32, s: int) -> Uint32: + """32-bit R2P bitmask keeping positions >= limit (inclusive lower bound). + + Positions limit..31 in chunk `s` get bit=1 (keep), the rest bit=0 (mask). + Uses inline PTX to avoid shift-by-type-width UB. + """ + n = max(limit - s * MASK_R2P_CHUNK_SIZE, 0) + return utils.shl_u32(Uint32(0xFFFFFFFF), Uint32(n)) @cute.jit -def mask_r2p_dual_bound( +def mask_r2p_lambda( X: cute.Tensor, - col_limit_left: Int32, # Inclusive lower bound - col_limit_right: Int32, # Exclusive upper bound + mask_gen_fn: cutlass.Constexpr[MaskGenFn], + rank1: bool = False, ) -> None: - """ - Dual-bound masking using two bitmasks for SM100, following mask_r2p. - Masks elements where: NOT (col_limit_left <= col < col_limit_right) + """Apply R2P masking with a custom bitmask generator. - Uses bit manipulation to create a range mask: - mask_right = (1 << right) - 1 -> bits (right-1)..0 are 1 - mask_left = (1 << left) - 1 -> bits (left-1)..0 are 1 - mask_range = mask_range = mask_right & ~ mask_left -> bits (right-1)..left are 1 + mask_gen_fn(chunk_idx: constexpr int) -> Uint32: + Returns a 32-bit bitmask for the chunk. Bit i set means column + chunk_idx * chunk_size + i is KEPT; bit i clear means masked to -inf. """ - ncol = const_expr(cute.size(X.shape)) + ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape)) + # 32-column chunks. The mask_gen_fn returns a Uint32 bitmask (1=keep). + CHUNK_SIZE = MASK_R2P_CHUNK_SIZE + for s in cutlass.range_constexpr(cute.ceil_div(ncol, CHUNK_SIZE)): + mask = mask_gen_fn(s) + # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction + for i in cutlass.range_constexpr(min(CHUNK_SIZE, ncol - s * CHUNK_SIZE)): + in_bound = cutlass.Boolean(mask & (Uint32(1) << i)) + c = s * CHUNK_SIZE + i + if const_expr(rank1): + X[c] = X[c] if in_bound else -Float32.inf + else: + for r in cutlass.range_constexpr(cute.size(X.shape[0])): + X[r, c] = X[r, c] if in_bound else -Float32.inf - for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): - right_s = max(col_limit_right - s * 24, 0) - left_s = max(col_limit_left - s * 24, 0) - # otherwise cute dsl complains about python int too large to convert into c long - right_s = min(right_s, 24) - left_s = min(left_s, 24) +@cute.jit +def sm90_col_to_r2p_idx(col_limit: Int32) -> Int32: + """Transform SM90 MMA column coordinate to R2P element index. - # bits (right-1)..left are 1 - mask_right = (1 << right_s) - 1 - mask_left = (1 << left_s) - 1 - mask_range = mask_right & ~mask_left + SM90 MMA accumulator column indices are non-contiguous: 0, 1, 8, 9, 16, 17, ... + Element indices are contiguous: 0, 1, 2, 3, 4, 5, ... + This converts a column-space threshold to element-space for r2p_bitmask_below/above. + """ + return col_limit // 8 * 2 + min(col_limit % 8, 2) - # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction - for i in cutlass.range_constexpr(min(24, ncol - s * 24)): - in_bound = cutlass.Boolean(mask_range & (1 << i)) - c = s * 24 + i - X[c] = X[c] if in_bound else -Float32.inf + +@cute.jit +def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32: + """Convert a row coordinate to an R2P element index in the warp-group interleaved layout. + + In the SM100 backward pass, 2 warp groups share TMEM. The TMEM load atom + distributes rows in an interleaved pattern: elements 0..num_rep-1 map to + rows 0..num_rep-1 (warp group 0), elements num_rep..2*num_rep-1 map to + rows num_rep*num_wg..num_rep*num_wg+num_rep-1 (warp group 1), and so on. + Row-coordinate thresholds (causal limits, window bounds, uih_len) must be + converted to element indices before use with r2p_bitmask_above/below. + + Rows not owned by this thread (in the gap between warp groups) are clamped + to the boundary element index, which is safe because R2P thresholds are + monotonic. + + Example with num_rep=16, num_wg=2: + row 0 -> elem 0, row 15 -> elem 15, + row 16 -> elem 16 (clamped), row 31 -> elem 16 (clamped), + row 32 -> elem 16, row 33 -> elem 17, row 47 -> elem 31. + """ + return x // (num_rep * num_wg) * num_rep + min(x % (num_rep * num_wg), num_rep) @dataclass(frozen=True) @@ -161,8 +154,7 @@ def apply_mask( seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset if const_expr(not mask_causal and not mask_local and mask_mod is None): if const_expr(mask_seqlen): - # The compiler now choses not to use R2P - r2p = const_expr(False and not self.swap_AB) + r2p = const_expr(not self.swap_AB) if const_expr(not r2p): # traverse column index. for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): @@ -170,7 +162,8 @@ def apply_mask( for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c] else: - mask_r2p(acc_S_mn, seqlenk_col_limit, arch=90) + seqlenk_col_limit_r2p = sm90_col_to_r2p_idx(seqlenk_col_limit) + mask_r2p_lambda(acc_S_mn, lambda s: r2p_bitmask_below(seqlenk_col_limit_r2p, s)) elif const_expr( not mask_causal and not mask_local and mask_mod is not None @@ -272,7 +265,12 @@ def apply_mask( else acc_S_mn[r, c] ) else: - mask_r2p(acc_S_mn[r, None], col_limit_right, arch=90, rank1=True) + col_limit_r2p = sm90_col_to_r2p_idx(col_limit_right) + mask_r2p_lambda( + acc_S_mn[r, None], + lambda s: r2p_bitmask_below(col_limit_r2p, s), + rank1=True, + ) else: # Local local_row_offset_right = ( causal_row_offset + self.window_size_right @@ -284,6 +282,7 @@ def apply_mask( if const_expr(self.window_size_left is not None) else None ) + r2p_local = const_expr(not self.swap_AB) for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): if const_expr(self.qhead_per_kvhead_packgqa == 1): row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m @@ -302,13 +301,22 @@ def apply_mask( if const_expr(self.window_size_left is not None) else 0 ) - # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block = {}, r = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", n_block, r, row_idx, causal_row_offset, col_limit_right, col_limit_left) - # traverse column index. - for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): - col_idx = t0ScS_mn[0, c][1] - # only consider the column index, so the row index sets to 0. - if col_idx >= col_limit_right or col_idx < col_limit_left: - acc_S_mn[r, c] = -Float32.inf + if const_expr(not r2p_local): + # traverse column index. + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + col_idx = t0ScS_mn[0, c][1] + if col_idx >= col_limit_right or col_idx < col_limit_left: + acc_S_mn[r, c] = -Float32.inf + else: + col_limit_right_r2p = sm90_col_to_r2p_idx(col_limit_right) + col_limit_left_r2p = sm90_col_to_r2p_idx(col_limit_left) + + def mask_gen_fn(s: int) -> Uint32: + return r2p_bitmask_below( + col_limit_right_r2p, s + ) & r2p_bitmask_above(col_limit_left_r2p, s) + + mask_r2p_lambda(acc_S_mn[r, None], mask_gen_fn, rank1=True) else: # swap_AB assert self.qhead_per_kvhead_packgqa == 1 thr_row_offset = tScS_mn[0][ROW] @@ -338,11 +346,18 @@ def apply_mask( # column, by setting row limit to be self.tile_m. row_limit_top = ( self.tile_m - if col0 >= seqlenk_col_limit - else col0 - causal_row_offset - self.window_size_right + if col0 >= seqlenk_col_limit and mask_seqlen + else ( + col0 - causal_row_offset - self.window_size_right + if const_expr(self.window_size_right is not None) + else 0 + ) + ) + row_limit_bot = ( + col0 - causal_row_offset + self.window_size_left + if const_expr(self.window_size_left is not None) + else self.tile_m ) - # TODO: do we need col_limit_sink? - row_limit_bot = col0 - causal_row_offset + self.window_size_left for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): row_idx = t0ScS_mn[r, 0][ROW] acc_S_mn[r, c] = ( @@ -392,7 +407,11 @@ def apply_mask_sm100( # For some reason the 2 lines above generate really bad SASS acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i] else: - mask_r2p(acc_S, seqlenk_col_limit, arch=100, rank1=True) + mask_r2p_lambda( + acc_S, + lambda s: r2p_bitmask_below(seqlenk_col_limit, s), + rank1=True, + ) elif const_expr(not mask_causal and not mask_local and mask_mod is not None): # Block sparse case w/ mask_mod @@ -445,12 +464,12 @@ def apply_mask_sm100( acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i] else: # Causal or local - causal_row_offset = 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q + causal_row_offset = self.seqlen_k - n_block * self.tile_n - self.seqlen_q row_idx = tScS_t2r[0][0] + m_block * self.tile_m if const_expr(self.qhead_per_kvhead_packgqa != 1): row_idx = row_idx // self.qhead_per_kvhead_packgqa if const_expr(mask_causal): - col_limit_right = row_idx + causal_row_offset + col_limit_right = row_idx + causal_row_offset + 1 if const_expr(mask_seqlen): col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) # if cute.arch.thread_idx()[0] % 32 == 0: @@ -460,15 +479,19 @@ def apply_mask_sm100( for i in cutlass.range(ncol, unroll_full=True): acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] else: - mask_r2p(acc_S, col_limit_right, arch=100, rank1=True) + mask_r2p_lambda( + acc_S, + lambda s: r2p_bitmask_below(col_limit_right, s), + rank1=True, + ) else: local_row_offset_right = ( - causal_row_offset + self.window_size_right + causal_row_offset + 1 + self.window_size_right if const_expr(self.window_size_right is not None) else None ) local_row_offset_left = ( - causal_row_offset - 1 - self.window_size_left + causal_row_offset - self.window_size_left if const_expr(self.window_size_left is not None) else None ) @@ -493,8 +516,15 @@ def apply_mask_sm100( else acc_S[i] ) else: - # XOR-based R2P dual bound masking - mask_r2p_dual_bound(acc_S, col_limit_left, col_limit_right) + # Dual-bound R2P masking for SM100. + # Masks elements where: NOT (col_limit_left <= col < col_limit_right) + + def mask_gen_fn(s: int) -> Uint32: + return r2p_bitmask_below(col_limit_right, s) & r2p_bitmask_above( + col_limit_left, s + ) + + mask_r2p_lambda(acc_S, mask_gen_fn, rank1=True) @cute.jit def apply_mask_sm100_transposed( @@ -634,7 +664,13 @@ def apply_mask_sm100_transposed( ) else: num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32 - mask_r2p_transposed(acc_S, row_limit_top, num_rep) + num_wg = 2 + row_limit = row_to_r2p_idx(row_limit_top, num_rep, num_wg) + mask_r2p_lambda( + acc_S, + lambda s: r2p_bitmask_above(row_limit, s), + rank1=True, + ) else: if const_expr(self.window_size_right is not None): row_limit_top = causal_offset - self.window_size_right @@ -645,9 +681,31 @@ def apply_mask_sm100_transposed( if const_expr(mask_seqlen): if seqlenk_col_limit <= 0: row_limit_top = self.tile_m - for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): - row_idx = t0ScS_t2r[i][ROW] - local_mask = row_idx < row_limit_top - if const_expr(self.window_size_left is not None): - local_mask |= row_idx > row_limit_bot - acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i] + r2p = True + if const_expr(not r2p): + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + row_idx = t0ScS_t2r[i][ROW] + local_mask = row_idx < row_limit_top + if const_expr(self.window_size_left is not None): + local_mask |= row_idx > row_limit_bot + acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i] + else: + + def mask_gen_fn(s: int) -> Uint32: + num_rep = cute.size(tScS_t2r, mode=[0]) + num_wg = 2 + + row_limit = row_to_r2p_idx(row_limit_top, num_rep, num_wg) + mask = r2p_bitmask_above(row_limit, s) + + if const_expr(self.window_size_left is not None): + row_limit_bottom = row_to_r2p_idx(row_limit_bot + 1, num_rep, num_wg) + mask = mask & r2p_bitmask_below(row_limit_bottom, s) + + return mask + + mask_r2p_lambda( + acc_S, + mask_gen_fn, + rank1=True, + ) diff --git a/flash-attn4/torch-ext/flash_attn4/named_barrier.py b/flash-attn4/torch-ext/flash_attn4/named_barrier.py index eadac4b9..dd0d1988 100644 --- a/flash-attn4/torch-ext/flash_attn4/named_barrier.py +++ b/flash-attn4/torch-ext/flash_attn4/named_barrier.py @@ -12,6 +12,19 @@ class NamedBarrierFwd(enum.IntEnum): PEmpty = enum.auto() +class NamedBarrierFwdSm100(enum.IntEnum): + Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() + TmemPtr = enum.auto() + SoftmaxStatsW0 = enum.auto() + SoftmaxStatsW1 = enum.auto() + SoftmaxStatsW2 = enum.auto() + SoftmaxStatsW3 = enum.auto() + SoftmaxStatsW4 = enum.auto() + SoftmaxStatsW5 = enum.auto() + SoftmaxStatsW6 = enum.auto() + SoftmaxStatsW7 = enum.auto() + + class NamedBarrierBwd(enum.IntEnum): Epilogue = enum.auto() WarpSchedulerWG1 = enum.auto() @@ -20,8 +33,10 @@ class NamedBarrierBwd(enum.IntEnum): PdS = enum.auto() dQFullWG0 = enum.auto() dQFullWG1 = enum.auto() + dQFullWG2 = enum.auto() dQEmptyWG0 = enum.auto() dQEmptyWG1 = enum.auto() + dQEmptyWG2 = enum.auto() class NamedBarrierBwdSm100(enum.IntEnum): diff --git a/flash-attn4/torch-ext/flash_attn4/pack_gqa.py b/flash-attn4/torch-ext/flash_attn4/pack_gqa.py index 8b0fd735..b616bb86 100644 --- a/flash-attn4/torch-ext/flash_attn4/pack_gqa.py +++ b/flash-attn4/torch-ext/flash_attn4/pack_gqa.py @@ -1,25 +1,123 @@ # Copyright (c) 2025, Tri Dao. +from dataclasses import dataclass +from typing import Union, Tuple import cutlass import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync + from .quack import layout_utils -from . import utils +from . import utils as utils + + +def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx): + """Reshape a tensor to fold qhead_per_kvhead into the seqlen dimension (mode 0). + + The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1) + are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept + as-is (e.g. batch). + + For Q/O tensors (head_idx=2): + (seqlen_q, headdim, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) + For LSE tensors (head_idx=1): + (seqlen_q, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) + """ + head_stride = T.stride[head_idx] + shape_packed = ( + (qhead_per_kvhead, T.shape[0]), + *[T.shape[i] for i in range(1, head_idx)], + nheads_kv, + *[T.shape[i] for i in range(head_idx + 1, len(T.shape))], + ) + stride_packed = ( + (head_stride, T.stride[0]), + *[T.stride[i] for i in range(1, head_idx)], + head_stride * qhead_per_kvhead, + *[T.stride[i] for i in range(head_idx + 1, len(T.shape))], + ) + return cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed)) + + +def make_packgqa_tiled_tma_atom( + op: cute.atom.CopyOp, + gmem_tensor: cute.Tensor, + smem_layout: Union[cute.Layout, cute.ComposedLayout], + cta_tiler: Tuple[int, int], + qhead_per_kvhead: int, + head_idx: int, +): + # This packing and unpacking of the layout is so that we keep the same TMA dimension as usual. + # e.g. for (seqlen, d, nheads, b) layout, we still have 4D TMA after packing to + # ((nheads, seqlen), d, b). + # If we instead pack directly to ((qhead_per_kvhead, seqlen), d, nheads_kv, b) we'd have 5D TMA. + # Pack headdim and seqlen dim into 1: (seqlen, d, nheads, b) -> ((nheads, seqlen), d, b) + gmem_tensor = layout_utils.select( + gmem_tensor, [head_idx, *range(head_idx), *range(head_idx + 1, cute.rank(gmem_tensor))] + ) + gmem_tensor = cute.group_modes(gmem_tensor, 0, 2) + assert cta_tiler[0] % qhead_per_kvhead == 0, ( + "CTA tile size in the seqlen dimension must be divisible by qhead_per_kvhead" + ) + tma_atom, tma_tensor = cpasync.make_tiled_tma_atom( + op, + gmem_tensor, + smem_layout, + ((qhead_per_kvhead, cta_tiler[0] // qhead_per_kvhead), cta_tiler[1]), # No mcast + ) + # Unpack from ((nheads, seqlen), d, b) -> ((qhead_per_kvhead, seqlen), d, nheads_kv, b) + T = tma_tensor + shape_packed = ( + (qhead_per_kvhead, T.shape[0][1]), + *[T.shape[i] for i in range(1, head_idx)], + T.shape[0][0] // qhead_per_kvhead, + *[T.shape[i] for i in range(head_idx, len(T.shape))], + ) + stride_packed = ( + *[T.stride[i] for i in range(head_idx)], + T.stride[0][0] * qhead_per_kvhead, + *[T.stride[i] for i in range(head_idx, len(T.shape))], + ) + tma_tensor = cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed)) + return tma_atom, tma_tensor +def unpack_gqa_layout(T, qhead_per_kvhead, head_idx): + """Reverse of pack_gqa_layout: unfold qhead_per_kvhead from the seqlen dimension (mode 0). + + The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1) + are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept + as-is (e.g. batch). + + For Q/O tensors (head_idx=2): + ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) -> (seqlen_q, headdim, nheads, batch, ...) + For LSE tensors (head_idx=1): + ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) -> (seqlen_q, nheads, batch, ...) + """ + seqlen_stride = T.stride[0][1] + head_stride = T.stride[0][0] + shape_unpacked = ( + T.shape[0][1], + *[T.shape[i] for i in range(1, head_idx)], + T.shape[head_idx] * qhead_per_kvhead, + *[T.shape[i] for i in range(head_idx + 1, len(T.shape))], + ) + stride_unpacked = ( + seqlen_stride, + *[T.stride[i] for i in range(1, head_idx)], + head_stride, + *[T.stride[i] for i in range(head_idx + 1, len(T.shape))], + ) + return cute.make_tensor(T.iterator, cute.make_layout(shape_unpacked, stride=stride_unpacked)) + + +@dataclass class PackGQA: - def __init__( - self, - m_block_size: cutlass.Constexpr[int], - head_dim_padded: cutlass.Constexpr[int], - check_hdim_oob: cutlass.Constexpr[bool], - qhead_per_kvhead: cutlass.Constexpr[bool], - ): - self.m_block_size = m_block_size - self.head_dim_padded = head_dim_padded - self.check_hdim_oob = check_hdim_oob - self.qhead_per_kvhead = qhead_per_kvhead + m_block_size: cutlass.Constexpr[int] + head_dim_padded: cutlass.Constexpr[int] + check_hdim_oob: cutlass.Constexpr[bool] + qhead_per_kvhead: cutlass.Constexpr[bool] @cute.jit def compute_ptr( diff --git a/flash-attn4/torch-ext/flash_attn4/paged_kv.py b/flash-attn4/torch-ext/flash_attn4/paged_kv.py index e76e9a0a..9efbc141 100644 --- a/flash-attn4/torch-ext/flash_attn4/paged_kv.py +++ b/flash-attn4/torch-ext/flash_attn4/paged_kv.py @@ -28,6 +28,9 @@ class PagedKVManager(ParamsBase): head_dim_padded: cutlass.Constexpr[Int32] head_dim_v_padded: cutlass.Constexpr[Int32] + arch: cutlass.Constexpr[Int32] + v_gmem_transposed: cutlass.Constexpr[bool] + gmem_threads_per_row: cutlass.Constexpr[Int32] page_entry_per_thread: Int32 async_copy_elems: Int32 @@ -55,7 +58,11 @@ def create( head_dim_v_padded: cutlass.Constexpr[Int32], num_threads: cutlass.Constexpr[Int32], dtype: Type[cutlass.Numeric], + arch: cutlass.Constexpr[int] = 100, ): + # SM100 transposes V in gmem to (dv, page_size, num_pages); + # SM90 keeps V as (page_size, dv, num_pages), same layout as K. + v_gmem_transposed = arch != 90 universal_copy_bits = 128 async_copy_elems = universal_copy_bits // dtype.width dtype_bytes = dtype.width // 8 @@ -97,7 +104,8 @@ def create( else: cV = cute.make_identity_tensor((n_block_size, head_dim_v_padded)) tVcV = gmem_thr_copy_KV.partition_S(cV) - tVpV = utils.predicate_k(tVcV, limit=mV_paged.shape[0]) + # When V is transposed in gmem, dv is shape[0]; otherwise dv is shape[1] (same as K) + tVpV = utils.predicate_k(tVcV, limit=mV_paged.shape[0 if v_gmem_transposed else 1]) return PagedKVManager( mPageTable, @@ -111,6 +119,8 @@ def create( num_threads, head_dim_padded, head_dim_v_padded, + arch, + v_gmem_transposed, gmem_threads_per_row, page_entry_per_thread, async_copy_elems, @@ -146,13 +156,17 @@ def load_page_table(self, n_block: Int32): @cute.jit def compute_X_ptr(self, K_or_V: str): tPrXPtr = cute.make_rmem_tensor((self.page_entry_per_thread,), cutlass.Int64) + mX = self.mK_paged if const_expr(K_or_V == "K") else self.mV_paged + # K is always (page_size, d, num_pages). V matches K when not transposed, + # but is (dv, page_size, num_pages) when transposed (SM100). + transposed = const_expr(K_or_V == "V" and self.v_gmem_transposed) for i in cutlass.range(self.page_entry_per_thread, unroll=1): page = self.tPrPage[i] page_offset = self.tPrPageOffset[i] - if const_expr(K_or_V == "K"): - tPrXPtr[i] = utils.elem_pointer(self.mK_paged, (page_offset, 0, page)).toint() + if const_expr(transposed): + tPrXPtr[i] = utils.elem_pointer(mX, (0, page_offset, page)).toint() else: - tPrXPtr[i] = utils.elem_pointer(self.mV_paged, (0, page_offset, page)).toint() + tPrXPtr[i] = utils.elem_pointer(mX, (page_offset, 0, page)).toint() return tPrXPtr @cute.jit @@ -161,18 +175,24 @@ def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str): tPrXPtr = self.compute_X_ptr(K_or_V) - # Finesse sX layout to be (M, N). - sX_pi = cute.make_tensor( - sX.iterator, - cute.make_layout( - (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])), - stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])), - ), - ) + if const_expr(self.arch == 90): + # SM90: sX is already stage-sliced by caller (sK[None, None, stage]). + # Flatten hierarchical modes to get (n_block_size, head_dim). + sX_pi = cute.group_modes(sX, 0, 1) + # SM90 does NOT transpose V here (it's transposed via utils.transpose_view before MMA) + else: + # SM100: Finesse sX layout to be (M, N). + sX_pi = cute.make_tensor( + sX.iterator, + cute.make_layout( + (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])), + stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])), + ), + ) - if const_expr(K_or_V == "V"): - # Need to transpose V - sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0])) + if const_expr(K_or_V == "V"): + # Transpose smem V to match transposed gmem layout + sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0])) head_dim = self.head_dim_v_padded if const_expr(K_or_V == "V") else self.head_dim_padded cX = cute.make_identity_tensor((self.n_block_size, head_dim)) diff --git a/flash-attn4/torch-ext/flash_attn4/pipeline.py b/flash-attn4/torch-ext/flash_attn4/pipeline.py index e45284ff..f8fdc1e8 100644 --- a/flash-attn4/torch-ext/flash_attn4/pipeline.py +++ b/flash-attn4/torch-ext/flash_attn4/pipeline.py @@ -1,6 +1,5 @@ # Copyright (c) 2025, Tri Dao. -# import math from typing import Optional from dataclasses import dataclass @@ -11,12 +10,31 @@ from cutlass.pipeline import PipelineUserType from cutlass.pipeline import NamedBarrier as NamedBarrierOg from cutlass.pipeline import PipelineAsync as PipelineAsyncOg +from cutlass.pipeline import PipelineCpAsync as PipelineCpAsyncOg from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg +def _override_create(parent_cls, child_cls): + """Create a static factory that constructs parent_cls then re-classes to child_cls.""" + + @staticmethod + def create(*args, **kwargs): + obj = parent_cls.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + object.__setattr__(obj, "__class__", child_cls) + return obj + + return create + + +def _make_state(index: Int32, phase: Int32) -> PipelineState: + """Construct a PipelineState from index and phase (count/stages unused by callers).""" + return PipelineState(stages=0, count=Int32(0), index=index, phase=phase) + + class PipelineStateSimple: """ Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer. @@ -25,9 +43,6 @@ class PipelineStateSimple: """ def __init__(self, stages: int, phase_index: Int32): - # assert stages < 2**16 - # self._log_stages = int(math.log2(stages)) - # assert 1 << self._log_stages == stages, "Number of stages must be a power of 2." self._stages = stages self._phase_index = phase_index @@ -36,13 +51,10 @@ def clone(self) -> "PipelineStateSimple": @property def stages(self) -> int: - # return 1 << self._log_stages return self._stages @property def index(self) -> Int32: - # return self._phase_index & 0xFFFF - # return self._phase_index & ((1 << self._log_stages) - 1) if const_expr(self._stages == 1): return Int32(0) else: @@ -50,11 +62,8 @@ def index(self) -> Int32: @property def phase(self) -> Int32: - # return self._phase_index >> 16 # PTX docs say that the phase parity needs to be 0 or 1, so by right we need to # take modulo 2. But in practice just passing the phase in without modulo works fine. - # return (self._phase_index >> self._log_stages) % 2 - # return self._phase_index >> self._log_stages if const_expr(self._stages == 1): return self._phase_index else: @@ -66,21 +75,6 @@ def advance(self): else: self._phase_index += 1 - # def then_body(phase_index): - # # XOR the phase bit and set the index to 0 - # return (phase_index & 0xFFFF0000) ^ (1 << 16) - - # def else_body(phase_index): - # return phase_index - - # self._phase_index = if_generate( - # (self._phase_index & 0xFFFF) == self.stages, - # then_body, - # else_body, - # [self._phase_index], - # [Int32], - # ) - def __extract_mlir_values__(self): phase_index = self._phase_index return [phase_index.ir_value()] @@ -94,7 +88,6 @@ def make_pipeline_state(type: PipelineUserType, stages: int): Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1. """ if type is PipelineUserType.Producer: - # return PipelineStateSimple(stages, Int32(1 << 16)) return PipelineStateSimple(stages, Int32(stages)) elif type is PipelineUserType.Consumer: return PipelineStateSimple(stages, Int32(0)) @@ -102,14 +95,73 @@ def make_pipeline_state(type: PipelineUserType, stages: int): assert False, "Error: invalid PipelineUserType specified for make_pipeline_state." +# โ”€โ”€ Shared helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +def _call_with_elect_one(parent_method, self, state, elect_one, syncwarp, loc, ip): + """Optionally wrap a parent pipeline method call in sync_warp + elect_one.""" + if const_expr(elect_one): + if const_expr(syncwarp): + cute.arch.sync_warp() + with cute.arch.elect_one(): + parent_method(self, state, loc=loc, ip=ip) + else: + parent_method(self, state, loc=loc, ip=ip) + + +# โ”€โ”€ Mixin: _w_index / _w_index_phase variants that delegate to parent โ”€โ”€โ”€โ”€โ”€โ”€โ”€ +# Each parent class has PipelineState-based methods (producer_acquire, producer_commit, +# consumer_wait, consumer_release). The _w_index_phase variants just construct a +# PipelineState from (index, phase) and delegate. + + +class _PipelineIndexPhaseMixin: + """Mixin providing _w_index_phase / _w_index methods that delegate to PipelineState-based parents.""" + + @dsl_user_op + def producer_acquire_w_index_phase( + self, + index: Int32, + phase: Int32, + try_acquire_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + state = _make_state(index, phase) + # Call the parent's producer_acquire (which takes PipelineState) + self.producer_acquire(state, try_acquire_token, loc=loc, ip=ip) + + @dsl_user_op + def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): + state = _make_state(index, Int32(0)) + self.producer_commit(state, loc=loc, ip=ip) + + @dsl_user_op + def consumer_wait_w_index_phase( + self, + index: Int32, + phase: Int32, + try_wait_token: Optional[Boolean] = None, + *, + loc=None, + ip=None, + ): + state = _make_state(index, phase) + self.consumer_wait(state, try_wait_token, loc=loc, ip=ip) + + @dsl_user_op + def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): + state = _make_state(index, Int32(0)) + self.consumer_release(state, loc=loc, ip=ip) + + +# โ”€โ”€ NamedBarrier โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + @dataclass(frozen=True) class NamedBarrier(NamedBarrierOg): - @staticmethod - def create(*args, **kwargs): - obj = NamedBarrierOg.create(*args, **kwargs) - # Can't assign to __class__ directly since the dataclass is frozen - object.__setattr__(obj, "__class__", NamedBarrier) - return obj + create = _override_create(NamedBarrierOg, None) # patched below @dsl_user_op def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None: @@ -134,72 +186,121 @@ def arrive_and_wait_w_index(self, index: Int32, *, loc=None, ip=None) -> None: ) +NamedBarrier.create = _override_create(NamedBarrierOg, NamedBarrier) + + +# โ”€โ”€ PipelineAsync โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + @dataclass(frozen=True) -class PipelineAsync(PipelineAsyncOg): +class PipelineAsync(_PipelineIndexPhaseMixin, PipelineAsyncOg): + """ + PipelineAsync with optional elect_one for producer_commit and consumer_release. + + When elect_one_*=True (set at create time), only one elected thread per warp + signals the barrier arrive. This is useful when the mask count is set to 1 per warp. + + Args (to create): + elect_one_commit: If True, only elected thread signals producer_commit. + syncwarp_before_commit: If True (default), issue syncwarp before elect_one. + elect_one_release: If True, only elected thread signals consumer_release. + syncwarp_before_release: If True (default), issue syncwarp before elect_one. + Set syncwarp to False when threads are already converged (e.g. after wgmma wait_group). + """ + + _elect_one_commit: bool = False + _syncwarp_before_commit: bool = True + _elect_one_release: bool = False + _syncwarp_before_release: bool = True + @staticmethod - def create(*args, **kwargs): + def create( + *args, + elect_one_commit: bool = False, + syncwarp_before_commit: bool = True, + elect_one_release: bool = False, + syncwarp_before_release: bool = True, + **kwargs, + ): obj = PipelineAsyncOg.create(*args, **kwargs) - # Can't assign to __class__ directly since the dataclass is frozen - # obj.__class__ = PipelineAsync object.__setattr__(obj, "__class__", PipelineAsync) + object.__setattr__(obj, "_elect_one_commit", elect_one_commit) + object.__setattr__(obj, "_syncwarp_before_commit", syncwarp_before_commit) + object.__setattr__(obj, "_elect_one_release", elect_one_release) + object.__setattr__(obj, "_syncwarp_before_release", syncwarp_before_release) return obj @dsl_user_op - def producer_acquire_w_index_phase( - self, - index: Int32, - phase: Int32, - try_acquire_token: Optional[Boolean] = None, - *, - loc=None, - ip=None, - ): - if_generate( - try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), - loc=loc, - ip=ip, + def producer_commit(self, state: PipelineState, *, loc=None, ip=None): + _call_with_elect_one( + PipelineAsyncOg.producer_commit, + self, + state, + self._elect_one_commit, + self._syncwarp_before_commit, + loc, + ip, ) @dsl_user_op - def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): - self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip) - - @dsl_user_op - def consumer_wait_w_index_phase( - self, - index: Int32, - phase: Int32, - try_wait_token: Optional[Boolean] = None, - *, - loc=None, - ip=None, - ): - if_generate( - try_wait_token is None or try_wait_token == 0, - lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), - loc=loc, - ip=ip, + def consumer_release(self, state: PipelineState, *, loc=None, ip=None): + _call_with_elect_one( + PipelineAsyncOg.consumer_release, + self, + state, + self._elect_one_release, + self._syncwarp_before_release, + loc, + ip, ) - @dsl_user_op - def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): - self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip) + # _w_index variants inherited from _PipelineIndexPhaseMixin, which delegate + # to producer_commit / consumer_release above. + + +# โ”€โ”€ PipelineCpAsync โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ @dataclass(frozen=True) -class PipelineTmaAsync(PipelineTmaAsyncOg): - """ - Override producer_acquire to take in extra_tx_count parameter. - """ +class PipelineCpAsync(_PipelineIndexPhaseMixin, PipelineCpAsyncOg): + _elect_one_release: bool = False + _syncwarp_before_release: bool = True @staticmethod - def create(*args, **kwargs): - obj = PipelineTmaAsyncOg.create(*args, **kwargs) - # Can't assign to __class__ directly since the dataclass is frozen - object.__setattr__(obj, "__class__", PipelineTmaAsync) + def create( + *args, + elect_one_release: bool = False, + syncwarp_before_release: bool = True, + **kwargs, + ): + obj = PipelineCpAsyncOg.create(*args, **kwargs) + object.__setattr__(obj, "__class__", PipelineCpAsync) + object.__setattr__(obj, "_elect_one_release", elect_one_release) + object.__setattr__(obj, "_syncwarp_before_release", syncwarp_before_release) return obj + @dsl_user_op + def consumer_release(self, state: PipelineState, *, loc=None, ip=None): + _call_with_elect_one( + PipelineCpAsyncOg.consumer_release, + self, + state, + self._elect_one_release, + self._syncwarp_before_release, + loc, + ip, + ) + + # _w_index variants inherited from _PipelineIndexPhaseMixin. + + +# โ”€โ”€ PipelineTmaAsync โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +@dataclass(frozen=True) +class PipelineTmaAsync(_PipelineIndexPhaseMixin, PipelineTmaAsyncOg): + """Override producer_acquire to take in extra_tx_count parameter.""" + @dsl_user_op def producer_acquire( self, @@ -226,19 +327,15 @@ def producer_acquire( self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip) -@dataclass(frozen=True) -class PipelineTmaUmma(PipelineTmaUmmaOg): - """ - Override producer_acquire to take in extra_tx_count parameter. - """ +PipelineTmaAsync.create = _override_create(PipelineTmaAsyncOg, PipelineTmaAsync) + + +# โ”€โ”€ PipelineTmaUmma โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - @staticmethod - def create(*args, **kwargs): - obj = PipelineTmaUmmaOg.create(*args, **kwargs) - # Can't assign to __class__ directly since the dataclass is frozen - # obj.__class__ = PipelineTmaUmma - object.__setattr__(obj, "__class__", PipelineTmaUmma) - return obj + +@dataclass(frozen=True) +class PipelineTmaUmma(_PipelineIndexPhaseMixin, PipelineTmaUmmaOg): + """Override producer_acquire to take in extra_tx_count parameter.""" @dsl_user_op def producer_acquire( @@ -279,162 +376,27 @@ def producer_acquire( ip=ip, ) - @dsl_user_op - def producer_acquire_w_index_phase( - self, - index: Int32, - phase: Int32, - try_acquire_token: Optional[Boolean] = None, - *, - loc=None, - ip=None, - ): - """ - TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. - """ - if_generate( - try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - if_generate( - self.is_leader_cta, - lambda: self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - @dsl_user_op - def consumer_wait_w_index_phase( - self, - index: Int32, - phase: Int32, - try_wait_token: Optional[Boolean] = None, - *, - loc=None, - ip=None, - ): - if_generate( - try_wait_token is None or try_wait_token == 0, - lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), - loc=loc, - ip=ip, - ) +PipelineTmaUmma.create = _override_create(PipelineTmaUmmaOg, PipelineTmaUmma) - @dsl_user_op - def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): - """ - UMMA consumer release buffer empty, cta_group needs to be provided. - """ - self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip) + +# โ”€โ”€ PipelineUmmaAsync โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ @dataclass(frozen=True) -class PipelineUmmaAsync(PipelineUmmaAsyncOg): - @staticmethod - def create(*args, **kwargs): - obj = PipelineUmmaAsyncOg.create(*args, **kwargs) - # Can't assign to __class__ directly since the dataclass is frozen - object.__setattr__(obj, "__class__", PipelineUmmaAsync) - return obj +class PipelineUmmaAsync(_PipelineIndexPhaseMixin, PipelineUmmaAsyncOg): + pass - @dsl_user_op - def producer_acquire_w_index_phase( - self, - index: Int32, - phase: Int32, - try_acquire_token: Optional[Boolean] = None, - *, - loc=None, - ip=None, - ): - if_generate( - try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - @dsl_user_op - def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): - """ - UMMA producer commit buffer full, cta_group needs to be provided. - """ - self.sync_object_full.arrive(index, self.producer_mask, self.cta_group, loc=loc, ip=ip) +PipelineUmmaAsync.create = _override_create(PipelineUmmaAsyncOg, PipelineUmmaAsync) - @dsl_user_op - def consumer_wait_w_index_phase( - self, - index: Int32, - phase: Int32, - try_wait_token: Optional[Boolean] = None, - *, - loc=None, - ip=None, - ): - if_generate( - try_wait_token is None or try_wait_token == 0, - lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - @dsl_user_op - def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): - self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip) +# โ”€โ”€ PipelineAsyncUmma โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ @dataclass(frozen=True) -class PipelineAsyncUmma(PipelineAsyncUmmaOg): - @staticmethod - def create(*args, **kwargs): - obj = PipelineAsyncUmmaOg.create(*args, **kwargs) - # Can't assign to __class__ directly since the dataclass is frozen - object.__setattr__(obj, "__class__", PipelineAsyncUmma) - return obj +class PipelineAsyncUmma(_PipelineIndexPhaseMixin, PipelineAsyncUmmaOg): + pass - @dsl_user_op - def producer_acquire_w_index_phase( - self, - index: Int32, - phase: Int32, - try_acquire_token: Optional[Boolean] = None, - *, - loc=None, - ip=None, - ): - if_generate( - try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - - @dsl_user_op - def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): - self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip) - - @dsl_user_op - def consumer_wait_w_index_phase( - self, - index: Int32, - phase: Int32, - try_wait_token: Optional[Boolean] = None, - *, - loc=None, - ip=None, - ): - if_generate( - try_wait_token is None or try_wait_token == 0, - lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - @dsl_user_op - def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): - """ - UMMA consumer release buffer empty, cta_group needs to be provided. - """ - self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip) +PipelineAsyncUmma.create = _override_create(PipelineAsyncUmmaOg, PipelineAsyncUmma) diff --git a/flash-attn4/torch-ext/flash_attn4/quack/copy_utils.py b/flash-attn4/torch-ext/flash_attn4/quack/copy_utils.py index 2c336389..c213667c 100644 --- a/flash-attn4/torch-ext/flash_attn4/quack/copy_utils.py +++ b/flash-attn4/torch-ext/flash_attn4/quack/copy_utils.py @@ -15,6 +15,9 @@ from cutlass._mlir import ir from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir +from . import layout_utils +from .utils import make_vector + Sm100MmaPeerBitMask = 0xFEFFFFFF @@ -41,6 +44,30 @@ def cvt_copy( cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) +@dsl_user_op +def sr_cvt_copy( + tiled_copy: cute.TiledCopy, + src: cute.Tensor, + dst: cute.Tensor, + seed: Int32, + tidx: Int32, + *, + loc=None, + ip=None, +) -> None: + """Like cvt_copy but uses stochastic rounding for FP32 -> BF16 conversion.""" + assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem + from .rounding import convert_f32_to_bf16_sr + from cutlass.cute.tensor import TensorSSA + + src_cvt = cute.make_rmem_tensor_like(src, dst.element_type) + src_vec = src.load() + raw_vec = convert_f32_to_bf16_sr(src_vec, seed, tidx, loc=loc, ip=ip) + src_cvt.store(TensorSSA(raw_vec, src_vec.shape, dst.element_type)) + src = src_cvt + cute.copy(tiled_copy, src, dst, loc=loc, ip=ip) + + @dsl_user_op def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: dst = cute.make_rmem_tensor_like(src, src.element_type, loc=loc, ip=ip) @@ -796,17 +823,17 @@ def gather_m_get_copy_fn( limit_m: Int32, limit_k: Int32, ) -> Callable: - tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1])) - tAsA = thr_copy_A.partition_D(sA) + tile_M, tile_K = cute.size(sA, mode=[0]), cute.size(sA, mode=[1]) + tAsA = partition_D_position_independent(thr_copy_A, sA) # k-major assert tAsA.shape[2] == 1 tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2) - is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0 + is_even_m_smem = tile_M % thr_copy_A.tiler_mn[0].shape == 0 if const_expr(not is_even_m_smem): - limit_m = min(limit_m, tile_shape_mk[0]) + limit_m = min(limit_m, tile_M) elems_per_load = cute.size(tAsA.shape[0][0]) - cA = cute.make_identity_tensor(tile_shape_mk) + cA = cute.make_identity_tensor((tile_M, tile_K)) tAcA = thr_copy_A.partition_S(cA) t0AcA = thr_copy_A.get_slice(0).partition_S(cA) # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0] @@ -828,13 +855,13 @@ def gather_m_get_copy_fn( else: m_idx[m] = 0 # It's ok to load row 0 in the case of OOB - mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1])) + mA_k = cute.logical_divide(mA, (None, tile_K)) def copy_fn(src_idx, dst_idx, pred: bool = False): tApA_k = None if const_expr(pred): tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean) - limit_k_cur = limit_k - src_idx * tile_shape_mk[1] + limit_k_cur = limit_k - src_idx * tile_K for k in cutlass.range(cols_per_thread, unroll_full=True): tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur mA_cur = mA_k[None, (None, src_idx)] @@ -997,11 +1024,162 @@ def gather_m_get_tma_copy_fn( tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group) def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer): + tSR_sA_cur = tSR_sA[None, None, None, dst_idx] col_idx = tile_K * src_idx for m in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True): row_indices = [tSR_rAIdx[v, m] for v in range(4)] - smem_ptr = tSR_sA[None, m, None, dst_idx].iterator + smem_ptr = tSR_sA_cur[None, m, None].iterator with cute.arch.elect_one(): tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices) return copy_fn + + +@cute.jit +def gather_k_get_tma_copy_fn( + tma_atom: cute.CopyAtom, + sA: cute.Tensor, # ((4, tile_K/4), (tile_M,), STAGE) โ€” K-grouped load layout + sAIdx: cute.Tensor, # (tile_K, a_prefetch_stage) โ€” K indices in smem + col_idx: Int32, # M offset in global tensor (contiguous dim for M-major) + warp_idx: Int32, + num_warps: int, + num_cta: int = 1, +) -> Tuple[Callable, Callable]: + """Build a copy function for TMA gather4 in K dimension (M-major A). + + Each gather4 instruction loads 4 K-columns ร— tile_M contiguous M-elements. + col_idx is the absolute M position in the global tensor. + K indices come from sAIdx (prefetched to smem by the scheduler warp). + + Returns copy_fn(src_idx, dst_idx, tma_bar_ptr) which: + Issues gather4 calls with those K indices as row_indices + """ + tile_K = cute.size(sAIdx, mode=[0]) + assert tile_K % 4 == 0 + cta_group = num_cta + + # Tiled copy for loading K indices from smem to registers (4 per vector, across warps) + copy_AIdx_s2r = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Int32, num_bits_per_copy=128), + cute.make_layout(num_warps), # thr_layout + cute.make_layout(4), # val_layout โ€” 4 K indices per gather4 + ) + warp_idx = cute.arch.make_warp_uniform(warp_idx) + warp_copy_AIdx_s2r = copy_AIdx_s2r.get_slice(warp_idx) + tSR_sAIdx = warp_copy_AIdx_s2r.partition_S(sAIdx) # (((4,1),4,4)) + # ((4,1),4,(64,2),(1,4)):((64,0),1024,(1,4096),(0,8192)) + tSR_sA = warp_copy_AIdx_s2r.partition_S(layout_utils.transpose_view(sA)) + tma_desc_ptr = get_tma_desc_addr(tma_atom) + tma_gather4_load_fn = partial(tma_gather4_load, tma_desc_ptr, num_cta=cta_group) + + def prefetch_from_smem_fn( + a_prefetch_pipeline, + src_idx, + dst_idx, + a_prefetch_consumer_state, + ) -> cute.Tensor: + a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state) + tSR_rAIdx = load_s2r(tSR_sAIdx[None, None, dst_idx]) + cute.arch.sync_warp() + with cute.arch.elect_one(): + a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state) + return tSR_rAIdx + + def copy_fn(src_idx, dst_idx, tSR_rAIdx, tma_bar_ptr: cute.Pointer): + # Issue gather4: col_idx = M position, row_indices = 4 K positions + tSR_sA_cur = tSR_sA[None, None, None, dst_idx] + gather_dim = cute.size(tSR_sA_cur, mode=[2, 0]) # Typically 64 + for k in cutlass.range(cute.size(tSR_rAIdx, mode=[1]), unroll_full=True): + row_indices = [tSR_rAIdx[v, k] for v in range(4)] + for m in cutlass.range(cute.size(tSR_sA_cur, mode=[2, 1]), unroll_full=True): + smem_ptr = tSR_sA_cur[None, k, (None, m)].iterator + with cute.arch.elect_one(): + tma_gather4_load_fn( + smem_ptr, tma_bar_ptr, col_idx + m * gather_dim, row_indices + ) + + return copy_fn, prefetch_from_smem_fn + + +# --------------------------------------------------------------------------- +# Store helpers +# --------------------------------------------------------------------------- + + +@dsl_user_op +@cute.jit +def store( + ptr: cute.Pointer, + val, + pred: Optional[Boolean] = None, + cop: cutlass.Constexpr = None, + *, + loc=None, + ip=None, +): + """Store a scalar value via cute.arch.store. + + ptr: cute.Pointer (any address space). + val: DSL Numeric value. + pred: None โ†’ unconditional. DSL Boolean โ†’ skipped when pred == 0. + cop: Cache operator โ€” "wb" (default), "cg", "cs" (streaming), "wt". + """ + if const_expr(pred is None): + cute.arch.store(ptr.llvm_ptr, type(val)(val), cop=cop, loc=loc, ip=ip) + else: + if pred: + cute.arch.store(ptr.llvm_ptr, type(val)(val), cop=cop, loc=loc, ip=ip) + + +@dsl_user_op +@cute.jit +def store_v2( + ptr: cute.Pointer, + v0, + v1, + pred: Optional[Boolean] = None, + cop: cutlass.Constexpr = None, + *, + loc=None, + ip=None, +): + """Vectorized store of 2 elements via cute.arch.store. + + Packs v0, v1 into an MLIR <2 x T> vector. + ptr: cute.Pointer (any address space, must be aligned for vector width). + cop: Cache operator โ€” "wb" (default), "cg", "cs" (streaming), "wt". + """ + vec = make_vector(type(v0), v0, v1, loc=loc, ip=ip) + if const_expr(pred is None): + cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip) + else: + if pred: + cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip) + + +@dsl_user_op +@cute.jit +def store_v4( + ptr: cute.Pointer, + v0, + v1, + v2, + v3, + pred: Optional[Boolean] = None, + cop: cutlass.Constexpr = None, + *, + loc=None, + ip=None, +): + """Vectorized store of 4 elements via cute.arch.store. + + Packs v0โ€“v3 into an MLIR <4 x T> vector. + ptr: cute.Pointer (any address space, must be aligned for vector width). + cop: Cache operator โ€” "wb" (default), "cg", "cs" (streaming), "wt". + """ + vec = make_vector(type(v0), v0, v1, v2, v3, loc=loc, ip=ip) + if const_expr(pred is None): + cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip) + else: + if pred: + cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip) diff --git a/flash-attn4/torch-ext/flash_attn4/quack/cute_dsl_utils.py b/flash-attn4/torch-ext/flash_attn4/quack/cute_dsl_utils.py index 044a45a1..4e8814d8 100644 --- a/flash-attn4/torch-ext/flash_attn4/quack/cute_dsl_utils.py +++ b/flash-attn4/torch-ext/flash_attn4/quack/cute_dsl_utils.py @@ -4,6 +4,9 @@ from functools import lru_cache from dataclasses import dataclass, fields +import os +import re + import torch try: @@ -14,7 +17,6 @@ import cutlass import cutlass.cute as cute from cutlass import Int32, Int64, Float16, BFloat16, Float32 -from cutlass.base_dsl.typing import JitArgument from cutlass.base_dsl.tvm_ffi_builder import spec from cutlass.cutlass_dsl import NumericMeta @@ -65,8 +67,25 @@ def get_max_active_clusters(cluster_size): return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size) +def _parse_arch_str(arch_str: str) -> Tuple[int, int]: + """Parse arch string (e.g. 'sm_90', 'sm90', '90', 'sm_100a') to (major, minor) tuple.""" + match = re.match(r"^(?:sm_?)?(\d+)(\d)([af]?)$", arch_str.strip(), re.IGNORECASE) + if not match: + raise ValueError(f"Invalid QUACK_ARCH format: {arch_str!r} (expected e.g. '90', 'sm_90')") + major, minor, _ = match.groups() + return int(major), int(minor) + + @lru_cache def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: + """Return (major, minor) device capability. + + Override with QUACK_ARCH (e.g. 'sm_90' or '90') for CPU-only compilation + without a GPU present. + """ + arch_override = os.environ.get("QUACK_ARCH") + if arch_override is not None: + return _parse_arch_str(arch_override) return torch.cuda.get_device_capability(device) @@ -138,28 +157,3 @@ def __extract_mlir_values__(self): return values __new_from_mlir_values__ = _new_from_mlir_values - - -@dataclass -class ArgumentsBase(JitArgument): - def __c_pointers__(self): - _, non_constexpr_fields = _partition_fields(self) - c_ptrs = [] - for obj in non_constexpr_fields.values(): - if hasattr(obj, "__c_pointers__"): - c_ptrs.extend(obj.__c_pointers__()) - return c_ptrs - - def __get_mlir_types__(self): - _, non_constexpr_fields = _partition_fields(self) - types, self._values_pos = [], [] - for obj in non_constexpr_fields.values(): - if hasattr(obj, "__get_mlir_types__"): - obj_types = obj.__get_mlir_types__() - types.extend(obj_types) - self._values_pos.append(len(obj_types)) - else: - self._values_pos.append(0) - return types - - __new_from_mlir_values__ = _new_from_mlir_values diff --git a/flash-attn4/torch-ext/flash_attn4/quack/layout_utils.py b/flash-attn4/torch-ext/flash_attn4/quack/layout_utils.py index 9a474804..8955a420 100644 --- a/flash-attn4/torch-ext/flash_attn4/quack/layout_utils.py +++ b/flash-attn4/torch-ext/flash_attn4/quack/layout_utils.py @@ -295,3 +295,37 @@ def mma_partition_A_vec( sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride)) tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma)) return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None] + + +def copy_partition_S_vec( + sVec: cute.Tensor, thr_copy: cute.core.ThrCopy, expand_shape: int, is_colvec: bool +) -> cute.Tensor: + assert cute.rank(sVec) == 2 + assert sVec.stride[0] == 1 + stage = sVec.shape[1] + shape = ( + (sVec.shape[0], expand_shape, stage) + if const_expr(is_colvec) + else (expand_shape, sVec.shape[0], stage) + ) + stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1]) + sVec_thr = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride)) + tC_sVec = reshape_acc_to_mn(thr_copy.partition_S(sVec_thr)) + return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None] + + +def copy_partition_D_vec( + sVec: cute.Tensor, thr_copy: cute.core.ThrCopy, expand_shape: int, is_colvec: bool +) -> cute.Tensor: + assert cute.rank(sVec) == 2 + assert sVec.stride[0] == 1 + stage = sVec.shape[1] + shape = ( + (sVec.shape[0], expand_shape, stage) + if const_expr(is_colvec) + else (expand_shape, sVec.shape[0], stage) + ) + stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1]) + sVec_thr = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride)) + tC_sVec = reshape_acc_to_mn(thr_copy.partition_D(sVec_thr)) + return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None] diff --git a/flash-attn4/torch-ext/flash_attn4/quack/utils.py b/flash-attn4/torch-ext/flash_attn4/quack/utils.py new file mode 100644 index 00000000..80ca4767 --- /dev/null +++ b/flash-attn4/torch-ext/flash_attn4/quack/utils.py @@ -0,0 +1,324 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +import math +from typing import Optional, Tuple, Union + +import cutlass +import cutlass.cute as cute + +from cutlass import Float32, Int32, const_expr +from cutlass._mlir.dialects import arith as _arith +from cutlass._mlir.dialects import llvm, nvvm, vector +from cutlass.cutlass_dsl import T, dsl_user_op + + +@dsl_user_op +def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: + return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) + + +@cute.jit +def load_scalar_or_pointer(x, dtype=Float32): + if const_expr(isinstance(x, cute.Pointer)): + return dtype(cute.make_tensor(x, cute.make_layout(1))[0]) + else: + return x + + +@dsl_user_op +def set_block_rank( + smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None +) -> Int32: + """Map the given smem pointer to the address at another CTA rank in the cluster.""" + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + return Int32( + llvm.inline_asm( + T.i32(), + [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()], + "mapa.shared::cluster.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def store_shared_remote( + val: float | Float32 | Int32 | cutlass.Int64, + smem_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + peer_cta_rank_in_cluster: cute.typing.Int, + *, + loc=None, + ip=None, +) -> None: + remote_smem_ptr_i32 = set_block_rank( + smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + remote_mbar_ptr_i32 = set_block_rank( + mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + if const_expr(isinstance(val, float)): + val = Float32(val) + assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64" + suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)] + constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)] + llvm.inline_asm( + None, + [remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32], + f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];", + f"r,{constraint},r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def store_shared_remote_x4( + val0: Float32 | Int32, + val1: Float32 | Int32, + val2: Float32 | Int32, + val3: Float32 | Int32, + smem_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + peer_cta_rank_in_cluster: cute.typing.Int, + *, + loc=None, + ip=None, +) -> None: + remote_smem_ptr_i32 = set_block_rank( + smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + remote_mbar_ptr_i32 = set_block_rank( + mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + assert isinstance(val0, (Float32, Int32)), "val must be Float32, or Int32" + dtype = Float32 if isinstance(val0, Float32) else Int32 + suffix = {Float32: "f32", Int32: "s32"}[dtype] + constraint = {Float32: "f", Int32: "r"}[dtype] + llvm.inline_asm( + None, + [ + remote_smem_ptr_i32, + remote_mbar_ptr_i32, + dtype(val0).ir_value(loc=loc, ip=ip), + dtype(val1).ir_value(loc=loc, ip=ip), + dtype(val2).ir_value(loc=loc, ip=ip), + dtype(val3).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + f".reg .v4 .{suffix} abcd;\n\t" + f"mov.{suffix} abcd.x, $2;\n\t" + f"mov.{suffix} abcd.y, $3;\n\t" + f"mov.{suffix} abcd.z, $4;\n\t" + f"mov.{suffix} abcd.w, $5;\n\t" + f"st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.{suffix} [$0], abcd, [$1];\n\t" + "}\n", + f"r,r,{constraint},{constraint},{constraint},{constraint}", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None) -> Float32: + if cutlass.const_expr(cutlass.CUDA_VERSION.major) == 12: + return Float32( + nvvm.fmin( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + return Float32( + nvvm.fmin( + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "sqrt.approx.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32: + return Int32( + llvm.inline_asm( + T.i32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "cvt.rpi.ftz.s32.f32 $0, $1;", + "=r,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Numeric) -> None: + """Fill out-of-bounds values in shared memory tensor. + + Args: + tXsX: Shared memory tensor to fill + tXpX: Predicate tensor indicating valid elements + fill_value: Value to fill OOB locations with + """ + tXrX_fill = cute.make_rmem_tensor_like(tXsX[(None, 0), None, 0]) + tXrX_fill.fill(fill_value) + for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]): + for rest_k in cutlass.range_constexpr(tXsX.shape[2]): + if const_expr(tXpX is not None): + if not tXpX[rest_v, 0, rest_k]: + cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k]) + else: + cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k]) + + +# --------------------------------------------------------------------------- +# General-purpose DSL store / vector helpers +# --------------------------------------------------------------------------- + + +@dsl_user_op +def make_vector(elem_type, *values, loc=None, ip=None): + """Build an MLIR vector from N scalar DSL values. + + Example: make_vector(cutlass.Uint32, v0, v1) -> <2 x i32> MLIR vector + """ + from cutlass._mlir import ir + + n = len(values) + mlir_ty = elem_type.mlir_type + vec_ty = ir.VectorType.get([n], mlir_ty) + vec = llvm.mlir_undef(vec_ty, loc=loc, ip=ip) + for i, v in enumerate(values): + vec = vector.insertelement( + elem_type(v).ir_value(loc=loc, ip=ip), + vec, + position=_arith.constant(T.i32(), i, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + return vec + + +@dsl_user_op +def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64: + vec_f32x2 = vector.from_elements( + T.vector(2, T.f32()), (a.ir_value(), b.ir_value()), loc=loc, ip=ip + ) + vec_i64x1 = vector.bitcast(T.vector(1, T.i64()), vec_f32x2) + res = cutlass.Int64( + vector.extract(vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip) + ) + return res + + +@dsl_user_op +def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]: + vec_i64x1 = vector.from_elements(T.vector(1, T.i64()), (c.ir_value(),), loc=loc, ip=ip) + vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1) + res0 = Float32( + vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip) + ) + res1 = Float32( + vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip) + ) + return res0, res1 + + +@cute.jit +def warp_prefix_sum(val: Int32, lane: Optional[Int32] = None) -> Int32: + if const_expr(lane is None): + lane = cute.arch.lane_idx() + for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))): + offset = 1 << i + # Very important that we set mask_and_clamp to 0 + partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0) + if lane >= offset: + val += partial_sum + return val + + +@dsl_user_op +def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32: + from cutlass import CUDA_VERSION + + # * NVVM call based on nvvm version + if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9: + # Old API: requires explicit result type as first positional argument + return nvvm.atomicrmw( + res=T.i32(), op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value() + ) + else: + # New API: infers result type automatically + return nvvm.atomicrmw( + op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value() + ) + + +@dsl_user_op +def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32: + from cutlass import CUDA_VERSION + + # * NVVM call based on nvvm version + if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9: + # Old API: requires explicit result type as first positional argument + return nvvm.atomicrmw( + res=T.i32(), op=nvvm.AtomicOpKind.ADD, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value() + ) + else: + # New API: infers result type automatically + return nvvm.atomicrmw( + op=nvvm.AtomicOpKind.ADD, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value() + ) + + +@dsl_user_op +def issue_clc_query_nomulticast( + mbar_ptr: cute.Pointer, + clc_response_ptr: cute.Pointer, + loc=None, + ip=None, +) -> None: + """ + The clusterlaunchcontrol.try_cancel instruction requests atomically cancelling the launch + of a cluster that has not started running yet. It asynchronously writes an opaque response + to shared memory indicating whether the operation succeeded or failed. On success, the + opaque response contains the ctaid of the first CTA of the canceled cluster. + + :param mbar_ptr: A pointer to the mbarrier address in SMEM + :type mbar_ptr: Pointer + :param clc_response_ptr: A pointer to the cluster launch control response address in SMEM + :type clc_response_ptr: Pointer + """ + mbar_llvm_ptr = mbar_ptr.llvm_ptr + clc_response_llvm_ptr = clc_response_ptr.llvm_ptr + nvvm.clusterlaunchcontrol_try_cancel( + clc_response_llvm_ptr, + mbar_llvm_ptr, + loc=loc, + ip=ip, + ) diff --git a/flash-attn4/torch-ext/flash_attn4/seqlen_info.py b/flash-attn4/torch-ext/flash_attn4/seqlen_info.py index 6d8c6feb..aa071296 100644 --- a/flash-attn4/torch-ext/flash_attn4/seqlen_info.py +++ b/flash-attn4/torch-ext/flash_attn4/seqlen_info.py @@ -5,6 +5,8 @@ import cutlass.cute as cute from cutlass import Int32, const_expr +from .quack import copy_utils + """ This consolidates all the info related to sequence length. This is so that we can do all the gmem reads once at the beginning of each tile, rather than having to repeat these reads @@ -14,34 +16,61 @@ @dataclass(frozen=True) class SeqlenInfo: - offset: cutlass.Int32 - seqlen: cutlass.Int32 + offset: Int32 + offset_padded: Int32 + seqlen: Int32 + has_cu_seqlens: cutlass.Constexpr[bool] = False @staticmethod def create( - batch_idx: cutlass.Int32, - seqlen_static: cutlass.Int32, + batch_idx: Int32, + seqlen_static: Int32, cu_seqlens: Optional[cute.Tensor] = None, seqused: Optional[cute.Tensor] = None, + tile: cutlass.Constexpr[int] = 128, ): offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx] + offset_padded = ( + 0 + if const_expr(cu_seqlens is None) + # Add divby so that the compiler knows the alignment when moving by offset_padded + else cute.assume((offset + batch_idx * tile) // tile * tile, divby=tile) + ) if const_expr(seqused is not None): seqlen = seqused[batch_idx] elif const_expr(cu_seqlens is not None): seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx] else: seqlen = seqlen_static - return SeqlenInfo(offset, seqlen) + return SeqlenInfo(offset, offset_padded, seqlen, has_cu_seqlens=cu_seqlens is not None) + + def offset_batch( + self, + mT: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + multiple: int = 1, + ) -> cute.Tensor: + """Offset a tensor by batch index. batch dim is at position `dim`, seqlen is at dim=0.""" + if const_expr(not self.has_cu_seqlens): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mT) - 1 - dim) + return mT[idx] + else: + off = multiple * (self.offset if const_expr(not padded) else self.offset_padded) + offset = off if const_expr(cute.rank(mT.shape[0]) == 1) else (0, off) + idx = (offset,) + (None,) * (cute.rank(mT) - 1) + return cute.domain_offset(idx, mT) @dataclass(frozen=True) class SeqlenInfoQK: - offset_q: cutlass.Int32 - offset_k: cutlass.Int32 - padded_offset_q: cutlass.Int32 - padded_offset_k: cutlass.Int32 - seqlen_q: cutlass.Int32 - seqlen_k: cutlass.Int32 + offset_q: Int32 + offset_k: Int32 + padded_offset_q: Int32 + padded_offset_k: Int32 + seqlen_q: Int32 + seqlen_k: Int32 has_cu_seqlens_q: cutlass.Constexpr[bool] has_cu_seqlens_k: cutlass.Constexpr[bool] has_seqused_q: cutlass.Constexpr[bool] @@ -49,27 +78,27 @@ class SeqlenInfoQK: @staticmethod def create( - batch_idx: cutlass.Int32, - seqlen_q_static: cutlass.Int32, - seqlen_k_static: cutlass.Int32, + batch_idx: Int32, + seqlen_q_static: Int32, + seqlen_k_static: Int32, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - tile_m: cutlass.Constexpr[cutlass.Int32] = 128, - tile_n: cutlass.Constexpr[cutlass.Int32] = 128, + tile_m: cutlass.Constexpr[Int32] = 128, + tile_n: cutlass.Constexpr[Int32] = 128, ): offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] padded_offset_q = ( 0 if const_expr(mCuSeqlensQ is None) - else (offset_q + batch_idx * tile_m) // tile_m * tile_m + else cute.assume((offset_q + batch_idx * tile_m) // tile_m * tile_m, divby=tile_m) ) padded_offset_k = ( 0 if const_expr(mCuSeqlensK is None) - else (offset_k + batch_idx * tile_n) // tile_n * tile_n + else cute.assume((offset_k + batch_idx * tile_n) // tile_n * tile_n, divby=tile_n) ) if const_expr(mSeqUsedQ is not None): seqlen_q = mSeqUsedQ[batch_idx] @@ -87,10 +116,6 @@ def create( if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx + 1] - offset_k ) - has_cu_seqlens_q: int = mCuSeqlensQ is not None - has_cu_seqlens_k: int = mCuSeqlensK is not None - has_seqused_q: int = mSeqUsedQ is not None - has_seqused_k: int = mSeqUsedK is not None return SeqlenInfoQK( offset_q, offset_k, @@ -98,10 +123,10 @@ def create( padded_offset_k, seqlen_q, seqlen_k, - has_cu_seqlens_q, - has_cu_seqlens_k, - has_seqused_q, - has_seqused_k, + has_cu_seqlens_q=mCuSeqlensQ is not None, + has_cu_seqlens_k=mCuSeqlensK is not None, + has_seqused_q=mSeqUsedQ is not None, + has_seqused_k=mSeqUsedK is not None, ) def offset_batch_Q( @@ -110,16 +135,38 @@ def offset_batch_Q( batch_idx: Int32, dim: int, padded: cutlass.Constexpr[bool] = False, + ragged: cutlass.Constexpr[bool] = False, ) -> cute.Tensor: """Seqlen must be the first dimension of mQ""" - if const_expr(not self.has_cu_seqlens_q): - idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) - return mQ[idx] + if const_expr(not ragged): + if const_expr(not self.has_cu_seqlens_q): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) + return mQ[idx] + else: + offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q + offset_q = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (None, offset_q) + idx = (offset_q,) + (None,) * (cute.rank(mQ) - 1) + return cute.domain_offset(idx, mQ) else: - offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q - offset = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, offset_q) - idx = (offset,) + (0,) * (cute.rank(mQ) - 1) - return cute.domain_offset(idx, mQ) + if const_expr(not self.has_cu_seqlens_q): + offset_q = 0 + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) + mQ = mQ[idx] + else: + offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q + if const_expr(cute.rank(mQ.shape[0]) == 1): + return copy_utils.offset_ragged_tensor( + mQ, offset_q, self.seqlen_q, ragged_dim=0, ptr_shift=True + ) + else: # PackGQA + assert cute.rank(mQ.shape[0]) == 2 + # Unpack before calling offset_ragged_tensor, then pack + idx = ((None, None),) + (None,) * (cute.rank(mQ) - 1) + mQ = mQ[idx] + mQ = copy_utils.offset_ragged_tensor( + mQ, offset_q, self.seqlen_q, ragged_dim=1, ptr_shift=True + ) + return cute.group_modes(mQ, 0, 2) def offset_batch_K( self, @@ -127,12 +174,114 @@ def offset_batch_K( batch_idx: Int32, dim: int, padded: cutlass.Constexpr[bool] = False, + ragged: cutlass.Constexpr[bool] = False, + multiple: int = 1, ) -> cute.Tensor: """Seqlen must be the first dimension of mK""" - if const_expr(not self.has_cu_seqlens_k): - idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) - return mK[idx] + if const_expr(not ragged): + if const_expr(not self.has_cu_seqlens_k): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) + return mK[idx] + else: + offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k + offset_k *= multiple + idx = (offset_k,) + (None,) * (cute.rank(mK) - 1) + return cute.domain_offset(idx, mK) else: - offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k - idx = (offset_k,) + (0,) * (cute.rank(mK) - 1) - return cute.domain_offset(idx, mK) + if const_expr(not self.has_cu_seqlens_k): + offset_k = 0 + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) + mK = mK[idx] + else: + offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k + offset_k *= multiple + return copy_utils.offset_ragged_tensor( + mK, offset_k, self.seqlen_k, ragged_dim=0, ptr_shift=True + ) + + +@dataclass(frozen=True) +class SeqlenInfoQKNewK: + """Sequence length info for append-KV with left-padding and new K support. + + Extends SeqlenInfoQK with: + - leftpad_k: left padding for K (tokens to skip at the start of the KV cache) + - offset_k_new: offset into the new K tensor + - seqlen_k_og: original K length (before appending new K), excluding leftpad + - seqlen_k_new: length of new K to append + - seqlen_k: total K length (seqlen_k_og + seqlen_k_new) + - seqlen_rotary: position for rotary embedding computation + """ + + leftpad_k: Int32 + offset_q: Int32 + offset_k: Int32 + offset_k_new: Int32 + seqlen_q: Int32 + seqlen_k_og: Int32 + seqlen_k_new: Int32 + seqlen_k: Int32 + seqlen_rotary: Int32 + + @staticmethod + def create( + batch_idx: Int32, + seqlen_q_static: Int32, + seqlen_k_static: Int32, + shape_K_new_0: Int32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mCuSeqlensKNew: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mLeftpadK: Optional[cute.Tensor] = None, + mSeqlensRotary: Optional[cute.Tensor] = None, + ): + leftpad_k = 0 if const_expr(mLeftpadK is None) else mLeftpadK[batch_idx] + offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] + if const_expr(mCuSeqlensK is not None): + offset_k = mCuSeqlensK[batch_idx] + leftpad_k + else: + offset_k = leftpad_k if const_expr(mCuSeqlensQ is not None) else 0 + offset_k_new = 0 if const_expr(mCuSeqlensKNew is None) else mCuSeqlensKNew[batch_idx] + # seqlen_q + if const_expr(mSeqUsedQ is not None): + seqlen_q = mSeqUsedQ[batch_idx] + elif const_expr(mCuSeqlensQ is not None): + seqlen_q = mCuSeqlensQ[batch_idx + 1] - mCuSeqlensQ[batch_idx] + else: + seqlen_q = seqlen_q_static + # seqlen_k_og: original K length (excluding leftpad) + if const_expr(mSeqUsedK is not None): + seqlen_k_og = mSeqUsedK[batch_idx] - leftpad_k + elif const_expr(mCuSeqlensK is not None): + seqlen_k_og = mCuSeqlensK[batch_idx + 1] - mCuSeqlensK[batch_idx] - leftpad_k + else: + seqlen_k_og = ( + seqlen_k_static - leftpad_k + if const_expr(mCuSeqlensQ is not None) + else seqlen_k_static + ) + # seqlen_k_new + if const_expr(mCuSeqlensKNew is None): + seqlen_k_new = 0 if const_expr(mCuSeqlensQ is None) else shape_K_new_0 + else: + seqlen_k_new = mCuSeqlensKNew[batch_idx + 1] - mCuSeqlensKNew[batch_idx] + seqlen_k = seqlen_k_og if const_expr(mCuSeqlensQ is None) else seqlen_k_og + seqlen_k_new + + # seqlen_rotary: defaults to seqlen_k_og + leftpad_k unless explicitly provided + if const_expr(mSeqlensRotary is not None): + seqlen_rotary = mSeqlensRotary[batch_idx] + else: + seqlen_rotary = seqlen_k_og + leftpad_k + return SeqlenInfoQKNewK( + leftpad_k, + offset_q, + offset_k, + offset_k_new, + seqlen_q, + seqlen_k_og, + seqlen_k_new, + seqlen_k, + seqlen_rotary, + ) diff --git a/flash-attn4/torch-ext/flash_attn4/sm90_config_search.py b/flash-attn4/torch-ext/flash_attn4/sm90_config_search.py new file mode 100644 index 00000000..6c9584ea --- /dev/null +++ b/flash-attn4/torch-ext/flash_attn4/sm90_config_search.py @@ -0,0 +1,402 @@ +"""Search feasible SM90 fwd/bwd attention configs for given (head_dim, head_dim_v). + +Enumerates tile sizes, swap modes, atom layouts, and staging options. +Checks GMMA divisibility, register budget, and shared memory budget. + +Usage: + python flash_attn/cute/sm90_config_search.py --headdim 128 + python flash_attn/cute/sm90_config_search.py --mode fwd --headdim 192-128 + python flash_attn/cute/sm90_config_search.py --mode bwd --headdim 192 --tile-n 64,96 +""" + +import math + +# H100 hardware limits +SMEM_LIMIT = 224 * 1024 # 228 KB minus ~3 KB for LSE, dPsum, mbarriers +REG_LIMITS = {2: 216, 3: 128} # per-WG budget: 2WG=240-24, 3WG=160-32 +THREADS_PER_WG = 128 + + +def _divisors(n): + return [d for d in range(1, n + 1) if n % d == 0] + + +def _acc_regs(M, N, num_wg): + """Accumulator registers per thread per WG.""" + return M * N // (num_wg * THREADS_PER_WG) + + +def _check_mma(M, N, num_wg, atom_layout_m, swap_AB): + """Check MMA feasibility. Returns regs per WG, or None if infeasible. + + GMMA atom M=64. Swap exchanges (M, N) and atom layout. + Requires: M divisible by (atom_layout_m * 64), N by (atom_layout_n * 8). + """ + if swap_AB: + M, N = N, M + atom_layout_m = num_wg // atom_layout_m + atom_layout_n = num_wg // atom_layout_m + if M % (atom_layout_m * 64) != 0 or N % (atom_layout_n * 8) != 0: + return None + return _acc_regs(M, N, num_wg) + + +def _mma_traffic(M_eff, N_eff, K_red, num_wg, wg_n, is_rs=False): + """Total SMEM read traffic for one MMA (all WGs combined). + + num_instr = (M_eff / 64) * wg_n instructions total. + Each reads A(64, K_red) and B(N_eff/wg_n, K_red) from smem (bf16). + """ + num_instr = (M_eff // 64) * wg_n + A_per = 64 * K_red * 2 if not is_rs else 0 + B_per = (N_eff // wg_n) * K_red * 2 + return num_instr * (A_per + B_per) + + +# ============================================================================ +# Backward +# ============================================================================ + + +def _check_bwd_config( + hdim, + hdimv, + tile_m, + tile_n, + num_wg, + SdP_swapAB, + dKV_swapAB, + dQ_swapAB, + AtomLayoutMSdP, + AtomLayoutNdKV, + AtomLayoutMdQ, +): + reg_limit = REG_LIMITS[num_wg] + + # MMA feasibility + regs_SdP = _check_mma(tile_m, tile_n, num_wg, AtomLayoutMSdP, SdP_swapAB) + regs_dK = _check_mma(tile_n, hdim, num_wg, AtomLayoutNdKV, dKV_swapAB) + regs_dV = _check_mma(tile_n, hdimv, num_wg, AtomLayoutNdKV, dKV_swapAB) + regs_dQ = _check_mma(tile_m, hdim, num_wg, AtomLayoutMdQ, dQ_swapAB) + if any(r is None for r in (regs_SdP, regs_dK, regs_dV, regs_dQ)): + return None + + # Peak regs: max(S+dP, dQ) + dK + dV + total_regs = max(2 * regs_SdP, regs_dQ) + regs_dK + regs_dV + if total_regs > reg_limit: + return None + + # SMEM + mma_dkv_is_rs = ( + AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_wg and SdP_swapAB and not dKV_swapAB + ) + Q_stage, PdS_stage = 2, 1 + + for dO_stage in (2, 1): + sQ = tile_m * hdim * 2 * Q_stage + sK = tile_n * hdim * 2 + sV = tile_n * hdimv * 2 + sdO = tile_m * hdimv * 2 * dO_stage + sPdS = tile_m * tile_n * 2 * PdS_stage + sP = sPdS if not mma_dkv_is_rs else 0 + sdQaccum = tile_m * hdim * 4 + smem = sQ + sK + sV + sdO + sP + sPdS + sdQaccum + if smem <= SMEM_LIMIT: + break + else: + return None + + # SMEM traffic + def _swap(a, b, s): + return (b, a) if s else (a, b) + + def _wg_n(al_m, s): + return al_m if s else num_wg // al_m + + M_s, N_s = _swap(tile_m, tile_n, SdP_swapAB) + wn_SdP = _wg_n(AtomLayoutMSdP, SdP_swapAB) + traffic_S = _mma_traffic(M_s, N_s, hdim, num_wg, wn_SdP) + traffic_dP = _mma_traffic(M_s, N_s, hdimv, num_wg, wn_SdP) + + wn_dKV = _wg_n(AtomLayoutNdKV, dKV_swapAB) + M_dv, N_dv = _swap(tile_n, hdimv, dKV_swapAB) + traffic_dV = _mma_traffic(M_dv, N_dv, tile_m, num_wg, wn_dKV, is_rs=mma_dkv_is_rs) + M_dk, N_dk = _swap(tile_n, hdim, dKV_swapAB) + traffic_dK = _mma_traffic(M_dk, N_dk, tile_m, num_wg, wn_dKV, is_rs=mma_dkv_is_rs) + + M_dq, N_dq = _swap(tile_m, hdim, dQ_swapAB) + wn_dQ = _wg_n(AtomLayoutMdQ, dQ_swapAB) + traffic_dQ = _mma_traffic(M_dq, N_dq, tile_n, num_wg, wn_dQ) + + traffic_P_store = tile_m * tile_n * 2 if not mma_dkv_is_rs else 0 + traffic_dS_store = tile_m * tile_n * 2 + traffic_dQ_smem = tile_m * hdim * 4 * 2 # store + TMA load + + smem_traffic = ( + traffic_S + + traffic_dP + + traffic_dV + + traffic_dK + + traffic_dQ + + traffic_P_store + + traffic_dS_store + + traffic_dQ_smem + ) + + return dict( + tile_m=tile_m, + tile_n=tile_n, + num_wg=num_wg, + Q_stage=Q_stage, + dO_stage=dO_stage, + PdS_stage=PdS_stage, + SdP_swapAB=SdP_swapAB, + dKV_swapAB=dKV_swapAB, + dQ_swapAB=dQ_swapAB, + AtomLayoutMSdP=AtomLayoutMSdP, + AtomLayoutNdKV=AtomLayoutNdKV, + AtomLayoutMdQ=AtomLayoutMdQ, + mma_dkv_is_rs=mma_dkv_is_rs, + regs_SdP=regs_SdP, + regs_dK=regs_dK, + regs_dV=regs_dV, + regs_dQ=regs_dQ, + total_regs=total_regs, + reg_limit=reg_limit, + smem_bytes=smem, + smem_kb=smem / 1024, + smem_traffic=smem_traffic, + smem_traffic_kb=smem_traffic / 1024, + smem_traffic_per_block=smem_traffic / (tile_m * tile_n), + ) + + +def find_feasible_bwd_configs( + head_dim, + head_dim_v=None, + tile_m_choices=(64, 80, 96, 112, 128), + tile_n_choices=(64, 80, 96, 112, 128), +): + if head_dim_v is None: + head_dim_v = head_dim + hdim = int(math.ceil(head_dim / 32) * 32) + hdimv = int(math.ceil(head_dim_v / 32) * 32) + + results = [] + for num_wg in (2, 3): + divs = _divisors(num_wg) + for tile_m in tile_m_choices: + for tile_n in tile_n_choices: + for SdP_swap in (False, True): + if (tile_n if SdP_swap else tile_m) % 64 != 0: + continue + for dKV_swap in (False, True): + if not dKV_swap and tile_n % 64 != 0: + continue + if dKV_swap and (hdim % 64 != 0 or hdimv % 64 != 0): + continue + for dQ_swap in (False, True): + if (hdim if dQ_swap else tile_m) % 64 != 0: + continue + for a1 in divs: + for a2 in divs: + for a3 in divs: + cfg = _check_bwd_config( + hdim, + hdimv, + tile_m, + tile_n, + num_wg, + SdP_swap, + dKV_swap, + dQ_swap, + a1, + a2, + a3, + ) + if cfg is not None: + results.append(cfg) + + results.sort(key=lambda c: (-c["tile_n"], -c["tile_m"], c["smem_traffic_per_block"])) + return results + + +def print_bwd_configs(configs, max_results=20): + if not configs: + print("No feasible configs found!") + return + n = min(len(configs), max_results) + print(f"Found {len(configs)} feasible configs (showing top {n}):\n") + hdr = ( + f"{'wg':>2} {'tm':>3} {'tn':>3} " + f"{'SdP':>3} {'dKV':>3} {'dQ':>3} " + f"{'aSdP':>4} {'adKV':>4} {'adQ':>4} " + f"{'Qs':>2} {'dOs':>3} " + f"{'rS':>3} {'rdK':>3} {'rdV':>3} {'rdQ':>3} {'tot':>4}/{'':<3} " + f"{'smem':>5} {'traffic':>7} {'tr/blk':>6}" + ) + print(hdr) + print("-" * len(hdr)) + B = lambda b: "T" if b else "F" + for c in configs[:max_results]: + print( + f"{c['num_wg']:>2} {c['tile_m']:>3} {c['tile_n']:>3} " + f"{B(c['SdP_swapAB']):>3} {B(c['dKV_swapAB']):>3} {B(c['dQ_swapAB']):>3} " + f"{c['AtomLayoutMSdP']:>4} {c['AtomLayoutNdKV']:>4} {c['AtomLayoutMdQ']:>4} " + f"{c['Q_stage']:>2} {c['dO_stage']:>3} " + f"{c['regs_SdP']:>3} {c['regs_dK']:>3} {c['regs_dV']:>3} {c['regs_dQ']:>3} " + f"{c['total_regs']:>4}/{c['reg_limit']:<3} " + f"{c['smem_kb']:>4.0f}K " + f"{c['smem_traffic_kb']:>6.0f}K " + f"{c['smem_traffic_per_block']:>6.1f}" + ) + + +# ============================================================================ +# Forward +# ============================================================================ + + +def _check_fwd_config(hdim, hdimv, tile_n, num_wg, pv_is_rs, overlap_wg): + reg_limit = REG_LIMITS[num_wg] + tile_m = num_wg * 64 + + if tile_n % 8 != 0: + return None + + regs_S = _acc_regs(tile_m, tile_n, num_wg) + regs_O = _acc_regs(tile_m, hdimv, num_wg) + regs_P = regs_S // 2 # bf16 = half of f32 + + if overlap_wg: + total_regs = regs_S + regs_P + regs_O + else: + total_regs = regs_S + regs_O + + if total_regs > reg_limit: + return None + + # SMEM: 1 stage Q, 2 stages K/V, O overlaps Q, sP if not RS + sQ = tile_m * hdim * 2 + sK = tile_n * hdim * 2 * 2 + sV = tile_n * hdimv * 2 * 2 + sO = tile_m * hdimv * 2 + sP = tile_m * tile_n * 2 if not pv_is_rs else 0 + smem = max(sQ, sO) + sK + sV + sP + if smem > SMEM_LIMIT: + return None + + # SMEM traffic: num_instr = num_wg (all WGs in M, wg_n=1) + traffic_S = num_wg * (64 * hdim * 2 + tile_n * hdim * 2) + A_pv = 64 * tile_n * 2 if not pv_is_rs else 0 + traffic_O = num_wg * (A_pv + hdimv * tile_n * 2) + traffic_P_store = tile_m * tile_n * 2 if not pv_is_rs else 0 + smem_traffic = traffic_S + traffic_O + traffic_P_store + + return dict( + tile_m=tile_m, + tile_n=tile_n, + num_wg=num_wg, + pv_is_rs=pv_is_rs, + overlap_wg=overlap_wg, + regs_S=regs_S, + regs_O=regs_O, + regs_P=regs_P, + total_regs=total_regs, + reg_limit=reg_limit, + smem_bytes=smem, + smem_kb=smem / 1024, + smem_traffic=smem_traffic, + smem_traffic_kb=smem_traffic / 1024, + smem_traffic_per_block=smem_traffic / (tile_m * tile_n), + ) + + +def find_feasible_fwd_configs( + head_dim, head_dim_v=None, tile_n_choices=(64, 80, 96, 112, 128, 144, 160, 176, 192) +): + if head_dim_v is None: + head_dim_v = head_dim + hdim = int(math.ceil(head_dim / 32) * 32) + hdimv = int(math.ceil(head_dim_v / 32) * 32) + + results = [] + for num_wg in (2, 3): + for tile_n in tile_n_choices: + for pv_is_rs in (True, False): + for overlap_wg in (True, False): + cfg = _check_fwd_config(hdim, hdimv, tile_n, num_wg, pv_is_rs, overlap_wg) + if cfg is not None: + results.append(cfg) + + results.sort(key=lambda c: (-c["tile_n"], c["smem_traffic_per_block"])) + return results + + +def print_fwd_configs(configs, max_results=20): + if not configs: + print("No feasible configs found!") + return + n = min(len(configs), max_results) + print(f"Found {len(configs)} feasible configs (showing top {n}):\n") + hdr = ( + f"{'wg':>2} {'tm':>3} {'tn':>3} " + f"{'RS':>2} {'olap':>4} " + f"{'rS':>3} {'rP':>3} {'rO':>3} {'tot':>4}/{'':<3} " + f"{'smem':>5} {'traffic':>7} {'tr/blk':>6}" + ) + print(hdr) + print("-" * len(hdr)) + B = lambda b: "T" if b else "F" + for c in configs[:max_results]: + print( + f"{c['num_wg']:>2} {c['tile_m']:>3} {c['tile_n']:>3} " + f"{B(c['pv_is_rs']):>2} {B(c['overlap_wg']):>4} " + f"{c['regs_S']:>3} {c['regs_P']:>3} {c['regs_O']:>3} " + f"{c['total_regs']:>4}/{c['reg_limit']:<3} " + f"{c['smem_kb']:>4.0f}K " + f"{c['smem_traffic_kb']:>6.0f}K " + f"{c['smem_traffic_per_block']:>6.1f}" + ) + + +# ============================================================================ +# CLI +# ============================================================================ + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Search feasible SM90 MMA configs") + parser.add_argument("--mode", choices=["fwd", "bwd", "both"], default="both") + parser.add_argument( + "--headdim", type=str, default="128", help="Head dim, or hdim-hdimv (e.g. 192-128)" + ) + parser.add_argument("--tile-m", type=str, default="64,80,96,112,128", help="Bwd tile_m choices") + parser.add_argument( + "--tile-n", + type=str, + default=None, + help="tile_n choices (default: fwd up to 192, bwd up to 128)", + ) + parser.add_argument("-n", "--num-results", type=int, default=30) + args = parser.parse_args() + + parts = args.headdim.split("-") + hdim = int(parts[0]) + hdimv = int(parts[1]) if len(parts) > 1 else hdim + + TN_FWD = "64,80,96,112,128,144,160,176,192" + TN_BWD = "64,80,96,112,128" + + if args.mode in ("fwd", "both"): + tn = tuple(int(x) for x in (args.tile_n or TN_FWD).split(",")) + print(f"=== FWD configs: hdim={hdim}, hdimv={hdimv} ===\n") + print_fwd_configs(find_feasible_fwd_configs(hdim, hdimv, tn), args.num_results) + print() + + if args.mode in ("bwd", "both"): + tm = tuple(int(x) for x in args.tile_m.split(",")) + tn = tuple(int(x) for x in (args.tile_n or TN_BWD).split(",")) + print(f"=== BWD configs: hdim={hdim}, hdimv={hdimv} ===\n") + print_bwd_configs(find_feasible_bwd_configs(hdim, hdimv, tm, tn), args.num_results) diff --git a/flash-attn4/torch-ext/flash_attn4/softmax.py b/flash-attn4/torch-ext/flash_attn4/softmax.py index 41e6cba1..d67ef62e 100644 --- a/flash-attn4/torch-ext/flash_attn4/softmax.py +++ b/flash-attn4/torch-ext/flash_attn4/softmax.py @@ -10,7 +10,7 @@ from cutlass import Float32 from .quack import layout_utils -from . import utils +from . import utils as utils from .quack.cute_dsl_utils import ParamsBase from .seqlen_info import SeqlenInfoQK diff --git a/flash-attn4/torch-ext/flash_attn4/tile_scheduler.py b/flash-attn4/torch-ext/flash_attn4/tile_scheduler.py index f05709bd..563e06d3 100644 --- a/flash-attn4/torch-ext/flash_attn4/tile_scheduler.py +++ b/flash-attn4/torch-ext/flash_attn4/tile_scheduler.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, Tri Dao. -from typing import Optional, Tuple +from enum import IntEnum, auto +from typing import Optional, Tuple, Protocol, runtime_checkable from dataclasses import dataclass try: @@ -9,17 +10,80 @@ from typing_extensions import override import cutlass +from cutlass.pipeline import PipelineClcFetchAsync, PipelineState from cutlass._mlir import ir import cutlass.cute as cute from cutlass import Int32, const_expr from cutlass.cute import FastDivmodDivisor +from cutlass.utils import ClcDynamicPersistentTileScheduler, ClcDynamicPersistentTileSchedulerParams from .quack.cute_dsl_utils import ParamsBase -from . import utils +from . import utils as utils from .fast_math import clz +class SchedulingMode(IntEnum): + NONE = auto() + STATIC = auto() + DYNAMIC = auto() + CLC = auto() + + +@dataclass +class ClcState(ParamsBase): + """Owns the runtime state shared by CLC-capable tile schedulers. + + `FlashAttentionForwardSm100` constructs this state because it owns the CLC + response buffer, mbarrier storage, and launch geometry needed to initialize + the hardware scheduler and async pipeline. Individual tile schedulers then + consume this state and map the returned hardware work tiles into their own + logical `WorkTileInfo` coordinates. + + To add CLC support to a scheduler: + - implement `clc_problem_shape(params)` so the kernel can create the hardware scheduler + - accept `clc: ClcState | None` in `create(...)` / `__init__` + - map `clc.initial_work_tile_info()` and `clc.get_current_work()` into scheduler coordinates + """ + + _hw_scheduler: ClcDynamicPersistentTileScheduler + _pipeline: PipelineClcFetchAsync + _consumer_state: PipelineState + _producer_state: PipelineState + + @staticmethod + def create( + *, + hw_scheduler: ClcDynamicPersistentTileScheduler, + pipeline: PipelineClcFetchAsync, + consumer_state: PipelineState, + producer_state: PipelineState, + ) -> "ClcState": + return ClcState(hw_scheduler, pipeline, consumer_state, producer_state) + + def initial_work_tile_info(self): + return self._hw_scheduler.initial_work_tile_info() + + def get_current_work(self): + return self._hw_scheduler.get_current_work() + + def prefetch_next_work(self, *, loc=None, ip=None): + self._pipeline.producer_acquire(self._producer_state, loc=loc, ip=ip) + mbarrier_addr = self._pipeline.producer_get_barrier(self._producer_state, loc=loc, ip=ip) + self._hw_scheduler.advance_to_next_work(mbarrier_addr, loc=loc, ip=ip) + self._producer_state.advance(loc=loc, ip=ip) + + def consumer_wait(self, *, loc=None, ip=None): + self._pipeline.consumer_wait(self._consumer_state, loc=loc, ip=ip) + + def consumer_release(self, *, loc=None, ip=None): + self._pipeline.consumer_release(self._consumer_state, loc=loc, ip=ip) + self._consumer_state.advance(loc=loc, ip=ip) + + def producer_tail(self, *, loc=None, ip=None): + self._pipeline.producer_tail(self._producer_state, loc=loc, ip=ip) + + class WorkTileInfo(cutlass.utils.WorkTileInfo): """Altered WorkTileInfo which includes four axes: (block, head, batch, split)""" @@ -31,6 +95,47 @@ def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": return WorkTileInfo(new_tile_idx, new_is_valid_tile) +@runtime_checkable +class TileSchedulerProtocol(Protocol): + """Protocol defining the interface all tile schedulers must implement. + + Schedulers are responsible for: + 1. Coordinate mapping: linear tile index -> (m_block, head, batch, split) + 2. Work distribution: how to get the next tile (static grid-stride vs CLC dynamic) + """ + + def get_current_work(self) -> WorkTileInfo: + """Get the current work tile coordinates.""" + ... + + def initial_work_tile_info(self) -> WorkTileInfo: + """Get the initial work tile for this CTA.""" + ... + + def advance_to_next_work(self, *, loc=None, ip=None): + """Consumer-side advance: move to next tile and return it. + + For static schedulers: grid-stride increment + get_current_work. + For CLC schedulers: consumer wait + get_current_work + consumer release + state advance. + """ + ... + + def prefetch_next_work(self, *, loc=None, ip=None) -> None: + """Producer-side prefetch of next work tile (no-op for static schedulers). + + For CLC schedulers: producer acquire + issue CLC query + producer state advance. + Only called by the scheduler warp. + """ + ... + + def producer_tail(self, *, loc=None, ip=None) -> None: + """Producer-side cleanup after the last tile. + + No-op for static schedulers. For CLC schedulers: pipeline producer_tail. + """ + ... + + @dataclass class TileSchedulerArguments(ParamsBase): num_block: Int32 @@ -51,6 +156,7 @@ class TileSchedulerArguments(ParamsBase): lpt: cutlass.Constexpr[bool] = False is_split_kv: cutlass.Constexpr[bool] = False head_swizzle: cutlass.Constexpr[bool] = False + use_cluster_idx: cutlass.Constexpr[bool] = False class SingleTileScheduler: @@ -63,6 +169,7 @@ class Params(ParamsBase): num_splits_divmod: FastDivmodDivisor is_split_kv: cutlass.Constexpr[bool] = False cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + use_cluster_idx: cutlass.Constexpr[bool] = False @staticmethod def create( @@ -76,6 +183,7 @@ def create( FastDivmodDivisor(args.num_splits), args.is_split_kv, args.cluster_shape_mn, + args.use_cluster_idx, ) def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None): @@ -86,18 +194,26 @@ def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None): self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"SingleTileScheduler only supports STATIC, got {scheduling_mode!r}" + ) return SingleTileScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod - def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler": - # if const_expr(cute.size(params.cluster_shape_mn) == 1): - # blk_coord = cute.arch.block_idx() - # else: - # # All CTAs in a cluster must get the same block coordinate - # blk_coord = cute.arch.cluster_idx() - # Temporary set to block_idx until we sort out the best way to handle cluster - blk_coord = cute.arch.block_idx() + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileScheduler": + if const_expr(cute.size(params.cluster_shape_mn) == 1 or not params.use_cluster_idx): + blk_coord = cute.arch.block_idx() + else: + blk_coord = cute.arch.cluster_idx() return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip) # called by host @@ -110,8 +226,13 @@ def get_grid_shape( ) -> Tuple[Int32, Int32, Int32]: # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1) assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" + if const_expr(params.use_cluster_idx): + # Grid must have num_block * cluster_m physical blocks so that there are num_block clusters + grid_x = params.num_block * params.cluster_shape_mn[0] + else: + grid_x = cute.round_up(params.num_block, params.cluster_shape_mn[0]) return ( - cute.round_up(params.num_block, params.cluster_shape_mn[0]), + grid_x, params.num_head * params.num_splits, params.num_batch, ) @@ -135,6 +256,10 @@ def prefetch_next_work(self, *, loc=None, ip=None): def advance_to_next_work(self, *, loc=None, ip=None): self._is_first_block = False + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + pass def __extract_mlir_values__(self): values, self._values_pos = [], [] @@ -180,18 +305,28 @@ def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"StaticPersistentTileScheduler only supports STATIC, got {scheduling_mode!r}" + ) return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod - def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentTileScheduler": + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "StaticPersistentTileScheduler": if const_expr(cute.size(params.cluster_shape_m) == 1): tile_idx = cute.arch.block_idx()[0] else: tile_idx = cute.arch.cluster_idx()[0] return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip) - # called by host @staticmethod def get_grid_shape( params: Params, @@ -201,18 +336,14 @@ def get_grid_shape( ) -> Tuple[Int32, Int32, Int32]: hardware_info = cutlass.utils.HardwareInfo() sm_count = hardware_info.get_device_multiprocessor_count() - # Grid must be a multiple of cluster_shape_m for CUDA cluster launch. max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * params.cluster_shape_m) return (grid_x, Int32(1), Int32(1)) - # @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod) batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod) is_valid = self._tile_idx < self.params.total_blocks_cluster - # if cute.arch.thread_idx()[0] == 0: - # cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid) return WorkTileInfo( (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid ) @@ -228,6 +359,10 @@ def advance_to_next_work(self, *, loc=None, ip=None): self._tile_idx += cute.arch.grid_dim()[0] else: self._tile_idx += cute.arch.cluster_dim()[0] + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + pass def __extract_mlir_values__(self): values, self._values_pos = [], [] @@ -254,32 +389,41 @@ class Params(ParamsBase): total_blocks: Int32 num_splits: Int32 num_block: Int32 + num_head: Int32 + num_batch: Int32 l2_minor: Int32 - num_block_divmod: FastDivmodDivisor num_head_divmod: FastDivmodDivisor l2_minor_divmod: FastDivmodDivisor l2_major_divmod: FastDivmodDivisor l2_minor_residual_divmod: FastDivmodDivisor num_hb_quotient: Int32 + num_splits_divmod: FastDivmodDivisor is_split_kv: cutlass.Constexpr[bool] = False + cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + lpt: cutlass.Constexpr[bool] = True @staticmethod @cute.jit def create( - args: TileSchedulerArguments, *, loc=None, ip=None + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, ) -> "SingleTileLPTScheduler.Params": - # cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.tile_shape_mn, args.qhead_per_kvhead_packgqa, args.element_size) + assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), ( + f"Only STATIC and CLC are supported, got {scheduling_mode!r}" + ) size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size size_one_head = size_one_kv_head size_l2 = 50 * 1024 * 1024 # 40 MB for K & V # Swizzle is the size of each "section". Round swizzle to a power of 2 # Need to be careful about the case where only one head will fit # swizzle is how many heads can fit in L2 - # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) - # Seems faster if swizzle if a power of 2 + # Seems faster if swizzle is a power of 2 log2_floor = lambda n: 31 - clz(n) swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) - # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) # If we're in the last section (called residual), we don't want to divide by # swizzle. Instead we want to divide by the remainder. num_hb_quotient = (args.num_head * args.num_batch) // swizzle @@ -287,37 +431,84 @@ def create( return SingleTileLPTScheduler.Params( total_blocks=args.num_block * args.num_head * args.num_batch, num_block=args.num_block, + num_head=args.num_head, + num_batch=args.num_batch, l2_minor=Int32(swizzle), - num_block_divmod=FastDivmodDivisor(args.num_block), num_head_divmod=FastDivmodDivisor(args.num_head), l2_minor_divmod=FastDivmodDivisor(swizzle), l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block), - l2_minor_residual_divmod=FastDivmodDivisor( - max(num_hb_remainder, 1) - ), # don't divide by 0 + l2_minor_residual_divmod=FastDivmodDivisor(max(num_hb_remainder, 1)), num_hb_quotient=Int32(num_hb_quotient), num_splits=args.num_splits, + num_splits_divmod=FastDivmodDivisor(args.num_splits), is_split_kv=args.is_split_kv, + cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, + lpt=args.lpt, ) - def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None): + def __init__( + self, + params: Params, + tile_idx: Int32, + split_idx: Int32, + clc: ClcState | None = None, + *, + loc=None, + ip=None, + ): self.params = params self._tile_idx = tile_idx self._split_idx = split_idx + self.clc = clc self._loc = loc self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: - return SingleTileLPTScheduler.Params.create(args, loc=loc, ip=ip) + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + return SingleTileLPTScheduler.Params.create( + args, scheduling_mode=scheduling_mode, loc=loc, ip=ip + ) + + @staticmethod + def _clc_grid_shape(params: Params): + num_batch_splits = ( + params.num_batch * params.num_splits + if const_expr(params.is_split_kv) + else params.num_batch + ) + return ( + cute.round_up(params.num_block, params.cluster_shape_m), + params.num_head, + num_batch_splits, + ) @staticmethod @cute.jit - def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler": + def clc_problem_shape(params: Params): + return ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=SingleTileLPTScheduler._clc_grid_shape(params), + cluster_shape_mnk=(params.cluster_shape_m, 1, 1), + ) + + @staticmethod + @cute.jit + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileLPTScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return SingleTileLPTScheduler( + params, cute.arch.block_idx()[0], Int32(0), clc, loc=loc, ip=ip + ) tile_idx, split_idx, _ = cute.arch.block_idx() return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) - # called by host @staticmethod def get_grid_shape( params: Params, @@ -325,10 +516,40 @@ def get_grid_shape( loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return SingleTileLPTScheduler._clc_grid_shape(params) return (params.total_blocks, params.num_splits, Int32(1)) + @cute.jit + def clc_work_to_coords(self, work) -> WorkTileInfo: + """Convert CLC response (block, head, batch_split) to WorkTileInfo. + + CLC returns raw grid coordinates โ€” no L2 swizzle (hardware decides order). + We only apply cluster division, optional LPT block reversal, and split_kv unpacking. + """ + block_idx = work.tile_idx[0] + if const_expr(self.params.cluster_shape_m > 1): + block_idx = block_idx // self.params.cluster_shape_m + if const_expr(self.params.lpt): + # Longest-processing-time-first: reverse block order + block_idx = self.params.num_block - 1 - block_idx + split_idx = Int32(0) + if const_expr(self.params.is_split_kv): + batch_idx, split_idx = divmod(work.tile_idx[2], self.params.num_splits_divmod) + else: + batch_idx = work.tile_idx[2] + return WorkTileInfo( + (Int32(block_idx), Int32(work.tile_idx[1]), Int32(batch_idx), Int32(split_idx)), + work.is_valid_tile, + ) + @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self.clc.get_current_work() + self._tile_idx = work.tile_idx[0] + return self.clc_work_to_coords(work) + # Static path: L2-swizzled coordinate mapping params = self.params # Implement LPT scheduling coordinate calculation bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod) @@ -342,25 +563,45 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: bidhb_actual = bidhb * params.l2_minor + bidhb_residual batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) # Longest-processing-time-first - block = params.num_block - 1 - block + if const_expr(params.lpt): + block = params.num_block - 1 - block is_valid = self._tile_idx < params.total_blocks return WorkTileInfo( (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid ) + @cute.jit def initial_work_tile_info(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self.clc.initial_work_tile_info() + self._tile_idx = work.tile_idx[0] + return self.clc_work_to_coords(work) return self.get_current_work(loc=loc, ip=ip) def prefetch_next_work(self, *, loc=None, ip=None): - pass + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.prefetch_next_work(loc=loc, ip=ip) def advance_to_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.consumer_wait(loc=loc, ip=ip) + work = self.get_current_work() + self.clc.consumer_release(loc=loc, ip=ip) + return work # Single tile scheduler - set to invalid tile_idx to indicate no more work self._tile_idx = self.params.total_blocks + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.producer_tail(loc=loc, ip=ip) def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [self.params, self._tile_idx, self._split_idx]: + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj in objs: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -368,10 +609,13 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] - for obj, n_items in zip([self.params, self._tile_idx, self._split_idx], self._values_pos): + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj, n_items in zip(objs, self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] - return self.__class__(*(tuple(obj_list)), loc=self._loc) + return self.__class__(*obj_list, loc=self._loc) class SingleTileLPTBwdScheduler: @@ -395,8 +639,8 @@ def create( ) -> "SingleTileLPTBwdScheduler.Params": size_l2 = 50 * 1024 * 1024 size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size - # size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4 - size_one_dqaccum_head = 0 + size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4 + # size_one_dqaccum_head = 0 size_one_head = size_one_qdo_head + size_one_dqaccum_head log2_floor = lambda n: 31 - clz(n) swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) @@ -430,7 +674,16 @@ def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"SingleTileLPTBwdScheduler only supports STATIC, got {scheduling_mode!r}" + ) return SingleTileLPTBwdScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod @@ -481,6 +734,7 @@ def prefetch_next_work(self, *, loc=None, ip=None): def advance_to_next_work(self, *, loc=None, ip=None): # Single tile scheduler - set to invalid tile_idx to indicate no more work self._tile_idx = self.params.total_blocks + return self.get_current_work() def __extract_mlir_values__(self): values, self._values_pos = [], [] @@ -514,20 +768,38 @@ class Params(ParamsBase): is_split_kv: cutlass.Constexpr[bool] = False head_swizzle: cutlass.Constexpr[bool] = False cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC @staticmethod @cute.jit def create( - args: TileSchedulerArguments, *, loc=None, ip=None + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, ) -> "SingleTileVarlenScheduler.Params": + assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), ( + f"Only STATIC and CLC are supported, got {scheduling_mode!r}" + ) size_l2 = 50 * 1024 * 1024 # 50 MB for K & V - max_kvblock_in_l2 = size_l2 // ( + # if backward, this is qdo block size + kv_block_size = ( (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] ) + # if backward, add dqaccum block size to calculate swizzle + if args.head_swizzle: + kv_block_size += args.headdim * 4 * args.tile_shape_mn[1] + max_kvblock_in_l2 = size_l2 // kv_block_size assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" ) assert args.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" + # TODO: Support varlen CLC with cluster_shape_m > 1 by refactoring the + # flattened-tile decode so cluster unpacking semantics are explicit. + assert scheduling_mode != SchedulingMode.CLC or args.cluster_shape_mn[0] == 1, ( + "Varlen CLC currently requires cluster_shape_mn[0] == 1" + ) return SingleTileVarlenScheduler.Params( num_head=args.num_head, num_batch=args.num_batch, @@ -542,22 +814,65 @@ def create( is_split_kv=args.is_split_kv, head_swizzle=args.head_swizzle, cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, ) - def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None): + def __init__( + self, + params: Params, + tile_idx: Int32, + split_idx: Int32, + clc: ClcState | None = None, + *, + loc=None, + ip=None, + ): self.params = params self._tile_idx = tile_idx self._split_idx = split_idx self._is_first_block = True + self.clc = clc self._loc = loc self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: - return SingleTileVarlenScheduler.Params.create(args, loc=loc, ip=ip) + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + return SingleTileVarlenScheduler.Params.create( + args, scheduling_mode=scheduling_mode, loc=loc, ip=ip + ) @staticmethod - def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler": + @cute.jit + def clc_problem_shape(params: Params): + return ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=SingleTileVarlenScheduler.get_grid_shape(params), + cluster_shape_mnk=(1, 1, 1), + ) + + @staticmethod + @cute.jit + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileVarlenScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + block_idx = cute.arch.block_idx() + split_idx = Int32(0) + if const_expr(params.is_split_kv): + split_idx = block_idx[1] + return SingleTileVarlenScheduler( + params, + block_idx[0], + split_idx, + clc, + loc=loc, + ip=ip, + ) tile_idx, split_idx, _ = cute.arch.block_idx() return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) @@ -573,7 +888,7 @@ def get_grid_shape( params.total_q + params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1) ) // params.tile_shape_mn[0] - # round down to nearest multiple of cluster since odd excess is always padding + # Round down to nearest multiple of cluster since odd excess is always padding. total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m return (total_blocks_max * params.num_head, params.num_splits, Int32(1)) @@ -601,7 +916,8 @@ def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: ) @cute.jit - def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + def _varlen_coord_map(self) -> WorkTileInfo: + """Map self._tile_idx to (block, head, batch) via warp-level prefix sums.""" params = self.params lane_idx = cute.arch.lane_idx() num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0) @@ -654,6 +970,7 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: num_n_blocks = ( num_m_blocks * params.tile_shape_mn[0] + * params.cluster_shape_m // params.qhead_per_kvhead_packgqa // params.tile_shape_mn[1] ) @@ -698,19 +1015,62 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0) return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid) + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + clc_work = self.clc.get_current_work() + # Default to grid_dim (one past last valid flat index) so _varlen_coord_map + # returns is_valid=False when CLC is exhausted. CLC tile_idx is garbage when + # invalid, so we can't trust it. Local-then-assign avoids CuTe DSL structural + # mismatch on self inside the runtime if. + new_tile_idx = cute.arch.grid_dim()[0] + new_split_idx = Int32(0) + if clc_work.is_valid_tile: + new_tile_idx = clc_work.tile_idx[0] + if const_expr(self.params.is_split_kv): + new_split_idx = clc_work.tile_idx[1] + self._tile_idx = new_tile_idx + self._split_idx = new_split_idx + return self._varlen_coord_map() + + @cute.jit def initial_work_tile_info(self, *, loc=None, ip=None): - return self.get_current_work(loc=loc, ip=ip) + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + clc_work = self.clc.initial_work_tile_info() + # See get_current_work for why grid_dim and local-then-assign. + new_tile_idx = cute.arch.grid_dim()[0] + new_split_idx = Int32(0) + if clc_work.is_valid_tile: + new_tile_idx = clc_work.tile_idx[0] + if const_expr(self.params.is_split_kv): + new_split_idx = clc_work.tile_idx[1] + self._tile_idx = new_tile_idx + self._split_idx = new_split_idx + return self._varlen_coord_map() def prefetch_next_work(self, *, loc=None, ip=None): - pass + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.prefetch_next_work(loc=loc, ip=ip) def advance_to_next_work(self, *, loc=None, ip=None): - # Single tile scheduler - set to invalid tile_idx to indicate no more work + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.consumer_wait(loc=loc, ip=ip) + work = self.get_current_work() + self.clc.consumer_release(loc=loc, ip=ip) + return work self._is_first_block = False + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.producer_tail(loc=loc, ip=ip) def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [self.params, self._tile_idx, self._split_idx]: + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj in objs: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -718,10 +1078,10 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] - for obj, n_items in zip( - [self.params, self._tile_idx, self._split_idx], - self._values_pos, - ): + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj, n_items in zip(objs, self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] - return SingleTileVarlenScheduler(*(tuple(obj_list)), loc=self._loc) + return self.__class__(*obj_list, loc=self._loc) diff --git a/flash-attn4/torch-ext/flash_attn4/utils.py b/flash-attn4/torch-ext/flash_attn4/utils.py index a05305b2..35209575 100644 --- a/flash-attn4/torch-ext/flash_attn4/utils.py +++ b/flash-attn4/torch-ext/flash_attn4/utils.py @@ -3,12 +3,14 @@ import math import hashlib import inspect +import os from typing import Type, Callable, Optional, Tuple, overload import cutlass import cutlass.cute as cute -from cutlass import Float32, const_expr +from cutlass import Float32, Int32, const_expr +from cutlass.cute import FastDivmodDivisor from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir.dialects import nvvm, llvm from cutlass.cute.runtime import from_dlpack @@ -54,6 +56,17 @@ ), } +_fa_clc_enabled: bool = os.environ.get("FA_CLC", "0") == "1" +_fa_disable_2cta_enabled: bool = os.environ.get("FA_DISABLE_2CTA", "0") == "1" + + +def _get_use_clc_scheduler_default() -> bool: + return _fa_clc_enabled + + +def _get_disable_2cta_default() -> bool: + return _fa_disable_2cta_enabled + def _compute_base_hash(func: Callable) -> str: """Compute hash from source code or bytecode and closure values.""" @@ -123,6 +136,40 @@ def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tenso return scoremod_premask_fn +LOG2_E = math.log2(math.e) + + +def compute_softmax_scale_log2(softmax_scale, score_mod): + """Compute softmax_scale_log2 and adjusted softmax_scale based on whether score_mod is used. + + When score_mod is None, fold the log2(e) factor into softmax_scale_log2 and set softmax_scale + to None. When score_mod is present, keep softmax_scale separate so it can be applied before + the score_mod, and set softmax_scale_log2 to just the change-of-base constant. + + Returns (softmax_scale_log2, softmax_scale). + """ + if const_expr(score_mod is None): + return softmax_scale * LOG2_E, None + else: + return LOG2_E, softmax_scale + + +def compute_fastdiv_mods(mQ, mK, qhead_per_kvhead, pack_gqa, aux_tensors, mPageTable=None): + """Compute FastDivmodDivisor pairs for aux_tensors index computation. + + Returns a (seqlen_q_divmod, seqlen_k_divmod) tuple, or None if aux_tensors is None. + """ + if const_expr(aux_tensors is None): + return None + seqlen_q = cute.size(mQ.shape[0]) // (qhead_per_kvhead if const_expr(pack_gqa) else 1) + seqlen_k = ( + cute.size(mK.shape[0]) + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1] + ) + return (FastDivmodDivisor(seqlen_q), FastDivmodDivisor(seqlen_k)) + + def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: return ( from_dlpack(x, assumed_align=alignment) @@ -215,6 +262,21 @@ def warp_reduce( return val +@dsl_user_op +def smid(*, loc=None, ip=None) -> Int32: + return Int32( + llvm.inline_asm( + T.i32(), + [], + "mov.u32 $0, %smid;", + "=r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @dsl_user_op def fmax( a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None @@ -429,8 +491,48 @@ def shuffle_sync( return val[0] +@dsl_user_op +def shl_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: + """ + Left-shift val by shift bits using PTX shl.b32 (sign-agnostic). + + Named ``shl_u32`` (not ``shl_b32``) because python type annotations + distinguish signed/unsigned. + + PTX semantics (ยง9.7.8.8): "Shift amounts greater than the register width N + are clamped to N." So ``shl.b32 d, a, 32`` is well-defined and yields 0. + + This differs from C/C++ and LLVM IR, where shifting by >= the type width is + undefined behavior. CuTeDSL compiles through MLIR -> LLVM IR, so a plain + Python-level ``Uint32(x) << Uint32(n)`` inherits LLVM's UB: the optimizer + may treat the result as poison and eliminate dependent code. Inline PTX + bypasses the LLVM IR shift entirely โ€” the instruction is emitted verbatim + into PTX where clamping makes it safe for all shift amounts. + """ + return cutlass.Uint32( + llvm.inline_asm( + T.i32(), + [ + cutlass.Uint32(val).ir_value(loc=loc, ip=ip), + cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), + ], + "shl.b32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @dsl_user_op def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: + """ + Unsigned right-shift val by shift bits using PTX shr.u32 (zero-fills). + + See ``shl_u32`` docstring for why inline PTX is used instead of plain + CuTeDSL shift operators (LLVM shift-by-type-width UB). + """ return cutlass.Uint32( llvm.inline_asm( T.i32(), @@ -438,7 +540,7 @@ def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32(val).ir_value(loc=loc, ip=ip), cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), ], - "shr.s32 $0, $1, $2;", + "shr.u32 $0, $1, $2;", "=r,r,r", has_side_effects=False, is_align_stack=False,