Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions aiter/ops/flydsl/gemm_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@

from ..shuffle import shuffle_weight
from .kernels.splitk_hgemm import compile_hgemm_kernel
from .utils import is_flydsl_available
from .utils import get_shared_memory_per_block, is_flydsl_available

__all__ = [
"flydsl_hgemm",
]

SPLIT_K_COUNTER_MAX_LEN = 128
SPLIT_K_SIGNAL_STATE_COUNT = 3
MAX_LDS_BYTES = 163840
FIXED_STAGE = 2
FIXED_C_TO_LDS = False
KERNEL_ASYNC_COPY = get_rocm_arch() != "gfx942"
Expand Down Expand Up @@ -333,10 +332,11 @@ def _validate_hgemm_tiling(
stages=stages,
b_to_lds=b_to_lds,
)
if lds_bytes > MAX_LDS_BYTES:
lds_limit = get_shared_memory_per_block(fallback_gfx=get_gfx())
if lds_bytes > lds_limit:
raise ValueError(
"Invalid tile combination: estimated LDS usage "
f"{lds_bytes} exceeds the hardware limit {MAX_LDS_BYTES}"
f"{lds_bytes} exceeds the hardware limit {lds_limit}"
)


Expand Down
26 changes: 9 additions & 17 deletions aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import math
import os

from aiter.ops.flydsl.utils import (
addressable_lds_bytes_for_gfx as _addressable_lds_bytes_for_gfx,
get_shared_memory_per_block,
)


def get_gfx():
"""Detect GPU arch: honour GPU_ARCHS env, fall back to chip_info, default gfx942."""
Expand Down Expand Up @@ -165,26 +170,13 @@ def kernel_instance_estimated_lds_bytes(ki: kernelInstance) -> int:
)


# Per-kernel LDS cap for tune filtering (must match LLVM AMDGPU
# getAddressableLocalMemorySize for the compile target).
# When arch cannot be parsed (no GPU, bad string), stay conservative for CDNA.
_FALLBACK_MAX_LDS_BYTES = 65536


def addressable_lds_bytes_for_gfx(gfx: str) -> int:
g = (gfx or "").strip().lower().split(":")[0]
if not g.startswith("gfx"):
return _FALLBACK_MAX_LDS_BYTES
if g.startswith("gfx950"):
return 163840
if g.startswith("gfx7") or g.startswith("gfx8"):
return 32768
return 65536
return _addressable_lds_bytes_for_gfx(gfx)


def max_lds_bytes_for_tune() -> int:
"""Addressable LDS limit for current target (from ``get_gfx()``)."""
return addressable_lds_bytes_for_gfx(get_gfx())
"""Addressable LDS limit for current target."""
return get_shared_memory_per_block(fallback_gfx=get_gfx())


# fmt: off
Expand Down Expand Up @@ -277,7 +269,7 @@ def _estimate_max_wpe(tile_m: int, tile_n: int, total_vgpr: int = 512) -> int:

