Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions ci/bench.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ benchmarks:
# Inclusive regex filters (required).
filters:
# Examples:
# - '^cub\.bench\.for_each\.base'
# - '^cub\.bench\.reduce\.(sum|min)\.'
- '^cub\.bench\.topk\.keys\.base'
- '^cub\.bench\.topk\.pairs\.base'

# Select GPUs. These are limited and shared, be intentional and conservative.
gpus:
Expand All @@ -37,8 +37,8 @@ benchmarks:
# - "rtxa6000" # sm_86, 48 GB
# - "l4" # sm_89, 24 GB
# - "rtx4090" # sm_89, 24 GB
# - "h100" # sm_90, 80 GB
# - "rtxpro6000" # sm_120
- "h100" # sm_90, 80 GB
- "rtxpro6000" # sm_120

# Extra .devcontainer/launch.sh -d args
# launch_args: "--cuda 13.1 --host gcc14"
Expand All @@ -54,4 +54,8 @@ benchmarks:
--stopping-criterion entropy
--throttle-threshold 90
--throttle-recovery-delay 0.15
--axis "OffsetT{ct}=I32"
--axis "SelectedElements[pow2]=[3:23:5]"
--axis "Entropy=[1.000,0.201,0.000]"
--axis "OutOffsetT{ct}=I32"
nvbench_compare_args: ""
279 changes: 161 additions & 118 deletions cub/cub/agent/agent_topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,6 @@ struct AgentTopK
}

