[ROCm] Remove rocPRIM detail:: dependency from CUB scan kernel#792
[ROCm] Remove rocPRIM detail:: dependency from CUB scan kernel#792magaonka-amd wants to merge 1 commit intoROCm:mainfrom
Conversation
- Replace rocprim::detail::default_scan_config_base<T> with hardcoded scan config using the same generic formula (block_size=256, items_per_thread=max(1, 16/item_scale), transpose load/store, warp_scan) - Remove #include of rocprim/device/detail/device_config_helper.hpp - Fixes build breakage since TheRock >= 7.11 where default_scan_config_base was renamed/restructured as part of rocPRIM config modernization
| (sizeof(T) + sizeof(int) - 1) / sizeof(int); | ||
| static constexpr int kBlockSize = 256; | ||
| static constexpr int kItemsPerThread = | ||
| static_cast<int>(16u / kItemScale > 0 ? 16u / kItemScale : 1u); |
There was a problem hiding this comment.
nit: The ternary expression 16u / kItemScale > 0 ? 16u / kItemScale : 1u is correct by C++ operator precedence, but can trip up readers who have to mentally verify the binding. Consider std::max(1u, 16u / kItemScale) — same semantics, immediately clear intent, and avoids evaluating the division expression twice (though the compiler would optimize that away for constexpr values anyway).
| static_cast<int>(16u / kItemScale > 0 ? 16u / kItemScale : 1u); | |
| static_cast<int>(std::max(1u, 16u / kItemScale)); |
| // Scan config using rocPRIM's generic default formula for block size and | ||
| // items-per-thread. Hardcoded to avoid depending on rocPRIM internal | ||
| // detail:: API (default_scan_config_base was renamed/restructured since | ||
| // TheRock >= 7.11). | ||
| template <typename T> | ||
| struct ScanConfig { | ||
| using RocprimConfig = | ||
| typename rocprim::detail::default_scan_config_base<T>::type; | ||
| static constexpr int kBlockSize = RocprimConfig::block_size; | ||
| static constexpr int kItemsPerThread = RocprimConfig::items_per_thread; | ||
| static constexpr unsigned int kItemScale = | ||
| (sizeof(T) + sizeof(int) - 1) / sizeof(int); | ||
| static constexpr int kBlockSize = 256; | ||
| static constexpr int kItemsPerThread = | ||
| static_cast<int>(16u / kItemScale > 0 ? 16u / kItemScale : 1u); | ||
| static constexpr int kTileSize = kBlockSize * kItemsPerThread; | ||
| static constexpr auto kLoadMethod = RocprimConfig::block_load_method; | ||
| static constexpr auto kStoreMethod = RocprimConfig::block_store_method; | ||
| static constexpr auto kScanAlgorithm = RocprimConfig::block_scan_method; | ||
| static constexpr auto kLoadMethod = | ||
| rocprim::block_load_method::block_load_transpose; | ||
| static constexpr auto kStoreMethod = | ||
| rocprim::block_store_method::block_store_transpose; | ||
| static constexpr auto kScanAlgorithm = | ||
| rocprim::block_scan_algorithm::using_warp_scan; |
There was a problem hiding this comment.
Observation on performance vs. the CUDA path: the CUDA implementation uses CUB's policy_hub for architecture-tuned configs at compile time, meaning it automatically adapts to different GPU architectures. This hardcoded config can't adapt. Your own benchmark data shows config 256/16/using_warp_scan (which is what this PR produces for float) achieves ~236 GB/s average on MI355X — about 12.8% slower than the best config you found (256/8/reduce_then_scan at 266.2 GB/s).
Since this is opened for discussion, a few options to consider:
- Keep the current generic default (as in this PR) for safety across architectures, but add a TODO comment noting the performance gap and intent to revisit if rocPRIM exposes a stable public config API.
- Use the better-performing config (
reduce_then_scanwithitems_per_thread=8atblock_size=256) if MI355X is the primary target — but this needs validation on other architectures (MI210, MI300X, etc.) to avoid regressions. - Parameterize by architecture at compile time using
__gfx*__macros, similar to how the CUDA path dispatches by SM version. This is more work but would let you pick optimal configs per GPU family.
Any of these are defensible; the key is documenting the choice and the known gap.
Review SummaryReviewed the removal of Key observations (details in inline comments):
No correctness issues found — the kernel interface is unchanged, so existing tests should continue to pass. The main open question is the performance/portability tradeoff for the scan config, which the author has already flagged. |
|
It looks good to me. Regarding the config change, I am OK with changing it but should we test performance for other datatypes (fp16, fp8) first? |
|
@magaonka-amd https://github.com/ROCm/rocm-libraries/blob/13bf528af264e243f181b85b23acb739ebb35d61/projects/rocprim/rocprim/include/rocprim/device/device_segmented_scan.hpp#L469 |
==========OPENED PR for DISCUSSION not for MERGING here=============
why move to hardcoded configs instead of try_compile() approach ?
I honestly didn't find single instance of this being used in XLA so I would be the first one to do so wasn't sure it will fly in upstream.
why not version guard this?:
currently rocm version numbers are kind of all over the place due to rocm to theRock transition and this makes it version guarding hard.
also existing code was also pulling default configs, my bad I was trying to see if there is way to do it per arch but I was wrong. rocprim has arch specific scan configs ( up until gfx942 ) for runtime but here we are trying to have scan configs compile time.
so I think having hardcoded value is not bad. but which exact scan config we should choose is debatable.
debate part is listed below :
Default numbers I added in this PR is recommended default value from rocprim. But to get better picture I did small sweep exercise where I go through all possible scan configs and try to see if they are significantly better than default config. All my testing was on MI355 and using ROCm 7.2.
My benchmarking patch looks like this:
diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index af04328591..fa9865ce64 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -1270,10 +1270,12 @@ xla_test( "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", + "//xla/tsl/platform:test_benchmark", + "//xla/tsl/platform:test_main", "@local_config_rocm//rocm:rocm_headers", ], ) diff --git a/xla/stream_executor/rocm/cub_scan_kernel_rocm_test.cc b/xla/stream_executor/rocm/cub_scan_kernel_rocm_test.cc index 043ed8d041..c0c52d3d9a 100644 --- a/xla/stream_executor/rocm/cub_scan_kernel_rocm_test.cc +++ b/xla/stream_executor/rocm/cub_scan_kernel_rocm_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include <gmock/gmock.h> #include <gtest/gtest.h> #include "absl/cleanup/cleanup.h" +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" @@ -39,6 +40,7 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/test_benchmark.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" @@ -346,5 +348,72 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(false)), ParametersToString); +//===----------------------------------------------------------------------===// +// Performance benchmarks +//===----------------------------------------------------------------------===// + +static se::Platform* GetRocmPlatform() { + return se::PlatformManager::PlatformWithName("ROCM").value(); +} + +static void BM_CubScan1D(benchmark::State& state) { + int64_t row_length = state.range(0); + se::StreamExecutor* executor = + GetRocmPlatform()->ExecutorForDevice(0).value(); + auto stream = executor->CreateStream(std::nullopt).value(); + se::DeviceAddress<float> device_data = + executor->AllocateArray<float>(row_length); + size_t scratch_bytes = + CubScanGetScratchSize(xla::F32, 1, row_length, 1, CubScanKind::kSum, + false) + .value(); + se::DeviceMemory<uint8_t> scratch = + executor->AllocateArray<uint8_t>(std::max(scratch_bytes, size_t{1})); + auto cleanup = absl::MakeCleanup([&]() { + executor->Deallocate(&device_data); + executor->Deallocate(&scratch); + }); + auto hip_stream = + static_cast<hipStream_t>(stream->platform_specific_handle().stream); + + for (auto _ : state) { + CHECK_OK(CubScanLaunchKernel(xla::F32, scratch.opaque(), scratch_bytes, + device_data.opaque(), device_data.opaque(), 1, + row_length, 1, CubScanKind::kSum, false, + hip_stream)); + CHECK_OK(stream->BlockHostUntilDone()); + } + state.SetBytesProcessed(state.iterations() * row_length * sizeof(float)); +} +BENCHMARK(BM_CubScan1D)->RangeMultiplier(4)->Range(4096, 16 * 1024 * 1024); + +static void BM_CubScan2D(benchmark::State& state) { + int64_t row_length = state.range(0); + int64_t col_length = state.range(1); + int64_t total = row_length * col_length; + se::StreamExecutor* executor = + GetRocmPlatform()->ExecutorForDevice(0).value(); + auto stream = executor->CreateStream(std::nullopt).value(); + se::DeviceAddress<float> device_data = executor->AllocateArray<float>(total); + auto cleanup = + absl::MakeCleanup([&]() { executor->Deallocate(&device_data); }); + auto hip_stream = + static_cast<hipStream_t>(stream->platform_specific_handle().stream); + + for (auto _ : state) { + CHECK_OK(CubScanLaunchKernel( + xla::F32, nullptr, 0, device_data.opaque(), device_data.opaque(), 1, + row_length, col_length, CubScanKind::kSum, false, hip_stream)); + CHECK_OK(stream->BlockHostUntilDone()); + } + state.SetBytesProcessed(state.iterations() * total * sizeof(float)); +} +BENCHMARK(BM_CubScan2D) + ->Args({1024, 1024}) + ->Args({4096, 256}) + ->Args({256, 4096}) + ->Args({512, 2048}) + ->Args({8192, 128}); + } // namespace } // namespace stream_executor::rocmMy iteration logic to sweep all combos:
And below is my 2D scan results with various scan config combos:
my experiments above show default is not the best configuration at least for MI355 case. I would like to hear opinion on if it is okay to move away from rocprim default config and introduce our own scan config numbers here??.
Also good number in MI355 may not mean good performance on all HW so I'm little confused on picking best config here.