From da9bea7fb44d3b2c405ebd9a408b5a5dbdfe06b1 Mon Sep 17 00:00:00 2001 From: Dragan Mladjenovic Date: Thu, 19 Mar 2026 08:23:31 -0500 Subject: [PATCH] [ROCm] Use BFCAllocator for scratch allocations needed for MIOpen autotuning --- xla/backends/gpu/autotuner/BUILD | 4 ++- xla/backends/gpu/autotuner/autotuner_main.cc | 12 +++---- xla/backends/gpu/autotuner/factory.h | 3 +- xla/backends/gpu/autotuner/factory_cuda.cc | 1 + xla/backends/gpu/autotuner/factory_rocm.cc | 6 ++-- xla/backends/gpu/autotuner/factory_test.cc | 11 +++--- xla/backends/gpu/autotuner/miopen.cc | 36 +++++++++----------- xla/backends/gpu/autotuner/miopen.h | 8 +++-- xla/backends/gpu/autotuner/miopen_test.cc | 5 ++- xla/service/gpu/gpu_compiler.cc | 11 +++--- xla/service/gpu/gpu_compiler.h | 2 ++ 11 files changed, 58 insertions(+), 41 deletions(-) diff --git a/xla/backends/gpu/autotuner/BUILD b/xla/backends/gpu/autotuner/BUILD index 0c13814d30d30..77d72769515f4 100644 --- a/xla/backends/gpu/autotuner/BUILD +++ b/xla/backends/gpu/autotuner/BUILD @@ -1005,12 +1005,12 @@ cc_library( "//xla/service/gpu:gpu_conv_runner", "//xla/service/gpu:stream_executor_util", "//xla/stream_executor:device_address", + "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:dnn", "//xla/stream_executor:engine_options", "//xla/stream_executor:scratch_allocator", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tsl/platform:errors", "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", @@ -1180,6 +1180,7 @@ xla_test( "//xla/stream_executor:device_description_proto_cc", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/rocm:rocm_platform_id", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:statusor", @@ -1212,6 +1213,7 @@ xla_test( "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/platform:platform_object_registry", "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings", diff --git a/xla/backends/gpu/autotuner/autotuner_main.cc b/xla/backends/gpu/autotuner/autotuner_main.cc index d9d9c832a9d66..71bb97523d504 100644 --- a/xla/backends/gpu/autotuner/autotuner_main.cc +++ b/xla/backends/gpu/autotuner/autotuner_main.cc @@ -103,16 +103,16 @@ absl::Status Autotune(HloModule& module) { DebugOptions debug_options = GetDebugOptionsFromFlags(); Compiler::GpuTargetConfig target_config(stream_executor); + std::unique_ptr allocator = + std::make_unique( + stream_executor); + mlir::MLIRContext mlir_context; xla::RegisterSymbolicExprStorage(&mlir_context); TF_ASSIGN_OR_RETURN(std::vector> backends, gpu_compiler->GetAutotunerBackends( - stream_executor, &target_config, alias_info.get(), - debug_options, &mlir_context)); - - std::unique_ptr allocator = - std::make_unique( - stream_executor); + stream_executor, allocator.get(), &target_config, + alias_info.get(), debug_options, &mlir_context)); tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "autotuner", tsl::port::MaxParallelism()); diff --git a/xla/backends/gpu/autotuner/factory.h b/xla/backends/gpu/autotuner/factory.h index 3208862d0a824..d9e3712e12195 100644 --- a/xla/backends/gpu/autotuner/factory.h +++ b/xla/backends/gpu/autotuner/factory.h @@ -38,7 +38,8 @@ namespace gpu { // returned. struct GetCodegenBackends { using Type = std::function>( - stream_executor::StreamExecutor*, const DebugOptions*, Compiler*, + stream_executor::StreamExecutor*, + stream_executor::DeviceAddressAllocator*, const DebugOptions*, Compiler*, const Compiler::GpuTargetConfig*, const AliasInfo* alias_info, mlir::MLIRContext* mlir_context, absl::Span backend_allowlist)>; diff --git a/xla/backends/gpu/autotuner/factory_cuda.cc b/xla/backends/gpu/autotuner/factory_cuda.cc index e92282c0aec5b..b5440bfac1170 100644 --- a/xla/backends/gpu/autotuner/factory_cuda.cc +++ b/xla/backends/gpu/autotuner/factory_cuda.cc @@ -83,6 +83,7 @@ std::unique_ptr GetCustomKernelRewriterPipeline( std::vector> GetCodegenBackendsForCuda( stream_executor::StreamExecutor* stream_executor, + stream_executor::DeviceAddressAllocator* device_allocator, const DebugOptions* debug_options, Compiler* compiler, const Compiler::GpuTargetConfig* target_config, const AliasInfo* alias_info, MLIRContext* mlir_context, diff --git a/xla/backends/gpu/autotuner/factory_rocm.cc b/xla/backends/gpu/autotuner/factory_rocm.cc index cfe0951727723..d32a3838003c3 100644 --- a/xla/backends/gpu/autotuner/factory_rocm.cc +++ b/xla/backends/gpu/autotuner/factory_rocm.cc @@ -44,6 +44,7 @@ using ::mlir::MLIRContext; std::vector> GetCodegenBackendsForROCm( stream_executor::StreamExecutor* stream_executor, + stream_executor::DeviceAddressAllocator* device_allocator, const DebugOptions* debug_options, Compiler* compiler, const Compiler::GpuTargetConfig* target_config, const AliasInfo* alias_info, MLIRContext* mlir_context, @@ -51,8 +52,9 @@ std::vector> GetCodegenBackendsForROCm( std::vector> backends; backends.push_back(std::make_unique( debug_options, compiler, target_config, alias_info, mlir_context)); - backends.push_back(std::make_unique( - stream_executor, debug_options, compiler, target_config)); + backends.push_back( + std::make_unique(stream_executor, debug_options, compiler, + target_config, device_allocator)); backends.push_back(std::make_unique( stream_executor, debug_options, compiler, target_config)); backends.push_back(std::make_unique( diff --git a/xla/backends/gpu/autotuner/factory_test.cc b/xla/backends/gpu/autotuner/factory_test.cc index 725648163b873..b43e7666eb8cb 100644 --- a/xla/backends/gpu/autotuner/factory_test.cc +++ b/xla/backends/gpu/autotuner/factory_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "xla/stream_executor/platform/platform_object_registry.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tsl/platform/statusor.h" namespace xla { @@ -55,6 +56,7 @@ class FactoryTest : public xla::HloHardwareIndependentTestBase, se::StreamExecutor* stream_executor_; Compiler::GpuTargetConfig target_config_; DebugOptions debug_options_; + se::StreamExecutorMemoryAllocator allocator_; FactoryTest() : platform_(se::PlatformManager::PlatformWithName( @@ -63,7 +65,8 @@ class FactoryTest : public xla::HloHardwareIndependentTestBase, .value()), compiler_(xla::Compiler::GetForPlatform(platform_->id()).value()), stream_executor_(platform_->ExecutorForDevice(0).value()), - target_config_(stream_executor_) {} + target_config_(stream_executor_), + allocator_(stream_executor_) {} }; TEST_P(FactoryTest, GetCodegenBackends) { @@ -81,9 +84,9 @@ TEST_P(FactoryTest, GetCodegenBackends) { AliasInfo alias_info; xla::RegisterSymbolicExprStorage(&mlir_context); std::vector> backends = - get_codegen_backends(stream_executor_, &debug_options_, compiler_.get(), - &target_config_, &alias_info, &mlir_context, - GetParam().names); + get_codegen_backends(stream_executor_, &allocator_, &debug_options_, + compiler_.get(), &target_config_, &alias_info, + &mlir_context, GetParam().names); EXPECT_EQ(backends.size(), GetParam().expected_num_backends); } else { GTEST_SKIP() << "Skipping test for platform " << platform_->id(); diff --git a/xla/backends/gpu/autotuner/miopen.cc b/xla/backends/gpu/autotuner/miopen.cc index 260a90ee11301..da56b9031c14f 100644 --- a/xla/backends/gpu/autotuner/miopen.cc +++ b/xla/backends/gpu/autotuner/miopen.cc @@ -48,7 +48,6 @@ limitations under the License. #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/protobuf/dnn.pb.h" @@ -80,9 +79,8 @@ bool IsCustomCallToDnnFusedConvolution(const HloInstruction& hlo) { absl::Status ApplyConfigAndUpdateWorkspaceInOutputTuple( HloInstruction& instr, const MIOpenBackendConfig& config) { HloComputation* computation = instr.parent(); - std::vector new_call_element_shapes; + absl::InlinedVector new_call_element_shapes; // Add the shapes of the outputs of the convolution. - new_call_element_shapes.reserve(instr.shape().tuple_shapes().size() - 1); for (int i = 0; i < instr.shape().tuple_shapes().size() - 1; ++i) { new_call_element_shapes.emplace_back(instr.shape().tuple_shapes(i)); } @@ -102,8 +100,7 @@ absl::Status ApplyConfigAndUpdateWorkspaceInOutputTuple( *cudnn_conv_config->mutable_algorithm() = config; TF_RETURN_IF_ERROR(new_call->set_backend_config(gpu_backend_config)); - std::vector new_tuple_elements; - new_tuple_elements.reserve(new_call->shape().tuple_shapes().size() - 1); + absl::InlinedVector new_tuple_elements; for (int i = 0; i < new_call->shape().tuple_shapes().size() - 1; ++i) { new_tuple_elements.emplace_back( computation->AddInstruction(HloInstruction::CreateGetTupleElement( @@ -268,6 +265,7 @@ absl::StatusOr>> GetConvolutionCustomCallConfigs(const HloCustomCallInstruction* instr, const HloModule* module, se::StreamExecutor* stream_executor, + se::DeviceAddressAllocator* allocator, se::Stream* stream) { CHECK(instr->custom_call_target() != kCudnnConvForwardGraphCallTarget); ASSIGN_OR_RETURN(GpuConvConfig gpu_conv_config, GetGpuConvConfig(instr)); @@ -279,10 +277,10 @@ GetConvolutionCustomCallConfigs(const HloCustomCallInstruction* instr, se::dnn::DataType output_type, GetDNNDataTypeFromPrimitiveType(gpu_conv_config.output_type)); se::dnn::DnnSupport* dnn = stream_executor->AsDnn(); - se::StreamExecutorMemoryAllocator allocator(stream_executor); + std::unique_ptr owned_stream; if (stream == nullptr) { - TF_ASSIGN_OR_RETURN(stream, - allocator.GetStream(stream_executor->device_ordinal())); + TF_ASSIGN_OR_RETURN(owned_stream, stream_executor->CreateStream()); + stream = owned_stream.get(); } bool allow_tf32 = absl::c_all_of( instr->precision_config().operand_precision(), @@ -291,8 +289,8 @@ GetConvolutionCustomCallConfigs(const HloCustomCallInstruction* instr, allow_tf32, /*require_command_buffer=*/false}; - se::OwningScratchAllocator<> scratch_allocator( - stream_executor->device_ordinal(), &allocator); + se::OwningScratchAllocator<4> scratch_allocator( + stream_executor->device_ordinal(), allocator); const auto initialize_buffer = [stream](se::DeviceAddressBase buffer) { // Although we don't have evidence this matters, zero out the buffers @@ -302,8 +300,7 @@ GetConvolutionCustomCallConfigs(const HloCustomCallInstruction* instr, return stream->MemZero(&buffer, buffer.size()); }; - std::vector operand_buffers; - operand_buffers.reserve(instr->operand_count()); + absl::InlinedVector operand_buffers; for (const auto* operand : instr->operands()) { ASSIGN_OR_RETURN(auto buffer, scratch_allocator.AllocateBytes( ShapeUtil::ByteSizeOf(operand->shape()))); @@ -311,9 +308,8 @@ GetConvolutionCustomCallConfigs(const HloCustomCallInstruction* instr, operand_buffers.push_back(buffer); } - std::vector result_buffers; - size_t result_buffers_count = instr->shape().tuple_shapes().size(); - result_buffers.reserve(result_buffers_count); + absl::InlinedVector result_buffers; + size_t result_buffers_count = instr->shape().tuple_shapes().size() - 1; for (int i = 0; i < result_buffers_count; ++i) { ASSIGN_OR_RETURN(auto buffer, scratch_allocator.AllocateBytes(ShapeUtil::ByteSizeOf( @@ -351,7 +347,8 @@ GetConvolutionCustomCallConfigs(const HloCustomCallInstruction* instr, absl::StatusOr>> GetFusedConvolutionCustomCallConfigs(const HloCustomCallInstruction* instr, const HloModule* module, - se::StreamExecutor* stream_executor) { + se::StreamExecutor* stream_executor, + se::DeviceAddressAllocator* allocator) { ASSIGN_OR_RETURN(GpuConvConfig gpu_conv_config, GetGpuConvConfig(instr)); ASSIGN_OR_RETURN(se::dnn::DataType input_type, GetDNNDataTypeFromPrimitiveType(gpu_conv_config.input_type)); @@ -409,7 +406,7 @@ GetFusedConvolutionCustomCallConfigs(const HloCustomCallInstruction* instr, return GetConvolutionCustomCallConfigs( static_cast(new_conv.get()), module, - stream_executor, owned_stream.get()); + stream_executor, allocator, owned_stream.get()); } absl::StatusOr>> @@ -418,7 +415,8 @@ MIOpenBackend::GetSupportedConfigs(const HloInstruction& instr) { auto custom_call_instr = Cast(&instr); if (IsCustomCallToDnnFusedConvolution(*custom_call_instr)) { return GetFusedConvolutionCustomCallConfigs( - custom_call_instr, custom_call_instr->GetModule(), stream_executor()); + custom_call_instr, custom_call_instr->GetModule(), stream_executor(), + allocator_); } if (do_not_autotune_) { @@ -430,7 +428,7 @@ MIOpenBackend::GetSupportedConfigs(const HloInstruction& instr) { return GetConvolutionCustomCallConfigs( custom_call_instr, custom_call_instr->GetModule(), stream_executor(), - /* stream */ nullptr); + allocator_, /* stream */ nullptr); } return std::vector>(); } diff --git a/xla/backends/gpu/autotuner/miopen.h b/xla/backends/gpu/autotuner/miopen.h index 0d56d62db176c..d0406e9e46713 100644 --- a/xla/backends/gpu/autotuner/miopen.h +++ b/xla/backends/gpu/autotuner/miopen.h @@ -26,6 +26,7 @@ limitations under the License. #include "xla/backends/gpu/autotuner/gpu_codegen_backend.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/compiler.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/stream_executor.h" #include "xla/xla.pb.h" @@ -37,10 +38,12 @@ class MIOpenBackend : public GpuCodegenBackend { public: explicit MIOpenBackend(stream_executor::StreamExecutor* stream_executor, const DebugOptions* debug_options, Compiler* compiler, - const Compiler::GpuTargetConfig* target_config) + const Compiler::GpuTargetConfig* target_config, + stream_executor::DeviceAddressAllocator* allocator) : GpuCodegenBackend(autotuner::Backend::MIOPEN, debug_options, compiler, target_config, stream_executor), - do_not_autotune_(debug_options->xla_gpu_autotune_level() == 0) {} + do_not_autotune_(debug_options->xla_gpu_autotune_level() == 0), + allocator_(allocator) {} absl::StatusOr>> GetSupportedConfigs(const HloInstruction& instr) override; @@ -54,6 +57,7 @@ class MIOpenBackend : public GpuCodegenBackend { private: bool IsSupported(const HloInstruction& instr) override; bool do_not_autotune_; + stream_executor::DeviceAddressAllocator* allocator_; }; } // namespace gpu diff --git a/xla/backends/gpu/autotuner/miopen_test.cc b/xla/backends/gpu/autotuner/miopen_test.cc index 314975092b3be..b2c05dbb5ccf6 100644 --- a/xla/backends/gpu/autotuner/miopen_test.cc +++ b/xla/backends/gpu/autotuner/miopen_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/protobuf/dnn.pb.h" @@ -77,6 +78,7 @@ class MIOpenBackendTest : public HloHardwareIndependentTestBase { AMDGPUCompiler compiler_; se::StreamExecutor* stream_executor_; Compiler::GpuTargetConfig target_config_; + se::StreamExecutorMemoryAllocator allocator_; MIOpenBackend backend_; MIOpenBackendTest() @@ -85,13 +87,14 @@ class MIOpenBackendTest : public HloHardwareIndependentTestBase { ->ExecutorForDevice(0) .value()), target_config_(stream_executor_), + allocator_(stream_executor_), backend_( stream_executor_, [](auto& opts) { opts.set_xla_gpu_autotune_level(1); return &opts; }(debug_options_), - &compiler_, &target_config_) {} + &compiler_, &target_config_, &allocator_) {} bool IsRocm() { return stream_executor_->GetPlatform()->id() == se::rocm::kROCmPlatformId; diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index c35d1d0d7fdc1..0b29c983f89b1 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -3225,8 +3225,8 @@ absl::Status GpuCompiler::AddConvAndGemmAutotuningPass( HloCostAnalysis::ShapeSizeFunction shape_size_fn) { TF_ASSIGN_OR_RETURN( std::vector> backends, - GetAutotunerBackends(stream_exec, target_config, alias_info, - debug_options, mlir_context)); + GetAutotunerBackends(stream_exec, options.device_allocator, target_config, + alias_info, debug_options, mlir_context)); bool do_not_autotune_cublas = debug_options.xla_gpu_experimental_disable_binary_libraries() || @@ -3282,6 +3282,7 @@ absl::Status GpuCompiler::AddConvAndGemmAutotuningPass( absl::StatusOr>> GpuCompiler::GetAutotunerBackends( se::StreamExecutor* stream_exec, + se::DeviceAddressAllocator* device_allocator, const Compiler::GpuTargetConfig* target_config, const AliasInfo* alias_info, const DebugOptions& debug_options, mlir::MLIRContext* mlir_context) { std::vector autotune_backends; @@ -3328,9 +3329,9 @@ GpuCompiler::GetAutotunerBackends( auto& registry = stream_executor::PlatformObjectRegistry::GetGlobalRegistry(); TF_ASSIGN_OR_RETURN(const GetCodegenBackends::Type& get_codegen_backends, registry.FindObject(PlatformId())); - std::vector> backends = - get_codegen_backends(stream_exec, &debug_options, this, target_config, - alias_info, mlir_context, autotune_backends); + std::vector> backends = get_codegen_backends( + stream_exec, device_allocator, &debug_options, this, target_config, + alias_info, mlir_context, autotune_backends); return backends; } diff --git a/xla/service/gpu/gpu_compiler.h b/xla/service/gpu/gpu_compiler.h index d8b465f8cfb37..580f5216feeaf 100644 --- a/xla/service/gpu/gpu_compiler.h +++ b/xla/service/gpu/gpu_compiler.h @@ -45,6 +45,7 @@ limitations under the License. #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_module_config.h" #include "xla/service/llvm_compiler.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" #include "xla/stream_executor/dnn.h" @@ -150,6 +151,7 @@ class GpuCompiler : public LLVMCompiler { absl::StatusOr>> GetAutotunerBackends(se::StreamExecutor* stream_exec, + se::DeviceAddressAllocator* device_allocator, const Compiler::GpuTargetConfig* target_config, const AliasInfo* alias_info, const DebugOptions& debug_options,