Preshuffle GEMM always uses 16x16 MFMA (4 VGPRs per thread per block).
Per-thread accum VGPRs = round_up(tile_m, 16) * round_up(tile_n, 16) / 256.
Estimated total accum * 1.5 (pipeline overhead for A/B buffers).
Estimated total ? accum * 1.5 (pipeline overhead for A/B buffers).
Returns the max waves_per_eu that the register file can support.
"""
padded_m = math.ceil(tile_m / _MFMA_M) * _MFMA_M
Expand Down
42 changes: 29 additions & 13 deletions aiter/ops/flydsl/kernels/splitk_hgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr

from .tensor_shim import GTensor, STensor, _to_raw, get_dtype_in_kernel
from ..utils import get_shared_memory_per_block

SPLIT_K_COUNTER_MAX_LEN = 128
SPLIT_K_SIGNAL_STATE_COUNT = 3
Expand Down Expand Up @@ -178,6 +179,15 @@ def compile_hgemm_kernel(
assert BLOCK_MN_SIZE % BLOCK_VECS == 0
BLOCK_K_BYTES = BLOCK_K * DTYPE_BYTES

KERNEL_NAME = f"hgemm_{dtype}_{BLOCK_M}x{BLOCK_N}x{BLOCK_K}_S{STAGES}TN"
KERNEL_NAME += "_NA" if not ASYNC_COPY else "_AS"
if B_PRE_SHUFFLE:
KERNEL_NAME += "_BP"
if IS_SPLIT_K:
KERNEL_NAME += f"_SPK{SPLIT_K}"
if B_TO_LDS:
KERNEL_NAME += "_BS"

allocator = SmemAllocator(None, arch=GPU_ARCH, global_sym_name="smem")
smem_a_offset = allocator._align(allocator.ptr, 16)
AS_BYTES = STAGES * BLOCK_M * BLOCK_K * DTYPE_BYTES
Expand All @@ -188,22 +198,20 @@ def compile_hgemm_kernel(
smem_b_offset = allocator._align(allocator.ptr, 16)
allocator.ptr = smem_b_offset + STAGES * BLOCK_N * BLOCK_K * DTYPE_BYTES
SMEM_USE += STAGES * BLOCK_N * BLOCK_K * DTYPE_BYTES
assert SMEM_USE <= 163840
smem_limit = get_shared_memory_per_block(fallback_gfx=GPU_ARCH)
if SMEM_USE > smem_limit:
raise RuntimeError(
f"{KERNEL_NAME} requires {SMEM_USE} bytes LDS, "
f"but device limit is {smem_limit} bytes "
f"(arch={GPU_ARCH}, TILE_M={TILE_M}, TILE_N={TILE_N}, TILE_K={TILE_K}, "
f"SPLIT_K={SPLIT_K}, B_TO_LDS={B_TO_LDS})",
)
LDG_ASYNC_VEC_SIZE = DMA_BYTES // DTYPE_BYTES
LDG_A_X_THREADS_AS = BLOCK_K // LDG_ASYNC_VEC_SIZE
LDG_REG_A_COUNT_AS = BLOCK_MK_SIZE // LDG_ASYNC_VEC_SIZE // BLOCK_THREADS
LDG_B_X_THREADS_AS = BLOCK_K // LDG_ASYNC_VEC_SIZE
LDG_REG_B_COUNT_AS = BLOCK_NK_SIZE // LDG_ASYNC_VEC_SIZE // BLOCK_THREADS

KERNEL_NAME = f"hgemm_{dtype}_{BLOCK_M}x{BLOCK_N}x{BLOCK_K}_S{STAGES}TN"
KERNEL_NAME += "_NA" if not ASYNC_COPY else "_AS"
if B_PRE_SHUFFLE:
KERNEL_NAME += "_BP"
if IS_SPLIT_K:
KERNEL_NAME += f"_SPK{SPLIT_K}"
if B_TO_LDS:
KERNEL_NAME += "_BS"

@flyc.kernel
def hgemm_kernel(
C: fx.Tensor,
Expand Down Expand Up @@ -925,9 +933,17 @@ def _launch(*args, **kwargs):
def _compile(C, A, B, m, COUNTER, signal_state, stream):
with CompilationContext.compile_hints(_compile_hints):
if _compile_cache.get(m, None) is None:
_compile_cache[m] = flyc.compile(
launch_hgemm_kernel, C, A, B, m, COUNTER, signal_state, stream
)
try:
_compile_cache[m] = flyc.compile(
launch_hgemm_kernel, C, A, B, m, COUNTER, signal_state, stream
)
except Exception as e:
raise RuntimeError(
f"{KERNEL_NAME} failed "
f"(arch={GPU_ARCH}, n={n}, k={k}, TILE_M={TILE_M}, TILE_N={TILE_N}, "
f"TILE_K={TILE_K}, SPLIT_K={SPLIT_K}, B_TO_LDS={B_TO_LDS}, "
f"SMEM_USE={SMEM_USE}, SMEM_LIMIT={smem_limit}): {e}",
) from e
return _compile_cache[m]

_launch.compile = _compile
Expand Down
31 changes: 31 additions & 0 deletions aiter/ops/flydsl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,37 @@

import importlib.util

import torch

_FALLBACK_MAX_LDS_BYTES = 65536


def addressable_lds_bytes_for_gfx(gfx: str) -> int:
g = (gfx or "").strip().lower().split(":")[0]
if not g.startswith("gfx"):
return _FALLBACK_MAX_LDS_BYTES
if g.startswith("gfx950"):
return 163840
if g.startswith("gfx7") or g.startswith("gfx8"):
return 32768
return 65536


def get_shared_memory_per_block(device=None, fallback_gfx: str = "") -> int:
"""Return per-block shared memory/LDS limit for the active device."""
try:
if device is None:
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
shared_memory_per_block = int(getattr(props, "shared_memory_per_block", 0) or 0)
if shared_memory_per_block > 0:
return shared_memory_per_block
return addressable_lds_bytes_for_gfx(
getattr(props, "gcnArchName", fallback_gfx)
)
except Exception:
return addressable_lds_bytes_for_gfx(fallback_gfx)


def is_flydsl_available() -> bool:
return importlib.util.find_spec("flydsl") is not None
10 changes: 5 additions & 5 deletions aiter/utility/base_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,13 @@ def post_process(self, rets, args, topk=-1, fast_mode=False):
f"error: no valid candidate found for {info_key}, please check the result or errRatio in all result file running with --profile_file"
)

if len(filtered_time) < topk:
topk = len(filtered_time)
print(f"choose {topk} kernels")
self.topk = topk
effective_topk = min(topk, len(filtered_time))
if effective_topk < topk:
print(f"choose {effective_topk} kernels")
self.topk = effective_topk
best_config = [
((info_key, *info_ex), us, max_err_ratio)
for info_ex, us, max_err_ratio in filtered_time[0:topk]
for info_ex, us, max_err_ratio in filtered_time[0:effective_topk]
]
if not best_config:
logger.info(f"No kernel can be used for {info_key}")
Expand Down
19 changes: 14 additions & 5 deletions aiter/utility/mp_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@
from aiter import logger


def _is_mapping_error(exc: BaseException) -> bool:
return isinstance(exc, KeyError)


def _is_accelerator_error(exc: BaseException) -> bool:
return type(exc).__name__ == "AcceleratorError"


def worker(
gpu_id,
info,
Expand Down Expand Up @@ -36,7 +44,7 @@ def worker(
res, us = run_perftest(func, *args, **kwargs)
us = round(us, 4)

except RuntimeError as e:
except (RuntimeError, ValueError) as e:
print(f"run gpu func warning: info:{info}\t {e}", flush=True)
us = -1 # not support or error
max_err_ratio = 1.0
Expand All @@ -50,6 +58,8 @@ def worker(
if us == 0:
print(f"Warning: try run {max_retries} times, but still get 0!")
torch.cuda.synchronize()
if us == -1 or res is None:
return info, us, round(max_err_ratio, 4)
if ref is not None:
if isinstance(ref, torch.Tensor):
ref = [ref]
Expand Down Expand Up @@ -448,14 +458,13 @@ def add_dummy_result(k, results_list):
except Exception as e:
# Check if it's a process crash (segfault, memory fault, etc.)
error_type = type(e).__name__

# Special handling for KeyError (PID mapping issue)
is_mapping_error = error_type == "KeyError"
is_mapping_error = _is_mapping_error(e)
is_accelerator_error = _is_accelerator_error(e)
# not restart as this is not root use
if is_mapping_error:
error_msg = f"[Mapping Error] Task {k} - Process PID not in GPU map: {error_type} - {e}"
dummy_failed_tasks.append((k, "mapping error"))
elif error_type == "AcceleratorError":
elif is_accelerator_error:
# GPU fault (e.g. illegal memory access): worker returns exception instead of
# hanging. Unlike hang->timeout, the faulting worker may stay alive and accept
# more tasks on the same bad GPU. Break immediately to trigger restart and
Expand Down
Loading