-
Notifications
You must be signed in to change notification settings - Fork 8
[ROCm] Add waves_per_eu support to Triton GEMM config #769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -161,6 +161,9 @@ std::vector<TritonGemmConfig> TritonDotFusionSearchSpace::GenerateConfigs( | |
| ExtendConfigs(configs, | ||
| &TritonDotFusionSearchSpace::AddWarpSpecializationParameter); | ||
| } | ||
| if (device_description_.gpu_compute_capability().IsRocm()) { | ||
| ExtendConfigs(configs, &TritonDotFusionSearchSpace::AddWavesPerEuParameter); | ||
| } | ||
|
|
||
| std::vector<TritonGemmConfig> result; | ||
| result.reserve(configs.size()); | ||
|
|
@@ -738,4 +741,38 @@ void TritonDotFusionSearchSpace::EliminateLowOccupancyConfigs( | |
| VLOG(10) << "Eliminated " << num_configs - configs.size() << " configs."; | ||
| } | ||
|
|
||
| void TritonDotFusionSearchSpace::AddWavesPerEuParameter( | ||
| const ConfigWithNotes& config, | ||
| std::vector<ConfigWithNotes>& updated_configs) const { | ||
| // Hints the LLVM backend to reduce VGPR usage so that the given number of | ||
| // wavefronts can run concurrently on each Execution Unit (EU). On MI300X | ||
| // each EU has 512 VGPRs allocated in blocks of 16, giving: | ||
| // | ||
| // Num VGPRs Waves/EU Waves/CU | ||
| // <= 64 8 32 | ||
| // <= 96 5 20 | ||
| // <= 128 4 16 | ||
| // <= 168 3 12 | ||
| // <= 256 2 8 | ||
| // > 256 1 4 | ||
| // | ||
| // 0 means no restriction (the default). Higher values force tighter register | ||
| // budgets, potentially increasing occupancy at the cost of more spilling. | ||
| // | ||
| // Actual occupancy also depends on LDS (shared memory) and num_warps: | ||
| // occ = min(floor(occ_vgpr * 4 / num_warps), | ||
| // floor(65536 / lds_bytes)) * num_warps / 4 | ||
| // where occ_vgpr is from the table above. Neither lds_bytes nor | ||
| // the VGPR count are known at search space construction time, so we | ||
| // enumerate a few values and let the autotuner pick the best one. | ||
| static constexpr int kWavesPerEuValues[] = {0, 1, 2, 4}; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you add this one as the reference? https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#auto-tunable-kernel-configurations
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added the comment based on both links here. I think it is better to be self-consistent and I believe that code comment will outlive any URL. |
||
| for (int waves : kWavesPerEuValues) { | ||
| ConfigWithNotes new_config = config; | ||
| new_config.config.waves_per_eu = waves; | ||
| VLOG(10) << "Adding waves_per_eu parameter: config = " | ||
| << new_config.ToString(); | ||
| updated_configs.push_back(new_config); | ||
| } | ||
| } | ||
|
|
||
| } // namespace xla::gpu | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -525,6 +525,16 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM( | |
| VerifyModule(*ll_triton_module); | ||
| } | ||
|
|
||
| // Apply ROCm-specific waves_per_eu attribute if set. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| if (gpu_cc.IsRocm() && block_level_parameters.waves_per_eu > 0) { | ||
| if (auto* fn = ll_triton_module->getFunction(kernel_name)) { | ||
| std::string waves_attr = | ||
| absl::StrCat(block_level_parameters.waves_per_eu, ", ", | ||
| block_level_parameters.waves_per_eu); | ||
| fn->addFnAttr("amdgpu-waves-per-eu", waves_attr); | ||
| } | ||
| } | ||
|
|
||
| // Integrate LLVM matmul kernel into XLA's LLVM module. | ||
| captured_nvvm_annotations = | ||
| xgt::ExtractNvvmAnnotations(ll_triton_module.get()); | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1080,11 +1080,13 @@ absl::StatusOr<se::gpu::BlasLt::Epilogue> AsBlasLtEpilogue( | |||||||||||
| TF_RET_CHECK(proto.num_stages() > 0); | ||||||||||||
| TF_RET_CHECK(proto.num_warps() > 0); | ||||||||||||
| TF_RET_CHECK(proto.num_ctas() > 0); | ||||||||||||
| TF_RET_CHECK(proto.waves_per_eu() >= 0); | ||||||||||||
|
|
||||||||||||
| return TritonGemmConfig( | ||||||||||||
| proto.block_m(), proto.block_n(), proto.block_k(), proto.split_k(), | ||||||||||||
| proto.num_stages(), proto.num_warps(), proto.num_ctas(), | ||||||||||||
| proto.is_tma_allowed(), proto.is_warp_specialization_allowed()); | ||||||||||||
| proto.is_tma_allowed(), proto.is_warp_specialization_allowed(), | ||||||||||||
| proto.waves_per_eu()); | ||||||||||||
|
Comment on lines
+1088
to
+1089
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All other numeric config fields validated above have Consider adding:
Suggested change
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Resolved — |
||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| AutotuneResult::TritonGemmKey TritonGemmConfig::ToProto() const { | ||||||||||||
|
|
@@ -1098,6 +1100,7 @@ AutotuneResult::TritonGemmKey TritonGemmConfig::ToProto() const { | |||||||||||
| key.set_num_ctas(num_ctas); | ||||||||||||
| key.set_is_tma_allowed(is_tma_allowed); | ||||||||||||
| key.set_is_warp_specialization_allowed(is_warp_specialization_allowed); | ||||||||||||
| key.set_waves_per_eu(waves_per_eu); | ||||||||||||
| return key; | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
|
|
@@ -1107,7 +1110,8 @@ std::string TritonGemmConfig::ToString() const { | |||||||||||
| ",split_k:", split_k, ",num_stages:", num_stages, | ||||||||||||
| ",num_warps:", num_warps, ",num_ctas:", num_ctas, | ||||||||||||
| ",is_tma_allowed:", is_tma_allowed, | ||||||||||||
| ",is_warp_specialization_allowed:", is_warp_specialization_allowed, "}"); | ||||||||||||
| ",is_warp_specialization_allowed:", is_warp_specialization_allowed, | ||||||||||||
| ",waves_per_eu:", waves_per_eu, "}"); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| absl::StatusOr<bool> IsMatrixMultiplicationTooSmallForRewriting( | ||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
waves_per_euis declared asint64here, but asint32inbackend_configs.proto(BlockLevelFusionConfig) and asintin the C++ structs (TritonGemmConfig,BlockLevelParameters). Since the max meaningful value is single digits ({0,1,2,4}), consider usingint32to matchbackend_configs.protoand the pattern used bynum_ctas/num_stagesthere.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is fine, I exactly follow num_ctas/num_stages pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved — you're right,
num_ctas,num_stages, andnum_warpsall useint64in this proto. Consistent with the existing pattern.