// Fused filtering of the current pass and building histogram for the next pass
template <bool IsFirstPass>
_CCCL_DEVICE _CCCL_FORCEINLINE void filter_and_histogram(
key_in_t* in_buf,
OffsetT* in_idx_buf,
Expand All @@ -465,116 +464,103 @@ struct AgentTopK
// Make sure the histogram was initialized
__syncthreads();

if constexpr (IsFirstPass)
{
// During the first pass, compute per-thread block histograms over the full input. The per-thread block histograms
// are being added to the global histogram further down below.
auto f = [this](key_in_t key, OffsetT /*index*/) {
const int bucket = extract_bin_op(key);
atomicAdd(temp_storage.histogram + bucket, OffsetT{1});
};
process_range(d_keys_in, previous_len, f);
}
else
{
OffsetT* p_filter_cnt = &counter->filter_cnt;
OutOffsetT* p_out_cnt = &counter->out_cnt;

// Lambda for early_stop = true (i.e., we have identified the exact "splitter" key):
// Select all items that fall into the bin of the k-th item (i.e., the 'candidates') and the ones that fall into
// bins preceding the k-th item bin (i.e., 'selected' items), write them to output.
// We can skip histogram computation because we don't need to further passes to refine the candidates.
auto f_early_stop = [load_from_original_input, in_idx_buf, p_out_cnt, this](key_in_t key, OffsetT i) {
const candidate_class pre_res = identify_candidates_op(key);
if (pre_res == candidate_class::candidate || pre_res == candidate_class::selected)
{
const OutOffsetT pos = atomicAdd(p_out_cnt, OutOffsetT{1});
d_keys_out[pos] = key;
if constexpr (!keys_only)
{
const OffsetT index = load_from_original_input ? i : in_idx_buf[i];
d_values_out[pos] = d_values_in[index];
}
}
};

// Lambda for early_stop = false, out_buf != nullptr (i.e., we need to further refine the candidates in the next
// pass): Write out selected items to output, write candidates to out_buf, and build histogram for candidates.
auto f_with_out_buf = [load_from_original_input, in_idx_buf, out_buf, out_idx_buf, p_filter_cnt, p_out_cnt, this](
key_in_t key, OffsetT i) {
const candidate_class pre_res = identify_candidates_op(key);
if (pre_res == candidate_class::candidate)
{
const OffsetT pos = atomicAdd(p_filter_cnt, OffsetT{1});
out_buf[pos] = key;
if constexpr (!keys_only)
{
const OffsetT index = load_from_original_input ? i : in_idx_buf[i];
out_idx_buf[pos] = index;
}
OffsetT* p_filter_cnt = &counter->filter_cnt;
OutOffsetT* p_out_cnt = &counter->out_cnt;

const int bucket = extract_bin_op(key);
atomicAdd(temp_storage.histogram + bucket, OffsetT{1});
}
else if (pre_res == candidate_class::selected)
{
const OutOffsetT pos = atomicAdd(p_out_cnt, OutOffsetT{1});
d_keys_out[pos] = key;
if constexpr (!keys_only)
{
const OffsetT index = in_idx_buf ? in_idx_buf[i] : i;
d_values_out[pos] = d_values_in[index];
}
}
};

// Lambda for early_stop = false, out_buf = nullptr (i.e., we need to further refine the candidates in the next
// pass, but we skip writing candidates to out_buf):
// Just build histogram for candidates.
// Note: We will only begin writing to d_keys_out starting from the pass in which the number of output-candidates
// is small enough to fit into the output buffer (otherwise, we would be writing the same items to d_keys_out
// multiple times).
auto f_no_out_buf = [this](key_in_t key, OffsetT i) {
const candidate_class pre_res = identify_candidates_op(key);
if (pre_res == candidate_class::candidate)
// Lambda for early_stop = true (i.e., we have identified the exact "splitter" key):
// Select all items that fall into the bin of the k-th item (i.e., the 'candidates') and the ones that fall into
// bins preceding the k-th item bin (i.e., 'selected' items), write them to output.
// We can skip histogram computation because we don't need to further passes to refine the candidates.
auto f_early_stop = [load_from_original_input, in_idx_buf, p_out_cnt, this](key_in_t key, OffsetT i) {
const candidate_class pre_res = identify_candidates_op(key);
if (pre_res == candidate_class::candidate || pre_res == candidate_class::selected)
{
const OutOffsetT pos = atomicAdd(p_out_cnt, OutOffsetT{1});
d_keys_out[pos] = key;
if constexpr (!keys_only)
{
const int bucket = extract_bin_op(key);
atomicAdd(temp_storage.histogram + bucket, OffsetT{1});
const OffsetT index = load_from_original_input ? i : in_idx_buf[i];
d_values_out[pos] = d_values_in[index];
}
};
}
};

// Choose and invoke the appropriate lambda with the correct input source
// If the input size exceeds the allocated buffer size, we know for sure we haven't started writing candidates to
// the output buffer yet
if (load_from_original_input)
// Lambda for early_stop = false, out_buf != nullptr (i.e., we need to further refine the candidates in the next
// pass): Write out selected items to output, write candidates to out_buf, and build histogram for candidates.
auto f_with_out_buf = [load_from_original_input, in_idx_buf, out_buf, out_idx_buf, p_filter_cnt, p_out_cnt, this](
key_in_t key, OffsetT i) {
const candidate_class pre_res = identify_candidates_op(key);
if (pre_res == candidate_class::candidate)
{
if (early_stop)
{
process_range(d_keys_in, previous_len, f_early_stop);
}
else if (out_buf)
const OffsetT pos = atomicAdd(p_filter_cnt, OffsetT{1});
out_buf[pos] = key;
if constexpr (!keys_only)
{
process_range(d_keys_in, previous_len, f_with_out_buf);
const OffsetT index = load_from_original_input ? i : in_idx_buf[i];
out_idx_buf[pos] = index;
}
else

const int bucket = extract_bin_op(key);
atomicAdd(temp_storage.histogram + bucket, OffsetT{1});
}
else if (pre_res == candidate_class::selected)
{
const OutOffsetT pos = atomicAdd(p_out_cnt, OutOffsetT{1});
d_keys_out[pos] = key;
if constexpr (!keys_only)
{
process_range(d_keys_in, previous_len, f_no_out_buf);
const OffsetT index = in_idx_buf ? in_idx_buf[i] : i;
d_values_out[pos] = d_values_in[index];
}
}
};

// Lambda for early_stop = false, out_buf = nullptr (i.e., we need to further refine the candidates in the next
// pass, but we skip writing candidates to out_buf):
// Just build histogram for candidates.
// Note: We will only begin writing to d_keys_out starting from the pass in which the number of output-candidates
// is small enough to fit into the output buffer (otherwise, we would be writing the same items to d_keys_out
// multiple times).
auto f_no_out_buf = [this](key_in_t key, OffsetT i) {
const candidate_class pre_res = identify_candidates_op(key);
if (pre_res == candidate_class::candidate)
{
const int bucket = extract_bin_op(key);
atomicAdd(temp_storage.histogram + bucket, OffsetT{1});
}
};

// Choose and invoke the appropriate lambda with the correct input source
// If the input size exceeds the allocated buffer size, we know for sure we haven't started writing candidates to
// the output buffer yet
if (load_from_original_input)
{
if (early_stop)
{
process_range(d_keys_in, previous_len, f_early_stop);
}
else if (out_buf)
{
process_range(d_keys_in, previous_len, f_with_out_buf);
}
else
{
if (early_stop)
{
process_range(in_buf, previous_len, f_early_stop);
}
else if (out_buf)
{
process_range(in_buf, previous_len, f_with_out_buf);
}
else
{
process_range(in_buf, previous_len, f_no_out_buf);
}
process_range(d_keys_in, previous_len, f_no_out_buf);
}
}
else
{
if (early_stop)
{
process_range(in_buf, previous_len, f_early_stop);
}
else if (out_buf)
{
process_range(in_buf, previous_len, f_with_out_buf);
}
else
{
process_range(in_buf, previous_len, f_no_out_buf);
}
}

Expand Down Expand Up @@ -703,7 +689,6 @@ struct AgentTopK
}
}

