Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Kernel-specific instructions

## flash-attn4

When the user asks to sync a flash-attn4 release, carry out the following
steps:

- Fetch the upstream Git repository from https://github.com/Dao-AILab/flash-attention.git
- Check out the tag that the user specified.
- Flash Attention 4 is in the directory `flash_attn/cute` of the upstream repo.
- Copy Flash Attention 4 upstream files to `flash-attn4/torch-ext/flash_attn4`.
- Copy tests from the tests from the upstream directory `tests/cute` to
`flash-attn4/tests/cute`.
- Check in `flash_attn/cute/pyproject.toml` upstream what version of quack is
required.
- Get this version of quack from https://github.com/Dao-AILab/quack.git
- Copy the `quack` directory from quack to `flash-attn4/torch-ext/flash_attn4/quack`
- Now make all imports of Flash Attention 4 and quack in
`flash-attn4/torch-ext/flash_attn4` and `flash-attn4/torch-ext/flash_attn4/quack`
relative imports.
- Remove all quack files in `flash-attn4/torch-ext/flash_attn4/quack` that are not used.
- Update imports of `flash_attn.cute` in `flash-attn4/tests/cute` to `flash_attn4`.
- Set `__version__` in `flash-attn4/torch-ext/flash_attn4/__init__.py` to the
version from the tag (e.g. for tag `fa4-v4.0.0.beta8` set it to
`"4.0.0.beta8"`). Remove any `importlib.metadata` version lookup code.
- Check whether any Torch custom ops are defined in `flash-attn4/torch-ext/flash_attn4`
or `flash-attn4/torch-ext/flash_attn4/quack` (look for `torch.library.custom_op`,
`torch.library.define`, etc.). If any are found, update them to use
`add_op_namespace_prefix` for the op name. For example, a definition like
`@torch.library.custom_op("_flash_attn_forward", mutates_args=(), device_types="cuda")`
should become
`@torch.library.custom_op(add_op_namespace_prefix("_flash_attn_forward"), mutates_args=(), device_types="cuda")`.
`add_op_namespace_prefix` is imported from `._ops` (see
`flash-attn3/torch-ext/flash_attn3/flash_attn_interface.py` for an example).

If the user did not specify the version tag, stop and ask which tag to sync
from.
2 changes: 1 addition & 1 deletion flash-attn4/tests/cute/benchmark_mask_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
import torch

from flash_attn4.flash_fwd import FlashAttentionForwardSm90
from flash_attn4.flash_fwd_sm90 import FlashAttentionForwardSm90
from mask_mod_definitions import (
get_mask_pair,
random_doc_id_tensor,
Expand Down
44 changes: 42 additions & 2 deletions flash-attn4/tests/cute/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import os
import subprocess
import logging
import tempfile
import json
import time
from pathlib import Path
from getpass import getuser


def _get_gpu_ids():
Expand All @@ -16,16 +22,50 @@ def _get_gpu_ids():
)
if result.returncode == 0:
return result.stdout.strip().splitlines()
except (FileNotFoundError, subprocess.TimeoutExpired):
except (FileNotFoundError,):
pass

logging.warning("Failed to get gpu ids, use default '0'")
return ["0"]


def pytest_configure(config):
tmp = Path(tempfile.gettempdir()) / getuser() / "flash_attention_tests"
tmp.mkdir(parents=True, exist_ok=True)

worker_id = os.environ.get("PYTEST_XDIST_WORKER")
logging.basicConfig(
format=config.getini("log_file_format"),
filename=str(tmp / f"tests_{worker_id}.log"),
level=config.getini("log_file_level"),
)
if not worker_id:
return
worker_num = int(worker_id.replace("gw", ""))
gpu_ids = _get_gpu_ids()

# cache gpu_ids, because nvidia-smi is expensive when we launch many workers doing torch initialization
# Always elect worker_0 to get gpu_ids.
cached_gpu_ids = tmp / "gpu_ids.json"
if worker_num == 0:
gpu_ids = _get_gpu_ids()
with cached_gpu_ids.open(mode="w") as f:
json.dump(gpu_ids, f)
else:
while not cached_gpu_ids.exists():
time.sleep(1)
with cached_gpu_ids.open() as f:
gpu_ids = json.load(f)

os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[worker_num % len(gpu_ids)]

def pytest_collection_finish(session):
if not session.config.option.collectonly:
return

# file_name -> test_name -> counter
test_counts: dict[str, dict[str, int]] = {}
for item in session.items:
funcname = item.function.__name__
parent = test_counts.setdefault(item.parent.name, {})
parent[funcname] = parent.setdefault(funcname, 0) + 1
print(json.dumps(test_counts, indent=2))
Loading
Loading