From decaca3fa93c7d2ed88083bc9fc503d68723ada5 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Tue, 7 Apr 2026 00:29:35 -0700 Subject: [PATCH 1/8] extracts histo into own kernel --- cub/cub/agent/agent_topk.cuh | 279 +++++++++++++--------- cub/cub/device/dispatch/dispatch_topk.cuh | 241 +++++++++++-------- 2 files changed, 303 insertions(+), 217 deletions(-) diff --git a/cub/cub/agent/agent_topk.cuh b/cub/cub/agent/agent_topk.cuh index ccdb1fac20c..fe6ea64ec5e 100644 --- a/cub/cub/agent/agent_topk.cuh +++ b/cub/cub/agent/agent_topk.cuh @@ -447,7 +447,6 @@ struct AgentTopK } // Fused filtering of the current pass and building histogram for the next pass - template _CCCL_DEVICE _CCCL_FORCEINLINE void filter_and_histogram( key_in_t* in_buf, OffsetT* in_idx_buf, @@ -465,116 +464,103 @@ struct AgentTopK // Make sure the histogram was initialized __syncthreads(); - if constexpr (IsFirstPass) - { - // During the first pass, compute per-thread block histograms over the full input. The per-thread block histograms - // are being added to the global histogram further down below. - auto f = [this](key_in_t key, OffsetT /*index*/) { - const int bucket = extract_bin_op(key); - atomicAdd(temp_storage.histogram + bucket, OffsetT{1}); - }; - process_range(d_keys_in, previous_len, f); - } - else - { - OffsetT* p_filter_cnt = &counter->filter_cnt; - OutOffsetT* p_out_cnt = &counter->out_cnt; - - // Lambda for early_stop = true (i.e., we have identified the exact "splitter" key): - // Select all items that fall into the bin of the k-th item (i.e., the 'candidates') and the ones that fall into - // bins preceding the k-th item bin (i.e., 'selected' items), write them to output. - // We can skip histogram computation because we don't need to further passes to refine the candidates. - auto f_early_stop = [load_from_original_input, in_idx_buf, p_out_cnt, this](key_in_t key, OffsetT i) { - const candidate_class pre_res = identify_candidates_op(key); - if (pre_res == candidate_class::candidate || pre_res == candidate_class::selected) - { - const OutOffsetT pos = atomicAdd(p_out_cnt, OutOffsetT{1}); - d_keys_out[pos] = key; - if constexpr (!keys_only) - { - const OffsetT index = load_from_original_input ? i : in_idx_buf[i]; - d_values_out[pos] = d_values_in[index]; - } - } - }; - - // Lambda for early_stop = false, out_buf != nullptr (i.e., we need to further refine the candidates in the next - // pass): Write out selected items to output, write candidates to out_buf, and build histogram for candidates. - auto f_with_out_buf = [load_from_original_input, in_idx_buf, out_buf, out_idx_buf, p_filter_cnt, p_out_cnt, this]( - key_in_t key, OffsetT i) { - const candidate_class pre_res = identify_candidates_op(key); - if (pre_res == candidate_class::candidate) - { - const OffsetT pos = atomicAdd(p_filter_cnt, OffsetT{1}); - out_buf[pos] = key; - if constexpr (!keys_only) - { - const OffsetT index = load_from_original_input ? i : in_idx_buf[i]; - out_idx_buf[pos] = index; - } + OffsetT* p_filter_cnt = &counter->filter_cnt; + OutOffsetT* p_out_cnt = &counter->out_cnt; - const int bucket = extract_bin_op(key); - atomicAdd(temp_storage.histogram + bucket, OffsetT{1}); - } - else if (pre_res == candidate_class::selected) - { - const OutOffsetT pos = atomicAdd(p_out_cnt, OutOffsetT{1}); - d_keys_out[pos] = key; - if constexpr (!keys_only) - { - const OffsetT index = in_idx_buf ? in_idx_buf[i] : i; - d_values_out[pos] = d_values_in[index]; - } - } - }; - - // Lambda for early_stop = false, out_buf = nullptr (i.e., we need to further refine the candidates in the next - // pass, but we skip writing candidates to out_buf): - // Just build histogram for candidates. - // Note: We will only begin writing to d_keys_out starting from the pass in which the number of output-candidates - // is small enough to fit into the output buffer (otherwise, we would be writing the same items to d_keys_out - // multiple times). - auto f_no_out_buf = [this](key_in_t key, OffsetT i) { - const candidate_class pre_res = identify_candidates_op(key); - if (pre_res == candidate_class::candidate) + // Lambda for early_stop = true (i.e., we have identified the exact "splitter" key): + // Select all items that fall into the bin of the k-th item (i.e., the 'candidates') and the ones that fall into + // bins preceding the k-th item bin (i.e., 'selected' items), write them to output. + // We can skip histogram computation because we don't need to further passes to refine the candidates. + auto f_early_stop = [load_from_original_input, in_idx_buf, p_out_cnt, this](key_in_t key, OffsetT i) { + const candidate_class pre_res = identify_candidates_op(key); + if (pre_res == candidate_class::candidate || pre_res == candidate_class::selected) + { + const OutOffsetT pos = atomicAdd(p_out_cnt, OutOffsetT{1}); + d_keys_out[pos] = key; + if constexpr (!keys_only) { - const int bucket = extract_bin_op(key); - atomicAdd(temp_storage.histogram + bucket, OffsetT{1}); + const OffsetT index = load_from_original_input ? i : in_idx_buf[i]; + d_values_out[pos] = d_values_in[index]; } - }; + } + }; - // Choose and invoke the appropriate lambda with the correct input source - // If the input size exceeds the allocated buffer size, we know for sure we haven't started writing candidates to - // the output buffer yet - if (load_from_original_input) + // Lambda for early_stop = false, out_buf != nullptr (i.e., we need to further refine the candidates in the next + // pass): Write out selected items to output, write candidates to out_buf, and build histogram for candidates. + auto f_with_out_buf = [load_from_original_input, in_idx_buf, out_buf, out_idx_buf, p_filter_cnt, p_out_cnt, this]( + key_in_t key, OffsetT i) { + const candidate_class pre_res = identify_candidates_op(key); + if (pre_res == candidate_class::candidate) { - if (early_stop) - { - process_range(d_keys_in, previous_len, f_early_stop); - } - else if (out_buf) + const OffsetT pos = atomicAdd(p_filter_cnt, OffsetT{1}); + out_buf[pos] = key; + if constexpr (!keys_only) { - process_range(d_keys_in, previous_len, f_with_out_buf); + const OffsetT index = load_from_original_input ? i : in_idx_buf[i]; + out_idx_buf[pos] = index; } - else + + const int bucket = extract_bin_op(key); + atomicAdd(temp_storage.histogram + bucket, OffsetT{1}); + } + else if (pre_res == candidate_class::selected) + { + const OutOffsetT pos = atomicAdd(p_out_cnt, OutOffsetT{1}); + d_keys_out[pos] = key; + if constexpr (!keys_only) { - process_range(d_keys_in, previous_len, f_no_out_buf); + const OffsetT index = in_idx_buf ? in_idx_buf[i] : i; + d_values_out[pos] = d_values_in[index]; } } + }; + + // Lambda for early_stop = false, out_buf = nullptr (i.e., we need to further refine the candidates in the next + // pass, but we skip writing candidates to out_buf): + // Just build histogram for candidates. + // Note: We will only begin writing to d_keys_out starting from the pass in which the number of output-candidates + // is small enough to fit into the output buffer (otherwise, we would be writing the same items to d_keys_out + // multiple times). + auto f_no_out_buf = [this](key_in_t key, OffsetT i) { + const candidate_class pre_res = identify_candidates_op(key); + if (pre_res == candidate_class::candidate) + { + const int bucket = extract_bin_op(key); + atomicAdd(temp_storage.histogram + bucket, OffsetT{1}); + } + }; + + // Choose and invoke the appropriate lambda with the correct input source + // If the input size exceeds the allocated buffer size, we know for sure we haven't started writing candidates to + // the output buffer yet + if (load_from_original_input) + { + if (early_stop) + { + process_range(d_keys_in, previous_len, f_early_stop); + } + else if (out_buf) + { + process_range(d_keys_in, previous_len, f_with_out_buf); + } else { - if (early_stop) - { - process_range(in_buf, previous_len, f_early_stop); - } - else if (out_buf) - { - process_range(in_buf, previous_len, f_with_out_buf); - } - else - { - process_range(in_buf, previous_len, f_no_out_buf); - } + process_range(d_keys_in, previous_len, f_no_out_buf); + } + } + else + { + if (early_stop) + { + process_range(in_buf, previous_len, f_early_stop); + } + else if (out_buf) + { + process_range(in_buf, previous_len, f_with_out_buf); + } + else + { + process_range(in_buf, previous_len, f_no_out_buf); } } @@ -703,7 +689,6 @@ struct AgentTopK } } - template _CCCL_DEVICE _CCCL_FORCEINLINE void invoke_filter_and_histogram( key_in_t* in_buf, OffsetT* in_idx_buf, @@ -713,22 +698,9 @@ struct AgentTopK OffsetT* histogram, int pass) { - OutOffsetT current_k; - OffsetT previous_len; - OffsetT current_len; - - if constexpr (IsFirstPass) - { - current_k = k; - previous_len = num_items; - current_len = num_items; - } - else - { - current_k = counter->k; - current_len = counter->len; - previous_len = counter->previous_len; - } + const OutOffsetT current_k = counter->k; + const OffsetT current_len = counter->len; + OffsetT previous_len = counter->previous_len; // If current_len is 0, it means all the candidates have been found in previous passes. if (current_len == 0) @@ -739,7 +711,7 @@ struct AgentTopK // Early stop means that the bin containing the k-th element has been identified, and all // the elements in this bin are exactly the remaining k items we need to find. So we can // stop the process right here. - const bool early_stop = ((!IsFirstPass) && current_len == static_cast(current_k)); + const bool early_stop = (current_len == static_cast(current_k)); // If previous_len > buffer_length, it means we haven't started writing candidates to out_buf yet, // so have to make sure to load input directly from the original input. @@ -761,7 +733,7 @@ struct AgentTopK } // Fused filtering of candidates and histogram computation over the output-candidates - filter_and_histogram( + filter_and_histogram( in_buf, in_idx_buf, out_buf, out_idx_buf, previous_len, counter, histogram, early_stop, load_from_original_input); // We need this `__threadfence()` to make sure all writes to the global memory-histogram are visible to all @@ -827,6 +799,77 @@ struct AgentTopK } } } + + // Histogram-only pass: computes the histogram over the full input without filtering. + // Used for the first radix pass before any candidates have been identified. + _CCCL_DEVICE _CCCL_FORCEINLINE void + invoke_histogram_only(Counter* counter, OffsetT* histogram, int pass) + { + // Initialize shared memory histogram + init_histograms(temp_storage.histogram); + __syncthreads(); + + // Compute per-thread block histograms over the full input + auto f = [this](key_in_t key, OffsetT /*index*/) { + const int bucket = extract_bin_op(key); + atomicAdd(temp_storage.histogram + bucket, OffsetT{1}); + }; + process_range(d_keys_in, num_items, f); + + // Ensure all threads have contributed to the histogram before accumulating in global memory + __syncthreads(); + + // Merge the locally aggregated histogram into the global histogram + merge_histograms(histogram); + + // We need this `__threadfence()` to make sure all writes to the global memory-histogram are visible to all + // threads before we proceed to compute the prefix sum over the histogram. + __threadfence(); + + // Identify the last block in the grid to perform the prefix sum over the histogram and identify the bin that + // the k-th item falls into + bool is_last_block = false; + if (threadIdx.x == 0) + { + unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); + is_last_block = (finished == (gridDim.x - 1)); + } + + if (__syncthreads_or(is_last_block)) + { + if (threadIdx.x == 0) + { + counter->previous_len = num_items; + counter->filter_cnt = 0; + } + + // Compute prefix sum over the histogram's bin counts + compute_bin_offsets(histogram); + + // Make sure the prefix sum has been written to shared memory before choose_bucket() + __syncthreads(); + + // Identify the bucket that the bin that the k-th item falls into + choose_bucket(counter, k, pass); + + // Reset histogram for the next pass + // TODO: Refactor calc_start_bit, calc_mask, and calc_num_passes to uniformly work with + // total_bits (passed as a kernel parameter) instead of sizeof(KeyT), then use a single + // unconditional path for both fundamental and non-fundamental types. + if constexpr (detail::radix::is_fundamental_type_v) + { + constexpr int num_passes = calc_num_passes(bits_per_pass); + if (pass != num_passes - 1) + { + init_histograms(histogram); + } + } + else + { + init_histograms(histogram); + } + } + } }; } // namespace detail::topk CUB_NAMESPACE_END diff --git a/cub/cub/device/dispatch/dispatch_topk.cuh b/cub/cub/device/dispatch/dispatch_topk.cuh index 15dd565a4c8..acb1d301c35 100644 --- a/cub/cub/device/dispatch/dispatch_topk.cuh +++ b/cub/cub/device/dispatch/dispatch_topk.cuh @@ -243,8 +243,7 @@ template + typename IdentifyCandidatesOpT> #if _CCCL_HAS_CONCEPTS() requires topk_policy_selector #endif // _CCCL_HAS_CONCEPTS() @@ -297,8 +296,67 @@ __launch_bounds__(int(PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}).block buffer_length, extract_bin_op, identify_candidates_op) - .template invoke_filter_and_histogram( - in_buf, in_idx_buf, out_buf, out_idx_buf, counter, histogram, pass); + .invoke_filter_and_histogram(in_buf, in_idx_buf, out_buf, out_idx_buf, counter, histogram, pass); +} + +template +#if _CCCL_HAS_CONCEPTS() + requires topk_policy_selector +#endif // _CCCL_HAS_CONCEPTS() +__launch_bounds__(int(PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}).block_threads)) + _CCCL_KERNEL_ATTRIBUTES void DeviceTopKHistogramKernel( + _CCCL_GRID_CONSTANT const KeyInputIteratorT d_keys_in, + _CCCL_GRID_CONSTANT const KeyOutputIteratorT d_keys_out, + _CCCL_GRID_CONSTANT const ValueInputIteratorT d_values_in, + _CCCL_GRID_CONSTANT const ValueOutputIteratorT d_values_out, + Counter, OffsetT, OutOffsetT>* counter, + _CCCL_GRID_CONSTANT OffsetT* const histogram, + _CCCL_GRID_CONSTANT const OffsetT num_items, + _CCCL_GRID_CONSTANT const OutOffsetT k, + _CCCL_GRID_CONSTANT const OffsetT buffer_length, + ExtractBinOpT extract_bin_op, + _CCCL_GRID_CONSTANT const int pass) +{ + static constexpr topk_policy policy = PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}); + using agent_topk_policy_t = + AgentTopKPolicy; + using identify_candidates_op_t = NullType; + using agent_topk_t = + AgentTopK; + + __shared__ typename agent_topk_t::TempStorage temp_storage; + agent_topk_t( + temp_storage, + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, + num_items, + k, + buffer_length, + extract_bin_op, + identify_candidates_op_t{}) + .invoke_histogram_only(counter, histogram, pass); } template ; + auto topk_kernel = + DeviceTopKKernel; int main_kernel_blocks_per_sm = 0; if (const auto error = @@ -526,7 +583,6 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch( const auto main_kernel_max_occupancy = static_cast(main_kernel_blocks_per_sm * num_sms); const auto topk_grid_size = (::cuda::std::min) (main_kernel_max_occupancy, num_tiles); -// Log topk_kernel configuration @todo check the kernel launch #ifdef CUB_DEBUG_LOG _CubLog("Invoking topk_kernel<<<%d, %d, 0, " "%lld>>>(), %d items per thread, %d SM occupancy\n", @@ -540,101 +596,89 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch( // Initialize address variables counter_t* counter = static_cast(allocations[0]); OffsetT* histogram = static_cast(allocations[1]); - key_in_t* in_buf = nullptr; key_in_t* out_buf = nullptr; - OffsetT* in_idx_buf = nullptr; OffsetT* out_idx_buf = nullptr; - int pass = 0; + + // Pass 0: dedicated histogram-only kernel over the full input + { + auto histogram_kernel = DeviceTopKHistogramKernel< + PolicySelector, + KeyInputIteratorT, + KeyOutputIteratorT, + ValueInputIteratorT, + ValueOutputIteratorT, + OffsetT, + OutOffsetT, + key_in_t, + extract_bin_op>; + + int histogram_kernel_blocks_per_sm = 0; + if (const auto error = + CubDebug(launcher_factory.MaxSmOccupancy(histogram_kernel_blocks_per_sm, histogram_kernel, block_threads))) + { + return error; + } + const auto histogram_kernel_max_occupancy = static_cast(histogram_kernel_blocks_per_sm * num_sms); + const auto histogram_grid_size = (::cuda::std::min) (histogram_kernel_max_occupancy, num_tiles); + + extract_bin_op extract_op(0, total_bits, decomposer); + if (const auto error = CubDebug( + launcher_factory(histogram_grid_size, block_threads, 0, stream) + .doit(histogram_kernel, + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, + counter, + histogram, + num_items, + k, + candidate_buffer_length, + extract_op, + 0))) + { + return error; + } + } + + // Passes 1..num_passes-1: fused filter + histogram kernel + key_in_t* in_buf = nullptr; + OffsetT* in_idx_buf = nullptr; + int pass = 1; for (; pass < num_passes; pass++) { - // Set operator extract_bin_op extract_op(pass, total_bits, decomposer); identify_candidates_op identify_op(&counter->kth_key_bits, pass, total_bits, decomposer); - // Initialize address variables in_buf = static_cast(pass % 2 == 0 ? allocations[2] : allocations[3]); - out_buf = pass == 0 ? nullptr : static_cast(pass % 2 == 0 ? allocations[3] : allocations[2]); + out_buf = static_cast(pass % 2 == 0 ? allocations[3] : allocations[2]); if constexpr (!keys_only) { in_idx_buf = pass <= 1 ? nullptr : static_cast(pass % 2 == 0 ? allocations[4] : allocations[5]); - out_idx_buf = pass == 0 ? nullptr : static_cast(pass % 2 == 0 ? allocations[5] : allocations[4]); + out_idx_buf = static_cast(pass % 2 == 0 ? allocations[5] : allocations[4]); } - // Invoke kernel - if (pass == 0) - { - auto topk_first_pass_kernel = DeviceTopKKernel< - PolicySelector, - KeyInputIteratorT, - KeyOutputIteratorT, - ValueInputIteratorT, - ValueOutputIteratorT, - OffsetT, - OutOffsetT, - key_in_t, - extract_bin_op, - identify_candidates_op, - true>; - - // Compute grid size for the histogram kernel of the first pass - int first_pass_kernel_blocks_per_sm = 0; - if (const auto error = CubDebug( - launcher_factory.MaxSmOccupancy(first_pass_kernel_blocks_per_sm, topk_first_pass_kernel, block_threads))) - { - return error; - } - const auto first_pass_kernel_max_occupancy = - static_cast(first_pass_kernel_blocks_per_sm * num_sms); - const auto topk_first_pass_grid_size = (::cuda::std::min) (first_pass_kernel_max_occupancy, num_tiles); - - // Compute histogram of the first pass - if (const auto error = CubDebug( - launcher_factory(topk_first_pass_grid_size, block_threads, 0, stream) - .doit(topk_first_pass_kernel, - d_keys_in, - d_keys_out, - d_values_in, - d_values_out, - in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - counter, - histogram, - num_items, - k, - candidate_buffer_length, - extract_op, - identify_op, - pass))) - { - return error; - } - } - else + if (const auto error = CubDebug( + launcher_factory(topk_grid_size, block_threads, 0, stream) + .doit(topk_kernel, + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + counter, + histogram, + num_items, + k, + candidate_buffer_length, + extract_op, + identify_op, + pass))) { - if (const auto error = CubDebug( - launcher_factory(topk_grid_size, block_threads, 0, stream) - .doit(topk_kernel, - d_keys_in, - d_keys_out, - d_values_in, - d_values_out, - in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - counter, - histogram, - num_items, - k, - candidate_buffer_length, - extract_op, - identify_op, - pass))) - { - return error; - } + return error; } } @@ -677,7 +721,6 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch( return error; } - // pass==num_passes to align with the usage of identify_candidates_op in previous passes. return cudaSuccess; }); } From b449e382310d8aed2615208458d291869787a2cb Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Tue, 7 Apr 2026 04:22:04 -0700 Subject: [PATCH 2/8] [bench-only] benchmarking changes --- ci/bench.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ci/bench.yaml b/ci/bench.yaml index 0527f349a96..922ca6e669a 100644 --- a/ci/bench.yaml +++ b/ci/bench.yaml @@ -27,18 +27,18 @@ benchmarks: # Inclusive regex filters (required). filters: # Examples: - # - '^cub\.bench\.for_each\.base' - # - '^cub\.bench\.reduce\.(sum|min)\.' + - '^cub\.bench\.topk\.keys\.base' + - '^cub\.bench\.topk\.pairs\.base' # Select GPUs. These are limited and shared, be intentional and conservative. gpus: - # - "t4" # sm_75, 16 GB + - "t4" # sm_75, 16 GB # - "rtx2080" # sm_75, 8 GB - # - "rtxa6000" # sm_86, 48 GB + - "rtxa6000" # sm_86, 48 GB # - "l4" # sm_89, 24 GB # - "rtx4090" # sm_89, 24 GB - # - "h100" # sm_90, 80 GB - # - "rtxpro6000" # sm_120 + - "h100" # sm_90, 80 GB + - "rtxpro6000" # sm_120 # Extra .devcontainer/launch.sh -d args # launch_args: "--cuda 13.1 --host gcc14" From 671b1058252fbdcd364e370aebfbd07c337fa0f3 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Tue, 7 Apr 2026 07:56:16 -0700 Subject: [PATCH 3/8] [bench-only] benchmarking run --- ci/bench.yaml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ci/bench.yaml b/ci/bench.yaml index 922ca6e669a..c944a587560 100644 --- a/ci/bench.yaml +++ b/ci/bench.yaml @@ -32,9 +32,9 @@ benchmarks: # Select GPUs. These are limited and shared, be intentional and conservative. gpus: - - "t4" # sm_75, 16 GB + # - "t4" # sm_75, 16 GB # - "rtx2080" # sm_75, 8 GB - - "rtxa6000" # sm_86, 48 GB + # - "rtxa6000" # sm_86, 48 GB # - "l4" # sm_89, 24 GB # - "rtx4090" # sm_89, 24 GB - "h100" # sm_90, 80 GB @@ -54,4 +54,8 @@ benchmarks: --stopping-criterion entropy --throttle-threshold 90 --throttle-recovery-delay 0.15 + --axis "OffsetT{ct}=I32" + --axis "SelectedElements[pow2]=[3:23:5]" + --axis "Entropy=[1.000,0.201,0.000]" + --axis "OutOffsetT{ct}=I32" nvbench_compare_args: "" From 4cad24d483e019259d0be56b85a98b1fb7ca3d6b Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Wed, 8 Apr 2026 01:27:51 -0700 Subject: [PATCH 4/8] [bench-only] adds smem buffering before gmem writes --- cub/cub/agent/agent_topk.cuh | 461 ++++++++++++++---- cub/cub/device/dispatch/dispatch_topk.cuh | 9 +- .../device/dispatch/tuning/tuning_topk.cuh | 16 +- 3 files changed, 378 insertions(+), 108 deletions(-) diff --git a/cub/cub/agent/agent_topk.cuh b/cub/cub/agent/agent_topk.cuh index fe6ea64ec5e..87551ac2f85 100644 --- a/cub/cub/agent/agent_topk.cuh +++ b/cub/cub/agent/agent_topk.cuh @@ -48,7 +48,8 @@ template + BlockScanAlgorithm ScanAlgorithm, + bool UseSmemWriteCoordination = true> struct AgentTopKPolicy { static constexpr int block_threads = BlockThreads; @@ -56,6 +57,7 @@ struct AgentTopKPolicy static constexpr int bits_per_pass = BitsPerPass; static constexpr BlockLoadAlgorithm load_algorithm = LoadAlgorithm; static constexpr BlockScanAlgorithm SCAN_ALGORITHM = ScanAlgorithm; + static constexpr bool use_smem_write_coordination = UseSmemWriteCoordination; }; template > @@ -252,6 +254,8 @@ struct AgentTopK static constexpr bool keys_only = ::cuda::std::is_same_v; static constexpr int bins_per_thread = ::cuda::ceil_div(num_buckets, block_threads); + static constexpr bool use_smem_write_coordination = AgentTopKPolicyT::use_smem_write_coordination; + // Parameterized BlockLoad type for input data using block_load_input_t = BlockLoad; using block_load_trans_t = BlockLoad; @@ -260,6 +264,11 @@ struct AgentTopK // Parameterized BlockStore type using block_store_trans_t = BlockStore; + struct noop_tile_epilogue + { + _CCCL_DEVICE void operator()() const {} + }; + // Shared memory struct _TempStorage { @@ -272,8 +281,22 @@ struct AgentTopK typename block_scan_t::TempStorage scan; // Smem needed for storing typename block_store_trans_t::TempStorage store_trans; + + // Staging buffer for coalesced writes (active when use_smem_write_coordination == true) + struct + { + key_in_t keys[use_smem_write_coordination ? tile_items : 1]; + OffsetT indices[(use_smem_write_coordination && !keys_only) ? tile_items : 1]; + } staging; }; OffsetT histogram[num_buckets]; + + // Write coordination counters (used when use_smem_write_coordination == true) + OffsetT smem_filter_cnt; + OutOffsetT smem_out_cnt; + OffsetT block_filter_base; + OutOffsetT block_out_base; + OutOffsetT block_out_back_base; }; /// Alias wrapper allowing storage to be unioned struct TempStorage : Uninitialized<_TempStorage> @@ -356,6 +379,15 @@ struct AgentTopK // Process a range of input data in tiles, calling f(key, index) for each element template _CCCL_DEVICE _CCCL_FORCEINLINE void process_range(InputItT in, const OffsetT num_items, FuncT f) + { + process_range(in, num_items, f, noop_tile_epilogue{}); + } + + // Process a range of input data in tiles, calling f(key, index) for each element + // and tile_epilogue() after each tile completes + template + _CCCL_DEVICE _CCCL_FORCEINLINE void + process_range(InputItT in, const OffsetT num_items, FuncT f, TileEpilogueT tile_epilogue) { key_in_t thread_data[items_per_thread]; @@ -378,6 +410,9 @@ struct AgentTopK { f(thread_data[j], offset + j); } + + tile_epilogue(); + tile_base += items_per_pass; offset += items_per_pass; } @@ -404,6 +439,8 @@ struct AgentTopK f(thread_data[j], offset + j); } } + + tile_epilogue(); } } @@ -461,106 +498,246 @@ struct AgentTopK // Initialize shared memory histogram init_histograms(temp_storage.histogram); - // Make sure the histogram was initialized + if constexpr (use_smem_write_coordination) + { + if (threadIdx.x == 0) + { + temp_storage.smem_filter_cnt = 0; + temp_storage.smem_out_cnt = 0; + } + } + + // Make sure the histogram and counters were initialized __syncthreads(); OffsetT* p_filter_cnt = &counter->filter_cnt; OutOffsetT* p_out_cnt = &counter->out_cnt; - // Lambda for early_stop = true (i.e., we have identified the exact "splitter" key): - // Select all items that fall into the bin of the k-th item (i.e., the 'candidates') and the ones that fall into - // bins preceding the k-th item bin (i.e., 'selected' items), write them to output. - // We can skip histogram computation because we don't need to further passes to refine the candidates. - auto f_early_stop = [load_from_original_input, in_idx_buf, p_out_cnt, this](key_in_t key, OffsetT i) { - const candidate_class pre_res = identify_candidates_op(key); - if (pre_res == candidate_class::candidate || pre_res == candidate_class::selected) - { - const OutOffsetT pos = atomicAdd(p_out_cnt, OutOffsetT{1}); - d_keys_out[pos] = key; - if constexpr (!keys_only) + if constexpr (use_smem_write_coordination) + { + // === SMEM WRITE COORDINATION PATH === + + auto f_early_stop = [load_from_original_input, in_idx_buf, this](key_in_t key, OffsetT i) { + const candidate_class pre_res = identify_candidates_op(key); + if (pre_res == candidate_class::candidate || pre_res == candidate_class::selected) { - const OffsetT index = load_from_original_input ? i : in_idx_buf[i]; - d_values_out[pos] = d_values_in[index]; + const auto local_pos = atomicAdd(&temp_storage.smem_out_cnt, OutOffsetT{1}); + temp_storage.staging.keys[local_pos] = key; + if constexpr (!keys_only) + { + temp_storage.staging.indices[local_pos] = load_from_original_input ? i : in_idx_buf[i]; + } } - } - }; + }; - // Lambda for early_stop = false, out_buf != nullptr (i.e., we need to further refine the candidates in the next - // pass): Write out selected items to output, write candidates to out_buf, and build histogram for candidates. - auto f_with_out_buf = [load_from_original_input, in_idx_buf, out_buf, out_idx_buf, p_filter_cnt, p_out_cnt, this]( - key_in_t key, OffsetT i) { - const candidate_class pre_res = identify_candidates_op(key); - if (pre_res == candidate_class::candidate) - { - const OffsetT pos = atomicAdd(p_filter_cnt, OffsetT{1}); - out_buf[pos] = key; - if constexpr (!keys_only) + auto early_stop_epilogue = [p_out_cnt, this]() { + __syncthreads(); + const OutOffsetT n_out = temp_storage.smem_out_cnt; + if (threadIdx.x == 0) { - const OffsetT index = load_from_original_input ? i : in_idx_buf[i]; - out_idx_buf[pos] = index; + temp_storage.block_out_base = atomicAdd(p_out_cnt, n_out); } + __syncthreads(); + const OutOffsetT out_base = temp_storage.block_out_base; + for (OutOffsetT i = threadIdx.x; i < n_out; i += block_threads) + { + d_keys_out[out_base + i] = temp_storage.staging.keys[i]; + if constexpr (!keys_only) + { + d_values_out[out_base + i] = d_values_in[temp_storage.staging.indices[i]]; + } + } + if (threadIdx.x == 0) + { + temp_storage.smem_out_cnt = 0; + } + }; - const int bucket = extract_bin_op(key); - atomicAdd(temp_storage.histogram + bucket, OffsetT{1}); - } - else if (pre_res == candidate_class::selected) - { - const OutOffsetT pos = atomicAdd(p_out_cnt, OutOffsetT{1}); - d_keys_out[pos] = key; - if constexpr (!keys_only) + auto f_with_out_buf = [load_from_original_input, in_idx_buf, this](key_in_t key, OffsetT i) { + const candidate_class pre_res = identify_candidates_op(key); + if (pre_res == candidate_class::candidate) { - const OffsetT index = in_idx_buf ? in_idx_buf[i] : i; - d_values_out[pos] = d_values_in[index]; + const auto local_pos = atomicAdd(&temp_storage.smem_filter_cnt, OffsetT{1}); + temp_storage.staging.keys[tile_items - 1 - local_pos] = key; + if constexpr (!keys_only) + { + temp_storage.staging.indices[tile_items - 1 - local_pos] = load_from_original_input ? i : in_idx_buf[i]; + } + + const int bucket = extract_bin_op(key); + atomicAdd(temp_storage.histogram + bucket, OffsetT{1}); } - } - }; + else if (pre_res == candidate_class::selected) + { + const auto local_pos = atomicAdd(&temp_storage.smem_out_cnt, OutOffsetT{1}); + temp_storage.staging.keys[local_pos] = key; + if constexpr (!keys_only) + { + temp_storage.staging.indices[local_pos] = in_idx_buf ? in_idx_buf[i] : i; + } + } + }; - // Lambda for early_stop = false, out_buf = nullptr (i.e., we need to further refine the candidates in the next - // pass, but we skip writing candidates to out_buf): - // Just build histogram for candidates. - // Note: We will only begin writing to d_keys_out starting from the pass in which the number of output-candidates - // is small enough to fit into the output buffer (otherwise, we would be writing the same items to d_keys_out - // multiple times). - auto f_no_out_buf = [this](key_in_t key, OffsetT i) { - const candidate_class pre_res = identify_candidates_op(key); - if (pre_res == candidate_class::candidate) - { - const int bucket = extract_bin_op(key); - atomicAdd(temp_storage.histogram + bucket, OffsetT{1}); - } - }; + auto with_out_buf_epilogue = [p_filter_cnt, p_out_cnt, out_buf, out_idx_buf, this]() { + __syncthreads(); + const OffsetT n_candidates = temp_storage.smem_filter_cnt; + const OutOffsetT n_selected = temp_storage.smem_out_cnt; + if (threadIdx.x == 0) + { + temp_storage.block_filter_base = atomicAdd(p_filter_cnt, n_candidates); + temp_storage.block_out_base = atomicAdd(p_out_cnt, n_selected); + } + __syncthreads(); + const OffsetT filter_base = temp_storage.block_filter_base; + const OutOffsetT out_base = temp_storage.block_out_base; + for (OutOffsetT i = threadIdx.x; i < n_selected; i += block_threads) + { + d_keys_out[out_base + i] = temp_storage.staging.keys[i]; + if constexpr (!keys_only) + { + d_values_out[out_base + i] = d_values_in[temp_storage.staging.indices[i]]; + } + } + const OffsetT cand_start = static_cast(tile_items) - n_candidates; + for (OffsetT i = threadIdx.x; i < n_candidates; i += block_threads) + { + out_buf[filter_base + i] = temp_storage.staging.keys[cand_start + i]; + if constexpr (!keys_only) + { + out_idx_buf[filter_base + i] = temp_storage.staging.indices[cand_start + i]; + } + } + if (threadIdx.x == 0) + { + temp_storage.smem_filter_cnt = 0; + temp_storage.smem_out_cnt = 0; + } + }; - // Choose and invoke the appropriate lambda with the correct input source - // If the input size exceeds the allocated buffer size, we know for sure we haven't started writing candidates to - // the output buffer yet - if (load_from_original_input) - { - if (early_stop) - { - process_range(d_keys_in, previous_len, f_early_stop); - } - else if (out_buf) + auto f_no_out_buf = [this](key_in_t key, OffsetT i) { + const candidate_class pre_res = identify_candidates_op(key); + if (pre_res == candidate_class::candidate) + { + const int bucket = extract_bin_op(key); + atomicAdd(temp_storage.histogram + bucket, OffsetT{1}); + } + }; + + if (load_from_original_input) { - process_range(d_keys_in, previous_len, f_with_out_buf); + if (early_stop) + { + process_range(d_keys_in, previous_len, f_early_stop, early_stop_epilogue); + } + else if (out_buf) + { + process_range(d_keys_in, previous_len, f_with_out_buf, with_out_buf_epilogue); + } + else + { + process_range(d_keys_in, previous_len, f_no_out_buf); + } } else { - process_range(d_keys_in, previous_len, f_no_out_buf); + if (early_stop) + { + process_range(in_buf, previous_len, f_early_stop, early_stop_epilogue); + } + else if (out_buf) + { + process_range(in_buf, previous_len, f_with_out_buf, with_out_buf_epilogue); + } + else + { + process_range(in_buf, previous_len, f_no_out_buf); + } } } else { - if (early_stop) - { - process_range(in_buf, previous_len, f_early_stop); - } - else if (out_buf) + // === ORIGINAL PATH (no smem write coordination) === + + auto f_early_stop = [load_from_original_input, in_idx_buf, p_out_cnt, this](key_in_t key, OffsetT i) { + const candidate_class pre_res = identify_candidates_op(key); + if (pre_res == candidate_class::candidate || pre_res == candidate_class::selected) + { + const OutOffsetT pos = atomicAdd(p_out_cnt, OutOffsetT{1}); + d_keys_out[pos] = key; + if constexpr (!keys_only) + { + const OffsetT index = load_from_original_input ? i : in_idx_buf[i]; + d_values_out[pos] = d_values_in[index]; + } + } + }; + + auto f_with_out_buf = [load_from_original_input, in_idx_buf, out_buf, out_idx_buf, p_filter_cnt, p_out_cnt, this]( + key_in_t key, OffsetT i) { + const candidate_class pre_res = identify_candidates_op(key); + if (pre_res == candidate_class::candidate) + { + const OffsetT pos = atomicAdd(p_filter_cnt, OffsetT{1}); + out_buf[pos] = key; + if constexpr (!keys_only) + { + const OffsetT index = load_from_original_input ? i : in_idx_buf[i]; + out_idx_buf[pos] = index; + } + + const int bucket = extract_bin_op(key); + atomicAdd(temp_storage.histogram + bucket, OffsetT{1}); + } + else if (pre_res == candidate_class::selected) + { + const OutOffsetT pos = atomicAdd(p_out_cnt, OutOffsetT{1}); + d_keys_out[pos] = key; + if constexpr (!keys_only) + { + const OffsetT index = in_idx_buf ? in_idx_buf[i] : i; + d_values_out[pos] = d_values_in[index]; + } + } + }; + + auto f_no_out_buf = [this](key_in_t key, OffsetT i) { + const candidate_class pre_res = identify_candidates_op(key); + if (pre_res == candidate_class::candidate) + { + const int bucket = extract_bin_op(key); + atomicAdd(temp_storage.histogram + bucket, OffsetT{1}); + } + }; + + if (load_from_original_input) { - process_range(in_buf, previous_len, f_with_out_buf); + if (early_stop) + { + process_range(d_keys_in, previous_len, f_early_stop); + } + else if (out_buf) + { + process_range(d_keys_in, previous_len, f_with_out_buf); + } + else + { + process_range(d_keys_in, previous_len, f_no_out_buf); + } } else { - process_range(in_buf, previous_len, f_no_out_buf); + if (early_stop) + { + process_range(in_buf, previous_len, f_early_stop); + } + else if (out_buf) + { + process_range(in_buf, previous_len, f_with_out_buf); + } + else + { + process_range(in_buf, previous_len, f_no_out_buf); + } } } @@ -648,44 +825,130 @@ struct AgentTopK OutOffsetT* p_out_cnt = &counter->out_cnt; OutOffsetT* p_out_back_cnt = &counter->out_back_cnt; - auto f = [this, p_out_cnt, in_idx_buf, p_out_back_cnt, num_of_kth_needed, k, load_from_original_input]( - key_in_t key, OffsetT i) { - const candidate_class res = identify_candidates_op(key); - if (res == candidate_class::selected) + if constexpr (use_smem_write_coordination) + { + if (threadIdx.x == 0) { - const OutOffsetT pos = atomicAdd(p_out_cnt, OffsetT{1}); - d_keys_out[pos] = key; - if constexpr (!keys_only) + temp_storage.smem_filter_cnt = 0; + temp_storage.smem_out_cnt = 0; + } + __syncthreads(); + + auto f = [this, in_idx_buf, load_from_original_input](key_in_t key, OffsetT i) { + const candidate_class res = identify_candidates_op(key); + if (res == candidate_class::selected) + { + const auto local_pos = atomicAdd(&temp_storage.smem_out_cnt, OutOffsetT{1}); + temp_storage.staging.keys[local_pos] = key; + if constexpr (!keys_only) + { + temp_storage.staging.indices[local_pos] = load_from_original_input ? i : in_idx_buf[i]; + } + } + else if (res == candidate_class::candidate) + { + const auto local_pos = atomicAdd(&temp_storage.smem_filter_cnt, OffsetT{1}); + temp_storage.staging.keys[tile_items - 1 - local_pos] = key; + if constexpr (!keys_only) + { + temp_storage.staging.indices[tile_items - 1 - local_pos] = load_from_original_input ? i : in_idx_buf[i]; + } + } + }; + + auto epilogue = [this, p_out_cnt, p_out_back_cnt, num_of_kth_needed, k]() { + __syncthreads(); + const OutOffsetT n_selected = temp_storage.smem_out_cnt; + const OffsetT n_candidates = temp_storage.smem_filter_cnt; + if (threadIdx.x == 0) { - // If writing has been skipped up to this point, `in_idx_buf` is nullptr - const OffsetT index = load_from_original_input ? i : in_idx_buf[i]; - d_values_out[pos] = d_values_in[index]; + temp_storage.block_out_base = atomicAdd(p_out_cnt, n_selected); + temp_storage.block_out_back_base = atomicAdd(p_out_back_cnt, static_cast(n_candidates)); } - } - else if (res == candidate_class::candidate) - { - const OutOffsetT back_pos = atomicAdd(p_out_back_cnt, OffsetT{1}); + __syncthreads(); + const OutOffsetT out_base = temp_storage.block_out_base; + const OutOffsetT back_base = temp_storage.block_out_back_base; - if (back_pos < num_of_kth_needed) + for (OutOffsetT i = threadIdx.x; i < n_selected; i += block_threads) { - const OutOffsetT pos = k - 1 - back_pos; - d_keys_out[pos] = key; + d_keys_out[out_base + i] = temp_storage.staging.keys[i]; if constexpr (!keys_only) { - const OffsetT new_idx = load_from_original_input ? i : in_idx_buf[i]; - d_values_out[pos] = d_values_in[new_idx]; + d_values_out[out_base + i] = d_values_in[temp_storage.staging.indices[i]]; } } - } - }; - if (load_from_original_input) - { - process_range(d_keys_in, current_len, f); + const OffsetT cand_start = static_cast(tile_items) - n_candidates; + for (OffsetT i = threadIdx.x; i < n_candidates; i += block_threads) + { + const OutOffsetT back_pos = back_base + static_cast(i); + if (back_pos < num_of_kth_needed) + { + const OutOffsetT pos = k - 1 - back_pos; + d_keys_out[pos] = temp_storage.staging.keys[cand_start + i]; + if constexpr (!keys_only) + { + d_values_out[pos] = d_values_in[temp_storage.staging.indices[cand_start + i]]; + } + } + } + + if (threadIdx.x == 0) + { + temp_storage.smem_filter_cnt = 0; + temp_storage.smem_out_cnt = 0; + } + }; + + if (load_from_original_input) + { + process_range(d_keys_in, current_len, f, epilogue); + } + else + { + process_range(in_buf, current_len, f, epilogue); + } } else { - process_range(in_buf, current_len, f); + auto f = [this, p_out_cnt, in_idx_buf, p_out_back_cnt, num_of_kth_needed, k, load_from_original_input]( + key_in_t key, OffsetT i) { + const candidate_class res = identify_candidates_op(key); + if (res == candidate_class::selected) + { + const OutOffsetT pos = atomicAdd(p_out_cnt, OffsetT{1}); + d_keys_out[pos] = key; + if constexpr (!keys_only) + { + const OffsetT index = load_from_original_input ? i : in_idx_buf[i]; + d_values_out[pos] = d_values_in[index]; + } + } + else if (res == candidate_class::candidate) + { + const OutOffsetT back_pos = atomicAdd(p_out_back_cnt, OffsetT{1}); + + if (back_pos < num_of_kth_needed) + { + const OutOffsetT pos = k - 1 - back_pos; + d_keys_out[pos] = key; + if constexpr (!keys_only) + { + const OffsetT new_idx = load_from_original_input ? i : in_idx_buf[i]; + d_values_out[pos] = d_values_in[new_idx]; + } + } + } + }; + + if (load_from_original_input) + { + process_range(d_keys_in, current_len, f); + } + else + { + process_range(in_buf, current_len, f); + } } } diff --git a/cub/cub/device/dispatch/dispatch_topk.cuh b/cub/cub/device/dispatch/dispatch_topk.cuh index acb1d301c35..da7d365029b 100644 --- a/cub/cub/device/dispatch/dispatch_topk.cuh +++ b/cub/cub/device/dispatch/dispatch_topk.cuh @@ -272,7 +272,8 @@ __launch_bounds__(int(PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}).block policy.items_per_thread, policy.bits_per_pass, policy.load_algorithm, - policy.scan_algorithm>; + policy.scan_algorithm, + policy.use_smem_write_coordination>; using agent_topk_t = AgentTopK; + policy.scan_algorithm, + policy.use_smem_write_coordination>; using identify_candidates_op_t = NullType; using agent_topk_t = AgentTopK; + policy.scan_algorithm, + policy.use_smem_write_coordination>; using extract_bin_op_t = NullType; using agent_topk_t = AgentTopK 2, int32 -> 4, int16 -> 8. const int items_per_thread = ::cuda::std::max(1, nominal_4b_items_per_thread * 4 / key_size); - return topk_policy{512, items_per_thread, bits_per_pass, BLOCK_LOAD_VECTORIZE, BLOCK_SCAN_WARP_SCANS}; + return topk_policy{512, items_per_thread, bits_per_pass, BLOCK_LOAD_VECTORIZE, BLOCK_SCAN_WARP_SCANS, true}; } // Default tuning used on older architectures. const int items_per_thread = ::cuda::std::clamp(nominal_4b_items_per_thread * 4 / key_size, 1, nominal_4b_items_per_thread); - return topk_policy{512, items_per_thread, bits_per_pass, BLOCK_LOAD_VECTORIZE, BLOCK_SCAN_WARP_SCANS}; + return topk_policy{512, items_per_thread, bits_per_pass, BLOCK_LOAD_VECTORIZE, BLOCK_SCAN_WARP_SCANS, true}; } }; From 597a429f20ede0538e84e8610492d1ab60209fc3 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Thu, 9 Apr 2026 22:31:04 -0700 Subject: [PATCH 5/8] fixes determining keys_only --- cub/cub/agent/agent_topk.cuh | 5 +++-- cub/cub/device/dispatch/dispatch_topk.cuh | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/cub/cub/agent/agent_topk.cuh b/cub/cub/agent/agent_topk.cuh index 87551ac2f85..cac6b5764f8 100644 --- a/cub/cub/agent/agent_topk.cuh +++ b/cub/cub/agent/agent_topk.cuh @@ -243,7 +243,8 @@ struct AgentTopK // Types and constants //--------------------------------------------------------------------- // The key and value type - using key_in_t = it_value_t; + using key_in_t = it_value_t; + using value_in_t = it_value_t; static constexpr int block_threads = AgentTopKPolicyT::block_threads; static constexpr int items_per_thread = AgentTopKPolicyT::items_per_thread; @@ -251,7 +252,7 @@ struct AgentTopK static constexpr int tile_items = block_threads * items_per_thread; static constexpr int num_buckets = 1 << bits_per_pass; - static constexpr bool keys_only = ::cuda::std::is_same_v; + static constexpr bool keys_only = ::cuda::std::is_same_v; static constexpr int bins_per_thread = ::cuda::ceil_div(num_buckets, block_threads); static constexpr bool use_smem_write_coordination = AgentTopKPolicyT::use_smem_write_coordination; diff --git a/cub/cub/device/dispatch/dispatch_topk.cuh b/cub/cub/device/dispatch/dispatch_topk.cuh index da7d365029b..b6b8339cce9 100644 --- a/cub/cub/device/dispatch/dispatch_topk.cuh +++ b/cub/cub/device/dispatch/dispatch_topk.cuh @@ -483,7 +483,8 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch( return dispatch_arch(policy_selector, arch_id, [&](auto policy_getter) { static constexpr topk_policy active_policy = policy_getter(); using key_in_t = it_value_t; - static constexpr bool keys_only = ::cuda::std::is_same_v; + using value_in_t = it_value_t; + static constexpr bool keys_only = ::cuda::std::is_same_v; // atomicAdd does not implement overloads for all integer types, so we limit OffsetT to uint32_t or unsigned long // long From d76869fb7fd870e5d6890c69dbc9f1bb11156289 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Thu, 9 Apr 2026 23:45:39 -0700 Subject: [PATCH 6/8] introduces new smem mode --- cub/cub/agent/agent_topk.cuh | 405 ++++++++++++++++-- cub/cub/device/dispatch/dispatch_topk.cuh | 6 +- .../device/dispatch/tuning/tuning_topk.cuh | 41 +- 3 files changed, 409 insertions(+), 43 deletions(-) diff --git a/cub/cub/agent/agent_topk.cuh b/cub/cub/agent/agent_topk.cuh index cac6b5764f8..bbb188d5ac1 100644 --- a/cub/cub/agent/agent_topk.cuh +++ b/cub/cub/agent/agent_topk.cuh @@ -19,9 +19,11 @@ #include #include #include +#include #include #include +#include CUB_NAMESPACE_BEGIN @@ -49,7 +51,7 @@ template + smem_write_mode WriteMode = smem_write_mode::smem_coalescing_two_phase> struct AgentTopKPolicy { static constexpr int block_threads = BlockThreads; @@ -57,7 +59,7 @@ struct AgentTopKPolicy static constexpr int bits_per_pass = BitsPerPass; static constexpr BlockLoadAlgorithm load_algorithm = LoadAlgorithm; static constexpr BlockScanAlgorithm SCAN_ALGORITHM = ScanAlgorithm; - static constexpr bool use_smem_write_coordination = UseSmemWriteCoordination; + static constexpr smem_write_mode write_mode = WriteMode; }; template > @@ -255,7 +257,13 @@ struct AgentTopK static constexpr bool keys_only = ::cuda::std::is_same_v; static constexpr int bins_per_thread = ::cuda::ceil_div(num_buckets, block_threads); - static constexpr bool use_smem_write_coordination = AgentTopKPolicyT::use_smem_write_coordination; + static constexpr smem_write_mode write_mode = AgentTopKPolicyT::write_mode; + + // For keys_only kernels, two_phase is equivalent to smem_coalescing + static constexpr smem_write_mode effective_write_mode = + (keys_only && write_mode == smem_write_mode::smem_coalescing_two_phase) + ? smem_write_mode::smem_coalescing + : write_mode; // Parameterized BlockLoad type for input data using block_load_input_t = BlockLoad; @@ -270,6 +278,45 @@ struct AgentTopK _CCCL_DEVICE void operator()() const {} }; + //--------------------------------------------------------------------- + // Staging buffer type selection + //--------------------------------------------------------------------- + + struct _StagingDisabled + { + key_in_t keys[1]; + OffsetT indices[1]; + }; + + struct _StagingCoalescing + { + key_in_t keys[tile_items]; + OffsetT indices[tile_items]; + }; + + struct _StagingKeysOnly + { + key_in_t keys[tile_items]; + OffsetT indices[1]; + }; + + // Two-phase: keys and indices share storage via anonymous union + struct _StagingTwoPhase + { + union + { + key_in_t keys[tile_items]; + OffsetT indices[tile_items]; + }; + }; + + using staging_t = ::cuda::std::conditional_t< + effective_write_mode == smem_write_mode::no_smem_coalescing, + _StagingDisabled, + ::cuda::std::conditional_t>>; + // Shared memory struct _TempStorage { @@ -283,16 +330,11 @@ struct AgentTopK // Smem needed for storing typename block_store_trans_t::TempStorage store_trans; - // Staging buffer for coalesced writes (active when use_smem_write_coordination == true) - struct - { - key_in_t keys[use_smem_write_coordination ? tile_items : 1]; - OffsetT indices[(use_smem_write_coordination && !keys_only) ? tile_items : 1]; - } staging; + staging_t staging; }; OffsetT histogram[num_buckets]; - // Write coordination counters (used when use_smem_write_coordination == true) + // Write coordination counters OffsetT smem_filter_cnt; OutOffsetT smem_out_cnt; OffsetT block_filter_base; @@ -499,7 +541,7 @@ struct AgentTopK // Initialize shared memory histogram init_histograms(temp_storage.histogram); - if constexpr (use_smem_write_coordination) + if constexpr (effective_write_mode != smem_write_mode::no_smem_coalescing) { if (threadIdx.x == 0) { @@ -514,9 +556,202 @@ struct AgentTopK OffsetT* p_filter_cnt = &counter->filter_cnt; OutOffsetT* p_out_cnt = &counter->out_cnt; - if constexpr (use_smem_write_coordination) + // Histogram-only lambda is shared across all modes + auto f_no_out_buf = [this](key_in_t key, OffsetT /*i*/) { + const candidate_class pre_res = identify_candidates_op(key); + if (pre_res == candidate_class::candidate) + { + const int bucket = extract_bin_op(key); + atomicAdd(temp_storage.histogram + bucket, OffsetT{1}); + } + }; + + if constexpr (effective_write_mode == smem_write_mode::smem_coalescing_two_phase) { - // === SMEM WRITE COORDINATION PATH === + // === TWO-PHASE PATH (always !keys_only due to effective_write_mode mapping) === + + // Per-thread register arrays that persist across per-element lambda calls within a tile + candidate_class thread_flags[items_per_thread]; + int thread_local_pos[items_per_thread]; + OffsetT thread_resolved_idx[items_per_thread]; + int thread_item_idx = 0; + + _CCCL_PRAGMA_UNROLL_FULL() + for (int j = 0; j < items_per_thread; ++j) + { + thread_flags[j] = candidate_class::rejected; + } + + auto f_early_stop = [&](key_in_t key, OffsetT i) { + const int j = thread_item_idx++; + const candidate_class pre_res = identify_candidates_op(key); + if (pre_res == candidate_class::candidate || pre_res == candidate_class::selected) + { + thread_flags[j] = pre_res; + const auto local_pos = atomicAdd(&temp_storage.smem_out_cnt, OutOffsetT{1}); + thread_local_pos[j] = static_cast(local_pos); + temp_storage.staging.keys[local_pos] = key; + thread_resolved_idx[j] = load_from_original_input ? i : in_idx_buf[i]; + } + }; + + auto early_stop_epilogue = [&]() { + __syncthreads(); + const OutOffsetT n_out = temp_storage.smem_out_cnt; + if (threadIdx.x == 0) + { + temp_storage.block_out_base = atomicAdd(p_out_cnt, n_out); + } + __syncthreads(); + const OutOffsetT out_base = temp_storage.block_out_base; + // Phase 1: flush keys + for (OutOffsetT i = threadIdx.x; i < n_out; i += block_threads) + { + d_keys_out[out_base + i] = temp_storage.staging.keys[i]; + } + // Phase 2: reuse staging for indices + __syncthreads(); + for (int j = 0; j < items_per_thread; ++j) + { + if (thread_flags[j] != candidate_class::rejected) + { + temp_storage.staging.indices[thread_local_pos[j]] = thread_resolved_idx[j]; + } + } + __syncthreads(); + for (OutOffsetT i = threadIdx.x; i < n_out; i += block_threads) + { + d_values_out[out_base + i] = d_values_in[temp_storage.staging.indices[i]]; + } + // Reset for next tile + if (threadIdx.x == 0) + { + temp_storage.smem_out_cnt = 0; + } + thread_item_idx = 0; + _CCCL_PRAGMA_UNROLL_FULL() + for (int j = 0; j < items_per_thread; ++j) + { + thread_flags[j] = candidate_class::rejected; + } + }; + + auto f_with_out_buf = [&](key_in_t key, OffsetT i) { + const int j = thread_item_idx++; + const candidate_class pre_res = identify_candidates_op(key); + if (pre_res == candidate_class::candidate) + { + thread_flags[j] = pre_res; + const auto local_pos = atomicAdd(&temp_storage.smem_filter_cnt, OffsetT{1}); + thread_local_pos[j] = static_cast(local_pos); + temp_storage.staging.keys[tile_items - 1 - local_pos] = key; + thread_resolved_idx[j] = load_from_original_input ? i : in_idx_buf[i]; + + const int bucket = extract_bin_op(key); + atomicAdd(temp_storage.histogram + bucket, OffsetT{1}); + } + else if (pre_res == candidate_class::selected) + { + thread_flags[j] = pre_res; + const auto local_pos = atomicAdd(&temp_storage.smem_out_cnt, OutOffsetT{1}); + thread_local_pos[j] = static_cast(local_pos); + temp_storage.staging.keys[local_pos] = key; + thread_resolved_idx[j] = in_idx_buf ? in_idx_buf[i] : i; + } + }; + + auto with_out_buf_epilogue = [&]() { + __syncthreads(); + const OffsetT n_candidates = temp_storage.smem_filter_cnt; + const OutOffsetT n_selected = temp_storage.smem_out_cnt; + if (threadIdx.x == 0) + { + temp_storage.block_filter_base = atomicAdd(p_filter_cnt, n_candidates); + temp_storage.block_out_base = atomicAdd(p_out_cnt, n_selected); + } + __syncthreads(); + const OffsetT filter_base = temp_storage.block_filter_base; + const OutOffsetT out_base = temp_storage.block_out_base; + // Phase 1: flush keys + for (OutOffsetT i = threadIdx.x; i < n_selected; i += block_threads) + { + d_keys_out[out_base + i] = temp_storage.staging.keys[i]; + } + const OffsetT cand_start = static_cast(tile_items) - n_candidates; + for (OffsetT i = threadIdx.x; i < n_candidates; i += block_threads) + { + out_buf[filter_base + i] = temp_storage.staging.keys[cand_start + i]; + } + // Phase 2: reuse staging for indices + __syncthreads(); + for (int j = 0; j < items_per_thread; ++j) + { + if (thread_flags[j] == candidate_class::selected) + { + temp_storage.staging.indices[thread_local_pos[j]] = thread_resolved_idx[j]; + } + else if (thread_flags[j] == candidate_class::candidate) + { + temp_storage.staging.indices[tile_items - 1 - thread_local_pos[j]] = thread_resolved_idx[j]; + } + } + __syncthreads(); + for (OutOffsetT i = threadIdx.x; i < n_selected; i += block_threads) + { + d_values_out[out_base + i] = d_values_in[temp_storage.staging.indices[i]]; + } + for (OffsetT i = threadIdx.x; i < n_candidates; i += block_threads) + { + out_idx_buf[filter_base + i] = temp_storage.staging.indices[cand_start + i]; + } + // Reset for next tile + if (threadIdx.x == 0) + { + temp_storage.smem_filter_cnt = 0; + temp_storage.smem_out_cnt = 0; + } + thread_item_idx = 0; + _CCCL_PRAGMA_UNROLL_FULL() + for (int j = 0; j < items_per_thread; ++j) + { + thread_flags[j] = candidate_class::rejected; + } + }; + + if (load_from_original_input) + { + if (early_stop) + { + process_range(d_keys_in, previous_len, f_early_stop, early_stop_epilogue); + } + else if (out_buf) + { + process_range(d_keys_in, previous_len, f_with_out_buf, with_out_buf_epilogue); + } + else + { + process_range(d_keys_in, previous_len, f_no_out_buf); + } + } + else + { + if (early_stop) + { + process_range(in_buf, previous_len, f_early_stop, early_stop_epilogue); + } + else if (out_buf) + { + process_range(in_buf, previous_len, f_with_out_buf, with_out_buf_epilogue); + } + else + { + process_range(in_buf, previous_len, f_no_out_buf); + } + } + } + else if constexpr (effective_write_mode == smem_write_mode::smem_coalescing) + { + // === SINGLE-PHASE SMEM WRITE COORDINATION PATH === auto f_early_stop = [load_from_original_input, in_idx_buf, this](key_in_t key, OffsetT i) { const candidate_class pre_res = identify_candidates_op(key); @@ -615,15 +850,6 @@ struct AgentTopK } }; - auto f_no_out_buf = [this](key_in_t key, OffsetT i) { - const candidate_class pre_res = identify_candidates_op(key); - if (pre_res == candidate_class::candidate) - { - const int bucket = extract_bin_op(key); - atomicAdd(temp_storage.histogram + bucket, OffsetT{1}); - } - }; - if (load_from_original_input) { if (early_stop) @@ -657,7 +883,7 @@ struct AgentTopK } else { - // === ORIGINAL PATH (no smem write coordination) === + // === NO SMEM COORDINATION PATH === auto f_early_stop = [load_from_original_input, in_idx_buf, p_out_cnt, this](key_in_t key, OffsetT i) { const candidate_class pre_res = identify_candidates_op(key); @@ -701,15 +927,6 @@ struct AgentTopK } }; - auto f_no_out_buf = [this](key_in_t key, OffsetT i) { - const candidate_class pre_res = identify_candidates_op(key); - if (pre_res == candidate_class::candidate) - { - const int bucket = extract_bin_op(key); - atomicAdd(temp_storage.histogram + bucket, OffsetT{1}); - } - }; - if (load_from_original_input) { if (early_stop) @@ -826,8 +1043,129 @@ struct AgentTopK OutOffsetT* p_out_cnt = &counter->out_cnt; OutOffsetT* p_out_back_cnt = &counter->out_back_cnt; - if constexpr (use_smem_write_coordination) + if constexpr (effective_write_mode == smem_write_mode::smem_coalescing_two_phase) + { + // === TWO-PHASE LAST FILTER (always !keys_only) === + if (threadIdx.x == 0) + { + temp_storage.smem_filter_cnt = 0; + temp_storage.smem_out_cnt = 0; + } + __syncthreads(); + + candidate_class thread_flags[items_per_thread]; + int thread_local_pos[items_per_thread]; + OffsetT thread_resolved_idx[items_per_thread]; + int thread_item_idx = 0; + + _CCCL_PRAGMA_UNROLL_FULL() + for (int j = 0; j < items_per_thread; ++j) + { + thread_flags[j] = candidate_class::rejected; + } + + auto f = [&](key_in_t key, OffsetT i) { + const int j = thread_item_idx++; + const candidate_class res = identify_candidates_op(key); + if (res == candidate_class::selected) + { + thread_flags[j] = res; + const auto local_pos = atomicAdd(&temp_storage.smem_out_cnt, OutOffsetT{1}); + thread_local_pos[j] = static_cast(local_pos); + temp_storage.staging.keys[local_pos] = key; + thread_resolved_idx[j] = load_from_original_input ? i : in_idx_buf[i]; + } + else if (res == candidate_class::candidate) + { + thread_flags[j] = res; + const auto local_pos = atomicAdd(&temp_storage.smem_filter_cnt, OffsetT{1}); + thread_local_pos[j] = static_cast(local_pos); + temp_storage.staging.keys[tile_items - 1 - local_pos] = key; + thread_resolved_idx[j] = load_from_original_input ? i : in_idx_buf[i]; + } + }; + + auto epilogue = [&]() { + __syncthreads(); + const OutOffsetT n_selected = temp_storage.smem_out_cnt; + const OffsetT n_candidates = temp_storage.smem_filter_cnt; + if (threadIdx.x == 0) + { + temp_storage.block_out_base = atomicAdd(p_out_cnt, n_selected); + temp_storage.block_out_back_base = atomicAdd(p_out_back_cnt, static_cast(n_candidates)); + } + __syncthreads(); + const OutOffsetT out_base = temp_storage.block_out_base; + const OutOffsetT back_base = temp_storage.block_out_back_base; + + // Phase 1: flush keys + for (OutOffsetT i = threadIdx.x; i < n_selected; i += block_threads) + { + d_keys_out[out_base + i] = temp_storage.staging.keys[i]; + } + const OffsetT cand_start = static_cast(tile_items) - n_candidates; + for (OffsetT i = threadIdx.x; i < n_candidates; i += block_threads) + { + const OutOffsetT back_pos = back_base + static_cast(i); + if (back_pos < num_of_kth_needed) + { + d_keys_out[k - 1 - back_pos] = temp_storage.staging.keys[cand_start + i]; + } + } + + // Phase 2: reuse staging for indices + __syncthreads(); + for (int j = 0; j < items_per_thread; ++j) + { + if (thread_flags[j] == candidate_class::selected) + { + temp_storage.staging.indices[thread_local_pos[j]] = thread_resolved_idx[j]; + } + else if (thread_flags[j] == candidate_class::candidate) + { + temp_storage.staging.indices[tile_items - 1 - thread_local_pos[j]] = thread_resolved_idx[j]; + } + } + __syncthreads(); + for (OutOffsetT i = threadIdx.x; i < n_selected; i += block_threads) + { + d_values_out[out_base + i] = d_values_in[temp_storage.staging.indices[i]]; + } + for (OffsetT i = threadIdx.x; i < n_candidates; i += block_threads) + { + const OutOffsetT back_pos = back_base + static_cast(i); + if (back_pos < num_of_kth_needed) + { + d_values_out[k - 1 - back_pos] = d_values_in[temp_storage.staging.indices[cand_start + i]]; + } + } + + // Reset for next tile + if (threadIdx.x == 0) + { + temp_storage.smem_filter_cnt = 0; + temp_storage.smem_out_cnt = 0; + } + thread_item_idx = 0; + _CCCL_PRAGMA_UNROLL_FULL() + for (int j = 0; j < items_per_thread; ++j) + { + thread_flags[j] = candidate_class::rejected; + } + }; + + if (load_from_original_input) + { + process_range(d_keys_in, current_len, f, epilogue); + } + else + { + process_range(in_buf, current_len, f, epilogue); + } + } + else if constexpr (effective_write_mode == smem_write_mode::smem_coalescing) { + // === SINGLE-PHASE LAST FILTER === if (threadIdx.x == 0) { temp_storage.smem_filter_cnt = 0; @@ -912,6 +1250,7 @@ struct AgentTopK } else { + // === NO SMEM COORDINATION LAST FILTER === auto f = [this, p_out_cnt, in_idx_buf, p_out_back_cnt, num_of_kth_needed, k, load_from_original_input]( key_in_t key, OffsetT i) { const candidate_class res = identify_candidates_op(key); diff --git a/cub/cub/device/dispatch/dispatch_topk.cuh b/cub/cub/device/dispatch/dispatch_topk.cuh index b6b8339cce9..3f8da85c293 100644 --- a/cub/cub/device/dispatch/dispatch_topk.cuh +++ b/cub/cub/device/dispatch/dispatch_topk.cuh @@ -273,7 +273,7 @@ __launch_bounds__(int(PolicySelector{}(::cuda::arch_id{CUB_PTX_ARCH / 10}).block policy.bits_per_pass, policy.load_algorithm, policy.scan_algorithm, - policy.use_smem_write_coordination>; + policy.write_mode>; using agent_topk_t = AgentTopK; + policy.write_mode>; using identify_candidates_op_t = NullType; using agent_topk_t = AgentTopK; + policy.write_mode>; using extract_bin_op_t = NullType; using agent_topk_t = AgentTopK 2, int32 -> 4, int16 -> 8. const int items_per_thread = ::cuda::std::max(1, nominal_4b_items_per_thread * 4 / key_size); - return topk_policy{512, items_per_thread, bits_per_pass, BLOCK_LOAD_VECTORIZE, BLOCK_SCAN_WARP_SCANS, true}; + return topk_policy{ + 512, + items_per_thread, + bits_per_pass, + BLOCK_LOAD_VECTORIZE, + BLOCK_SCAN_WARP_SCANS, + smem_write_mode::smem_coalescing_two_phase}; } // Default tuning used on older architectures. const int items_per_thread = ::cuda::std::clamp(nominal_4b_items_per_thread * 4 / key_size, 1, nominal_4b_items_per_thread); - return topk_policy{512, items_per_thread, bits_per_pass, BLOCK_LOAD_VECTORIZE, BLOCK_SCAN_WARP_SCANS, true}; + return topk_policy{ + 512, + items_per_thread, + bits_per_pass, + BLOCK_LOAD_VECTORIZE, + BLOCK_SCAN_WARP_SCANS, + smem_write_mode::smem_coalescing_two_phase}; } }; From 5eaa259252e58f82c6d25f911ccb1aad335adce8 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Fri, 10 Apr 2026 02:19:12 -0700 Subject: [PATCH 7/8] lower shmem requirements --- cub/cub/agent/agent_topk.cuh | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/cub/cub/agent/agent_topk.cuh b/cub/cub/agent/agent_topk.cuh index 15544096a08..04fefbd9b5a 100644 --- a/cub/cub/agent/agent_topk.cuh +++ b/cub/cub/agent/agent_topk.cuh @@ -282,26 +282,26 @@ struct AgentTopK // Staging buffer type selection //--------------------------------------------------------------------- - struct _StagingDisabled + struct staging_disabled_t { key_in_t keys[1]; OffsetT indices[1]; }; - struct _StagingCoalescing + struct staging_coalescing_t { key_in_t keys[tile_items]; OffsetT indices[tile_items]; }; - struct _StagingKeysOnly + struct staging_keys_only_t { key_in_t keys[tile_items]; OffsetT indices[1]; }; // Two-phase: keys and indices share storage via anonymous union - struct _StagingTwoPhase + struct staging_two_phase_t { union { @@ -312,35 +312,40 @@ struct AgentTopK using staging_t = ::cuda::std::conditional_t< effective_write_mode == smem_write_mode::no_smem_coalescing, - _StagingDisabled, + staging_disabled_t, ::cuda::std::conditional_t>>; + staging_two_phase_t, + ::cuda::std::conditional_t>>; + //--------------------------------------------------------------------- // Shared memory - struct _TempStorage + //--------------------------------------------------------------------- + + struct temp_storage_base_t { union { - // Smem needed for loading typename block_load_input_t::TempStorage load_input; typename block_load_trans_t::TempStorage load_trans; - // Smem needed for scan typename block_scan_t::TempStorage scan; - // Smem needed for storing typename block_store_trans_t::TempStorage store_trans; - staging_t staging; }; OffsetT histogram[num_buckets]; + }; - // Write coordination counters + struct temp_storage_smem_t : temp_storage_base_t + { OffsetT smem_filter_cnt; OutOffsetT smem_out_cnt; OffsetT block_filter_base; OutOffsetT block_out_base; OutOffsetT block_out_back_base; }; + + using _TempStorage = ::cuda::std:: + conditional_t; + /// Alias wrapper allowing storage to be unioned struct TempStorage : Uninitialized<_TempStorage> {}; From a4ba5fc9abad6aebb1257af4c6c88794f13d6e85 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Fri, 10 Apr 2026 02:19:46 -0700 Subject: [PATCH 8/8] [bench-only] no smem-coalescing for 8B offset types --- cub/cub/device/dispatch/dispatch_topk.cuh | 2 +- .../device/dispatch/tuning/tuning_topk.cuh | 25 +++++++------------ 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/cub/cub/device/dispatch/dispatch_topk.cuh b/cub/cub/device/dispatch/dispatch_topk.cuh index 00af5320546..e82113cc307 100644 --- a/cub/cub/device/dispatch/dispatch_topk.cuh +++ b/cub/cub/device/dispatch/dispatch_topk.cuh @@ -455,7 +455,7 @@ template