template <bool IsFirstPass>
_CCCL_DEVICE _CCCL_FORCEINLINE void invoke_filter_and_histogram(
key_in_t* in_buf,
OffsetT* in_idx_buf,
Expand All @@ -713,22 +698,9 @@ struct AgentTopK
OffsetT* histogram,
int pass)
{
OutOffsetT current_k;
OffsetT previous_len;
OffsetT current_len;

if constexpr (IsFirstPass)
{
current_k = k;
previous_len = num_items;
current_len = num_items;
}
else
{
current_k = counter->k;
current_len = counter->len;
previous_len = counter->previous_len;
}
const OutOffsetT current_k = counter->k;
const OffsetT current_len = counter->len;
OffsetT previous_len = counter->previous_len;

// If current_len is 0, it means all the candidates have been found in previous passes.
if (current_len == 0)
Expand All @@ -739,7 +711,7 @@ struct AgentTopK
// Early stop means that the bin containing the k-th element has been identified, and all
// the elements in this bin are exactly the remaining k items we need to find. So we can
// stop the process right here.
const bool early_stop = ((!IsFirstPass) && current_len == static_cast<OffsetT>(current_k));
const bool early_stop = (current_len == static_cast<OffsetT>(current_k));

// If previous_len > buffer_length, it means we haven't started writing candidates to out_buf yet,
// so have to make sure to load input directly from the original input.
Expand All @@ -761,7 +733,7 @@ struct AgentTopK
}

// Fused filtering of candidates and histogram computation over the output-candidates
filter_and_histogram<IsFirstPass>(
filter_and_histogram(
in_buf, in_idx_buf, out_buf, out_idx_buf, previous_len, counter, histogram, early_stop, load_from_original_input);

// We need this `__threadfence()` to make sure all writes to the global memory-histogram are visible to all
Expand Down Expand Up @@ -827,6 +799,77 @@ struct AgentTopK
}
}
}

// Histogram-only pass: computes the histogram over the full input without filtering.
// Used for the first radix pass before any candidates have been identified.
_CCCL_DEVICE _CCCL_FORCEINLINE void
invoke_histogram_only(Counter<key_in_t, OffsetT, OutOffsetT>* counter, OffsetT* histogram, int pass)
{
// Initialize shared memory histogram
init_histograms(temp_storage.histogram);
__syncthreads();

// Compute per-thread block histograms over the full input
auto f = [this](key_in_t key, OffsetT /*index*/) {
const int bucket = extract_bin_op(key);
atomicAdd(temp_storage.histogram + bucket, OffsetT{1});
};
process_range(d_keys_in, num_items, f);

// Ensure all threads have contributed to the histogram before accumulating in global memory
__syncthreads();

// Merge the locally aggregated histogram into the global histogram
merge_histograms(histogram);

// We need this `__threadfence()` to make sure all writes to the global memory-histogram are visible to all
// threads before we proceed to compute the prefix sum over the histogram.
__threadfence();

// Identify the last block in the grid to perform the prefix sum over the histogram and identify the bin that
// the k-th item falls into
bool is_last_block = false;
if (threadIdx.x == 0)
{
unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1);
is_last_block = (finished == (gridDim.x - 1));
}

if (__syncthreads_or(is_last_block))
{
if (threadIdx.x == 0)
{
counter->previous_len = num_items;
counter->filter_cnt = 0;
}

// Compute prefix sum over the histogram's bin counts
compute_bin_offsets(histogram);

// Make sure the prefix sum has been written to shared memory before choose_bucket()
__syncthreads();

// Identify the bucket that the bin that the k-th item falls into
choose_bucket(counter, k, pass);

// Reset histogram for the next pass
// TODO: Refactor calc_start_bit, calc_mask, and calc_num_passes to uniformly work with
// total_bits (passed as a kernel parameter) instead of sizeof(KeyT), then use a single
// unconditional path for both fundamental and non-fundamental types.
if constexpr (detail::radix::is_fundamental_type_v<key_in_t>)
{
constexpr int num_passes = calc_num_passes<key_in_t>(bits_per_pass);
if (pass != num_passes - 1)
{
init_histograms(histogram);
}
}
else
{
init_histograms(histogram);
}
}
}
};
} // namespace detail::topk
CUB_NAMESPACE_END
Loading
Loading