Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions xla/stream_executor/rocm/cub_scan_kernel_rocm_impl.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ limitations under the License.
#include "rocm/include/rocprim/block/block_load.hpp"
#include "rocm/include/rocprim/block/block_scan.hpp"
#include "rocm/include/rocprim/block/block_store.hpp"
#include "rocm/include/rocprim/device/detail/device_config_helper.hpp"
#include "rocm/include/rocprim/device/device_scan.hpp"
#include "rocm/include/rocprim/functional.hpp"
#include "xla/stream_executor/rocm/cub_scan_kernel_rocm.h"
Expand All @@ -34,17 +33,24 @@ namespace stream_executor::rocm {

namespace {

// Architecture-aware tuning from rocPRIM's autotuned scan config.
// Scan config using rocPRIM's generic default formula for block size and
// items-per-thread. Hardcoded to avoid depending on rocPRIM internal
// detail:: API (default_scan_config_base was renamed/restructured since
// TheRock >= 7.11).
template <typename T>
struct ScanConfig {
using RocprimConfig =
typename rocprim::detail::default_scan_config_base<T>::type;
static constexpr int kBlockSize = RocprimConfig::block_size;
static constexpr int kItemsPerThread = RocprimConfig::items_per_thread;
static constexpr unsigned int kItemScale =
(sizeof(T) + sizeof(int) - 1) / sizeof(int);
static constexpr int kBlockSize = 256;
static constexpr int kItemsPerThread =
static_cast<int>(16u / kItemScale > 0 ? 16u / kItemScale : 1u);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The ternary expression 16u / kItemScale > 0 ? 16u / kItemScale : 1u is correct by C++ operator precedence, but can trip up readers who have to mentally verify the binding. Consider std::max(1u, 16u / kItemScale) — same semantics, immediately clear intent, and avoids evaluating the division expression twice (though the compiler would optimize that away for constexpr values anyway).

Suggested change
static_cast<int>(16u / kItemScale > 0 ? 16u / kItemScale : 1u);
static_cast<int>(std::max(1u, 16u / kItemScale));

static constexpr int kTileSize = kBlockSize * kItemsPerThread;
static constexpr auto kLoadMethod = RocprimConfig::block_load_method;
static constexpr auto kStoreMethod = RocprimConfig::block_store_method;
static constexpr auto kScanAlgorithm = RocprimConfig::block_scan_method;
static constexpr auto kLoadMethod =
rocprim::block_load_method::block_load_transpose;
static constexpr auto kStoreMethod =
rocprim::block_store_method::block_store_transpose;
static constexpr auto kScanAlgorithm =
rocprim::block_scan_algorithm::using_warp_scan;
Comment on lines +36 to +53
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Observation on performance vs. the CUDA path: the CUDA implementation uses CUB's policy_hub for architecture-tuned configs at compile time, meaning it automatically adapts to different GPU architectures. This hardcoded config can't adapt. Your own benchmark data shows config 256/16/using_warp_scan (which is what this PR produces for float) achieves ~236 GB/s average on MI355X — about 12.8% slower than the best config you found (256/8/reduce_then_scan at 266.2 GB/s).

Since this is opened for discussion, a few options to consider:

  1. Keep the current generic default (as in this PR) for safety across architectures, but add a TODO comment noting the performance gap and intent to revisit if rocPRIM exposes a stable public config API.
  2. Use the better-performing config (reduce_then_scan with items_per_thread=8 at block_size=256) if MI355X is the primary target — but this needs validation on other architectures (MI210, MI300X, etc.) to avoid regressions.
  3. Parameterize by architecture at compile time using __gfx*__ macros, similar to how the CUDA path dispatches by SM version. This is more work but would let you pick optimal configs per GPU family.

Any of these are defensible; the key is documenting the choice and the known gap.

};

template <typename T>
Expand Down
Loading