diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index 885b25d22be9e..7e97c2a193f1e 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -51,7 +51,6 @@ xla_cc_test( cc_library( name = "rocm_context", - srcs = ["rocm_context.cc"], hdrs = ["rocm_context.h"], tags = [ "gpu", @@ -61,14 +60,10 @@ cc_library( ":rocm_driver_wrapper", ":rocm_status", "//xla/stream_executor/gpu:context", - "//xla/stream_executor/gpu:context_map", "//xla/stream_executor/gpu:scoped_activate_context", "//xla/tsl/platform:errors", - "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@local_config_rocm//rocm:rocm_headers", ], ) diff --git a/xla/stream_executor/rocm/rocm_context.cc b/xla/stream_executor/rocm/rocm_context.cc deleted file mode 100644 index 4bcfe51e8e5a0..0000000000000 --- a/xla/stream_executor/rocm/rocm_context.cc +++ /dev/null @@ -1,223 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/rocm/rocm_context.h" - -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "rocm/include/hip/driver_types.h" -#include "rocm/include/hip/hip_runtime_api.h" -#include "xla/stream_executor/gpu/context_map.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" -#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" -#include "xla/stream_executor/rocm/rocm_status.h" -#include "xla/tsl/platform/errors.h" - -namespace stream_executor::gpu { - -namespace { - -// Returns the current context or dies if it fails. -hipCtx_t CurrentContextOrDie() { - hipCtx_t current = nullptr; - CHECK_OK( - ToStatus(hipCtxGetCurrent(¤t), "Failed to query current context")); - return current; -} - -// Returns the current context and checks that it is in the set of HIP contexts -// created by StreamExecutor (to ensure that the HIP runtime didn't create a -// context behind our backs). -hipCtx_t CurrentContext() { - hipCtx_t current = CurrentContextOrDie(); - if (current != nullptr && !RocmContext::GetContextMap()->Has(current)) { - LOG(FATAL) << "current context was not created by the StreamExecutor " - "rocm_driver API: " - << current - << "; a HIP runtime call " - "was likely performed without using a StreamExecutor context"; - } - return current; -} - -} // namespace - -// Returns the singleton ContextMap. -ContextMap* RocmContext::GetContextMap() { - static ContextMap* context_map = - new ContextMap([](void* ptr) { - int device_ordinal; - hipError_t result = - hipPointerGetAttribute(static_cast(&device_ordinal), - HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL, - reinterpret_cast(ptr)); - if (result != hipSuccess) { - LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << ptr - << ". Error: " << ToString(result); - } - return device_ordinal; - }); - return context_map; -} - -bool RocmContext::GetDeviceTotalMemory(hipDevice_t device, uint64_t* result) { - size_t value = -1; - hipError_t res = wrap::hipDeviceTotalMem(&value, device); - if (res != hipSuccess) { - LOG(ERROR) << "failed to query total available memory: " << ToString(res); - return false; - } - *result = value; - return true; -} - -bool RocmContext::GetDeviceMemoryUsage(int64_t* free_out, int64_t* total_out) { - ScopedActivateContext activation(this); - size_t free = 0; - size_t total = 0; - hipError_t res = wrap::hipMemGetInfo(&free, &total); - if (res != hipSuccess) { - LOG(ERROR) << "failed to query device memory info: " << ToString(res); - return false; - } - - VLOG(1) << "Device memory: " << total / 1048576 << " MB total, " - << free / 1048576 << " MB free"; - - // overflow check - if (free > std::numeric_limits::max()) { - LOG(ERROR) << "free memory (" << free << ") is overflow int64_t"; - return false; - } - - *free_out = free; - *total_out = total; - return true; -} - -RocmContext::~RocmContext() { - hipCtx_t former_context = CurrentContext(); - // Explicitly call RocmContext::SetActive() to silence clang-tidy warnings - // about calling a virtual method in the destructor. - RocmContext::SetActive(); - hipDevice_t device; - CHECK_EQ(hipSuccess, wrap::hipCtxGetDevice(&device)); - CHECK_EQ(hipSuccess, wrap::hipCtxSetCurrent(former_context)); - - auto res = wrap::hipDevicePrimaryCtxRelease(device); - - if (res != hipSuccess) { - LOG(ERROR) << "failed to release HIP context; leaking: " << ToString(res); - } - - GetContextMap()->Remove(context()); -} - -void RocmContext::SetActive() { - CHECK_OK( - ToStatus(wrap::hipCtxSetCurrent(context_), "Failed setting context")); -} - -bool RocmContext::IsActive() const { return CurrentContext() == context_; } - -absl::Status RocmContext::Synchronize() { - ScopedActivateContext activation(this); - TF_RETURN_IF_ERROR(ToStatus(wrap::hipDeviceSynchronize(), - "could not synchronize on ROCM device")); - return absl::OkStatus(); -} - -absl::StatusOr RocmContext::Create(int device_ordinal, - hipDevice_t device) { - RocmContext* context = nullptr; - - int flags = 0; - - hipError_t res; - hipCtx_t former_context; - hipCtx_t new_context; - - unsigned int former_primary_context_flags; - int former_primary_context_is_active; - CHECK_EQ(hipSuccess, wrap::hipDevicePrimaryCtxGetState( - device, &former_primary_context_flags, - &former_primary_context_is_active)); - if (former_primary_context_flags != flags) { - if (former_primary_context_is_active) { - LOG(ERROR) - << "The primary context is active and has a different flag set (" - << former_primary_context_flags << ") than the desired flag set (" - << flags << ")."; - } else { - CHECK_EQ(hipSuccess, wrap::hipDevicePrimaryCtxSetFlags(device, flags)); - } - } - - former_context = CurrentContextOrDie(); - res = wrap::hipDevicePrimaryCtxRetain(&new_context, device); - if (former_context != nullptr) { - hipDevice_t former_device; - if (wrap::hipCtxGetDevice(&former_device) == hipSuccess) { - if (former_device == device) { - if (former_context == new_context) { - VLOG(2) << "The primary context " << former_context << " for device " - << device - << " exists before initializing the StreamExecutor."; - } else { - LOG(WARNING) << "A non-primary context " << former_context - << " for device " << device - << " exists before initializing the StreamExecutor. The " - << "primary context is now " << new_context << ". We " - << "haven't verified StreamExecutor works with that."; - } - } - } else { - LOG(ERROR) << "Failed to get the device of the current context " - << former_context; - } - } - CHECK_EQ(hipSuccess, wrap::hipCtxSetCurrent(former_context)); - - if (res == hipSuccess) { - context = GetContextMap()->Add(new_context, device_ordinal); - CHECK(context != nullptr) - << "success in this call must entail non-null result"; - VLOG(2) << "created or reused context " << new_context - << " for this thread"; - return context; - } - - std::string message = - "failed call to hipDevicePrimaryCtxRetain: " + ToString(res); - if (res == hipErrorOutOfMemory) { - uint64_t total_memory; - if (GetDeviceTotalMemory(device, &total_memory)) { - absl::StrAppend(&message, "; total memory reported: ", total_memory); - } else { - absl::StrAppend(&message, "; could not query total memory"); - } - } - - return absl::InternalError(message); -} - -} // namespace stream_executor::gpu diff --git a/xla/stream_executor/rocm/rocm_context.h b/xla/stream_executor/rocm/rocm_context.h index 60480f4905412..efae197cc9611 100644 --- a/xla/stream_executor/rocm/rocm_context.h +++ b/xla/stream_executor/rocm/rocm_context.h @@ -1,4 +1,3 @@ -#include "absl/status/statusor.h" /* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,56 +13,62 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// The ROCM-specific Driver library support, implementing the general Driver -// interface. +// On ROCm, hipCtx_t is a thin wrapper around a device ordinal and the entire +// context lifecycle (Retain/Release/SetCurrent/GetCurrent) is a no-op. AMD +// has deprecated every hipCtx* and hipDevicePrimaryCtx* API since ROCm 1.9 +// with the recommendation to use hipSetDevice / hipGetDevice instead. +// +// RocmContext is a trivial implementation of the Context interface that +// delegates to hipSetDevice/hipGetDevice. It is intended to be owned as +// a plain value field inside RocmExecutor. #ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_CONTEXT_H_ #define XLA_STREAM_EXECUTOR_ROCM_ROCM_CONTEXT_H_ -#include - +#include "absl/log/check.h" #include "absl/status/status.h" #include "rocm/include/hip/hip_runtime.h" #include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/context_map.h" +#include "xla/stream_executor/gpu/scoped_activate_context.h" +#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" +#include "xla/stream_executor/rocm/rocm_status.h" +#include "xla/tsl/platform/errors.h" namespace stream_executor::gpu { -// RocmContext implements the Context class for ROCm GPUs. class RocmContext : public Context { public: - RocmContext(hipCtx_t context, const int ordinal) - : context_(context), device_ordinal_(ordinal) {} - ~RocmContext() override; + explicit RocmContext(int device_ordinal) : device_ordinal_(device_ordinal) {} + ~RocmContext() override = default; + + void SetActive() override { + CHECK_OK( + ToStatus(wrap::hipSetDevice(device_ordinal_), "Failed to set device")); + } + + bool IsActive() const override { + int current_device; + if (wrap::hipGetDevice(¤t_device) != hipSuccess) { + return false; + } + return current_device == device_ordinal_; + } - hipCtx_t context() const { return context_; } - void SetActive() override; - bool IsActive() const override; int device_ordinal() const override { return device_ordinal_; } - absl::Status Synchronize() override; - // Disallow copying and moving. + absl::Status Synchronize() override { + ScopedActivateContext activation(this); + TF_RETURN_IF_ERROR(ToStatus(wrap::hipDeviceSynchronize(), + "could not synchronize on ROCM device")); + return absl::OkStatus(); + } + RocmContext(RocmContext&&) = delete; RocmContext(const RocmContext&) = delete; RocmContext& operator=(RocmContext&&) = delete; RocmContext& operator=(const RocmContext&) = delete; - // Returns the free amount of memory and total amount of memory, as reported - // by hipDeviceTotalMem. - bool GetDeviceMemoryUsage(int64_t* free_out, int64_t* total_out); - - // Returns the total amount of memory available on the device. - static bool GetDeviceTotalMemory(hipDevice_t device, uint64_t* result); - - // Returns the context map for all XLA-known ROCm contexts. - static ContextMap* GetContextMap(); - - // Creates a new context for the given device. - static absl::StatusOr Create(int device_ordinal, - hipDevice_t device); - private: - hipCtx_t const context_; const int device_ordinal_; }; diff --git a/xla/stream_executor/rocm/rocm_driver_wrapper.h b/xla/stream_executor/rocm/rocm_driver_wrapper.h index 83064ef830fef..36f1f6e86a9a5 100644 --- a/xla/stream_executor/rocm/rocm_driver_wrapper.h +++ b/xla/stream_executor/rocm/rocm_driver_wrapper.h @@ -66,8 +66,6 @@ namespace wrap { // IMPORTANT: if you add a new HIP API to this list, please notify // the rocm-profiler developers to track the API traces. #define HIP_ROUTINE_EACH(__macro) \ - __macro(hipCtxGetDevice) \ - __macro(hipCtxSetCurrent) \ __macro(hipCtxEnablePeerAccess) \ __macro(hipDeviceCanAccessPeer) \ __macro(hipDeviceEnablePeerAccess) \ @@ -78,10 +76,6 @@ namespace wrap { __macro(hipDeviceGetSharedMemConfig) \ __macro(hipDeviceGetStreamPriorityRange) \ __macro(hipDeviceGraphMemTrim) \ - __macro(hipDevicePrimaryCtxGetState) \ - __macro(hipDevicePrimaryCtxSetFlags) \ - __macro(hipDevicePrimaryCtxRetain) \ - __macro(hipDevicePrimaryCtxRelease) \ __macro(hipDeviceSetSharedMemConfig) \ __macro(hipDeviceSynchronize) \ __macro(hipDeviceTotalMem) \ diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 7d3d2c8ebaf45..e66d5f7d23dca 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -333,17 +334,6 @@ absl::StatusOr GetDevice(int device_ordinal) { absl::StrCat("failed call to hipDeviceGet: ", ToString(res))); } -// Returns the device associated with the given context. -absl::StatusOr DeviceFromContext(Context* context) { - ScopedActivateContext activated(context); - hipDevice_t device = -1; - hipError_t result = wrap::hipCtxGetDevice(&device); - if (result == hipSuccess) return device; - - return absl::InternalError( - absl::StrCat("failed to get device for context: ", ToString(result))); -} - bool CanEnablePeerAccess(hipDevice_t from, hipDevice_t to) { int can_access_peer = -1; hipError_t result = wrap::hipDeviceCanAccessPeer(&can_access_peer, from, to); @@ -356,20 +346,19 @@ bool CanEnablePeerAccess(hipDevice_t from, hipDevice_t to) { } bool CanEnablePeerAccess(Context* from, Context* to) { - // A context can always access its own memory. if (from == to) return true; - auto from_device = DeviceFromContext(from); + auto from_device = GetDevice(from->device_ordinal()); if (!from_device.ok()) { - LOG(ERROR) << "failed to resolve 'from' peer access context to a device: " - << from_device.status(); + LOG(ERROR) << "failed to get device for ordinal " << from->device_ordinal() + << ": " << from_device.status(); return false; } - auto to_device = DeviceFromContext(to); + auto to_device = GetDevice(to->device_ordinal()); if (!to_device.ok()) { - LOG(ERROR) << "failed to resolve 'to' peer access context to a device: " - << to_device.status(); + LOG(ERROR) << "failed to get device for ordinal " << to->device_ordinal() + << ": " << to_device.status(); return false; } return CanEnablePeerAccess(from_device.value(), to_device.value()); @@ -497,7 +486,7 @@ absl::StatusOr HostAllocate(Context* context, uint64_t bytes) { } absl::StatusOr> AllocateHostMemory( - RocmContext* rocm_context, uint64_t size) { + Context* rocm_context, uint64_t size) { TF_ASSIGN_OR_RETURN(void* ptr, HostAllocate(rocm_context, size)); return std::make_unique( ptr, size, [rocm_context](void* location, uint64_t size) { @@ -515,14 +504,14 @@ absl::StatusOr> AllocateHostMemory( RocmExecutor::~RocmExecutor() { for (auto& it : in_memory_modules_) { - UnloadRocmModule(rocm_context_, it.second); + UnloadRocmModule(&rocm_context_, it.second); } CHECK(kernel_to_gpu_binary_.empty()) << "RocmExecutor has live kernels."; CHECK(gpu_binary_to_module_.empty()) << "RocmExecutor has loaded modules."; } std::unique_ptr RocmExecutor::Activate() { - return std::make_unique(rocm_context_); + return std::make_unique(&rocm_context_); } bool RocmExecutor::UnloadModule(ModuleHandle module_handle) { @@ -623,7 +612,7 @@ bool RocmExecutor::UnloadGpuBinary(ModuleHandle module_handle) { VLOG(3) << "Found HSACO module " << module << " with refcount " << refcount; if (--refcount == 0) { VLOG(3) << "Unloading HSACO module " << module; - UnloadRocmModule(rocm_context_, module); + UnloadRocmModule(&rocm_context_, module); gpu_binary_to_module_.erase(module_it); ModuleHandle mem_it{}; for (auto x : in_memory_modules_) { @@ -653,9 +642,6 @@ void RocmExecutor::UnloadKernel(const Kernel* kernel) { absl::Status RocmExecutor::Init() { TF_ASSIGN_OR_RETURN(device_, GetDevice(device_ordinal())); - - TF_ASSIGN_OR_RETURN(rocm_context_, - RocmContext::Create(device_ordinal(), device_)); TF_ASSIGN_OR_RETURN(version_, GetGpuISAVersion(device_)); // We initialize BLAS interfaces early here since otherwise it might create // us problems during hipBlasLt initialization under graph capture. @@ -678,14 +664,14 @@ absl::StatusOr> RocmExecutor::LoadKernel( hipModule_t& module = in_memory_modules_[module_handle]; if (module == nullptr) { - TF_ASSIGN_OR_RETURN(module, LoadHsaco(rocm_context_, hsaco)); + TF_ASSIGN_OR_RETURN(module, LoadHsaco(&rocm_context_, hsaco)); } kernel_to_gpu_binary_[rocm_kernel.get()] = module_handle; VLOG(2) << "getting function " << kernel_name << " from module " << module; TF_ASSIGN_OR_RETURN( hipFunction_t function, - GetModuleFunction(rocm_context_, module, kernel_name.c_str())); + GetModuleFunction(&rocm_context_, module, kernel_name.c_str())); rocm_kernel->set_gpu_function(function); } else if (spec.has_in_process_symbol()) { void* symbol = spec.in_process_symbol()->symbol; @@ -764,7 +750,7 @@ absl::StatusOr RocmExecutor::LoadModuleFromHsaco( std::tie(module, module_refcount) = gpu_binary_to_module_[module_handle]; if (module == nullptr) { - TF_ASSIGN_OR_RETURN(module, LoadHsaco(rocm_context_, hsaco)); + TF_ASSIGN_OR_RETURN(module, LoadHsaco(&rocm_context_, hsaco)); module_refcount = 1; in_memory_modules_[module_handle] = module; VLOG(3) << "Loaded HSACO " << static_cast(hsaco) @@ -783,7 +769,8 @@ DeviceAddressBase RocmExecutor::Allocate(uint64_t size, int64_t memory_space) { case MemorySpace::kCollective: case MemorySpace::kDevice: return DeviceAddressBase( - DeviceAllocate(rocm_context_, size, /*is_fine_grained*/ false), size); + DeviceAllocate(&rocm_context_, size, /*is_fine_grained*/ false), + size); case MemorySpace::kP2P: // On the ROCm platform, differences in cache design (e.g., coherence // protocol) can cause cache coherence issues for some archs (e.g., MI200) @@ -791,9 +778,9 @@ DeviceAddressBase RocmExecutor::Allocate(uint64_t size, int64_t memory_space) { // fine-grained memory in P2P communication for all archs to make sure of // the correctness. return DeviceAddressBase( - DeviceAllocate(rocm_context_, size, /*is_fine_grained*/ true), size); + DeviceAllocate(&rocm_context_, size, /*is_fine_grained*/ true), size); case MemorySpace::kHost: - if (auto result = HostAllocate(rocm_context_, size); result.ok()) { + if (auto result = HostAllocate(&rocm_context_, size); result.ok()) { return DeviceAddressBase(*result, size); } return DeviceAddressBase(nullptr, 0); @@ -803,11 +790,11 @@ DeviceAddressBase RocmExecutor::Allocate(uint64_t size, int64_t memory_space) { } absl::StatusOr> RocmExecutor::HostMemoryAllocate(uint64_t size) { - return AllocateHostMemory(rocm_context_, size); + return AllocateHostMemory(&rocm_context_, size); } void RocmExecutor::Deallocate(DeviceAddressBase* mem) { - DeviceDeallocate(rocm_context_, mem->opaque()); + DeviceDeallocate(&rocm_context_, mem->opaque()); } absl::StatusOr> @@ -824,8 +811,9 @@ RocmExecutor::CreateMemoryAllocator(MemorySpace type) { wrap::hipMallocManaged(&result, size, hipMemAttachGlobal), "Failed to allocate managed memory")); void* ptr = reinterpret_cast(result); - VLOG(2) << "allocated " << ptr << " for context " << rocm_context_ - << " of " << size << " bytes in unified memory"; + VLOG(2) << "allocated " << ptr << " for device " + << rocm_context_.device_ordinal() << " of " << size + << " bytes in unified memory"; return std::make_unique( ptr, size, [this](void* location, uint64_t size) { std::unique_ptr activation = Activate(); @@ -837,7 +825,7 @@ RocmExecutor::CreateMemoryAllocator(MemorySpace type) { << location << "; result: " << ToString(res); } else { VLOG(2) << "deallocated unified memory at " << location - << " for context " << rocm_context_; + << " for device " << rocm_context_.device_ordinal(); } }); }); @@ -870,7 +858,7 @@ RocmExecutor::CreateMemoryAllocator(MemorySpace type) { }); case MemorySpace::kHost: return std::make_unique([this](uint64_t size) { - return AllocateHostMemory(rocm_context_, size); + return AllocateHostMemory(&rocm_context_, size); }); default: return absl::UnimplementedError( @@ -879,7 +867,7 @@ RocmExecutor::CreateMemoryAllocator(MemorySpace type) { } bool RocmExecutor::SynchronizeAllActivity() { - return rocm_context_->Synchronize().ok(); + return rocm_context_.Synchronize().ok(); } bool RocmExecutor::HostMemoryRegister(void* location, uint64_t size) { @@ -1024,16 +1012,31 @@ fft::FftSupport* RocmExecutor::AsFft() { bool RocmExecutor::CanEnablePeerAccessTo(StreamExecutor* other) { RocmExecutor* rocm_other = static_cast(other); - return CanEnablePeerAccess(rocm_context_, rocm_other->rocm_context_); + return CanEnablePeerAccess(&rocm_context_, &rocm_other->rocm_context_); } absl::Status RocmExecutor::EnablePeerAccessTo(StreamExecutor* other) { RocmExecutor* rocm_other = static_cast(other); - return EnablePeerAccess(rocm_context_, rocm_other->rocm_context_); + return EnablePeerAccess(&rocm_context_, &rocm_other->rocm_context_); } -bool RocmExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const { - return rocm_context_->GetDeviceMemoryUsage(free, total); +bool RocmExecutor::DeviceMemoryUsage(int64_t* free_out, + int64_t* total_out) const { + ScopedActivateContext activation(&rocm_context_); + size_t free = 0; + size_t total = 0; + hipError_t res = wrap::hipMemGetInfo(&free, &total); + if (res != hipSuccess) { + LOG(ERROR) << "failed to query device memory info: " << ToString(res); + return false; + } + if (free > std::numeric_limits::max()) { + LOG(ERROR) << "free memory (" << free << ") overflows int64_t"; + return false; + } + *free_out = free; + *total_out = total; + return true; } absl::StatusOr RocmExecutor::GetSymbol( @@ -1046,14 +1049,14 @@ absl::StatusOr RocmExecutor::GetSymbol( auto it = gpu_binary_to_module_.find(module_handle); CHECK(it != gpu_binary_to_module_.end()); TF_RETURN_IF_ERROR( - GetModuleSymbol(rocm_context_, it->second.first, symbol_name.c_str(), + GetModuleSymbol(&rocm_context_, it->second.first, symbol_name.c_str(), reinterpret_cast(&mem), &bytes)); return DeviceAddressBase(mem, bytes); } for (auto& it : gpu_binary_to_module_) { TF_RETURN_IF_ERROR( - GetModuleSymbol(rocm_context_, it.second.first, symbol_name.c_str(), + GetModuleSymbol(&rocm_context_, it.second.first, symbol_name.c_str(), reinterpret_cast(&mem), &bytes)); return DeviceAddressBase(mem, bytes); } @@ -1173,8 +1176,16 @@ RocmExecutor::CreateDeviceDescription(int device_ordinal) { desc.set_ecc_enabled(*ecc_enabled_or); } - uint64_t device_memory_size = -1; - (void)RocmContext::GetDeviceTotalMemory(device, &device_memory_size); + uint64_t device_memory_size = 0; + { + size_t value = 0; + hipError_t res = wrap::hipDeviceTotalMem(&value, device); + if (res == hipSuccess) { + device_memory_size = value; + } else { + LOG(ERROR) << "failed to query total available memory: " << ToString(res); + } + } desc.set_device_memory_size(device_memory_size); { diff --git a/xla/stream_executor/rocm/rocm_executor.h b/xla/stream_executor/rocm/rocm_executor.h index cbf064795206c..423e5e9af9a48 100644 --- a/xla/stream_executor/rocm/rocm_executor.h +++ b/xla/stream_executor/rocm/rocm_executor.h @@ -61,7 +61,7 @@ namespace stream_executor::gpu { class RocmExecutor : public GpuExecutor { public: RocmExecutor(Platform* platform, int device_ordinal) - : GpuExecutor(platform, device_ordinal) {} + : GpuExecutor(platform, device_ordinal), rocm_context_(device_ordinal) {} ~RocmExecutor() override; std::unique_ptr Activate() override; @@ -203,8 +203,11 @@ class RocmExecutor : public GpuExecutor { // GPU ISA version for device_. int version_; - // RocmContext for this device. - RocmContext* rocm_context_; + // RocmContext for this device. Owned as a value — on ROCm a "context" + // is just a device ordinal, so there is no heavyweight object to manage. + // Mutable because SetActive() (hipSetDevice) is logically const — it does + // not change RocmContext state, but ScopedActivateContext takes non-const. + mutable RocmContext rocm_context_; }; } // namespace stream_executor::gpu