diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index caf64579bad0e..9fcbacc4dcb5b 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -193,6 +193,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/numeric:int128", diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 40087f84392b7..e95339f88b6a6 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -15,8 +15,6 @@ limitations under the License. #include "xla/stream_executor/rocm/rocm_executor.h" -#include - #include #include #include @@ -30,6 +28,7 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/numeric/int128.h" @@ -47,6 +46,7 @@ limitations under the License. #include "rocm/include/hip/hip_runtime.h" #include "rocm/include/hip/hip_version.h" #include "rocm/rocm_config.h" +#include #include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" @@ -176,6 +176,17 @@ absl::StatusOr GetModuleFunction(Context* context, const char* kernel_name) { ScopedActivateContext activated(context); CHECK(module != nullptr && kernel_name != nullptr); + // Check for pre-existing HIP errors before the call. On ROCm 7+ + // the per-thread error state is sticky: successful HIP calls do + // not clear it, so a stale error from a prior operation would + // produce confusing diagnostics if we proceeded. + hipError_t pre_err = ::hipPeekAtLastError(); + if (pre_err != hipSuccess) { + return absl::InternalError( + absl::StrCat("There was a HIP error before calling " + "hipModuleGetFunction for kernel '", + kernel_name, "': ", ToString(pre_err))); + } hipFunction_t function; TF_RETURN_IF_ERROR( ToStatus(wrap::hipModuleGetFunction(&function, module, kernel_name), @@ -197,6 +208,16 @@ absl::Status GetModuleSymbol(Context* context, hipModule_t module, absl::StrCat("Failed to get symbol '", symbol_name, "'")); } +// Compute a content-based ModuleHandle from HSACO bytes. +// Using a content hash instead of the raw data pointer avoids stale cache +// entries when an HSACO buffer is freed and a new one is allocated at the +// same address (pointer-reuse cache collision). +ModuleHandle HsacoModuleHandle(const char* hsaco, size_t size) { + auto hash = absl::HashOf(absl::string_view(hsaco, size)); + // Ensure hash is never 0 (ModuleHandle treats nullptr as invalid) + return ModuleHandle{reinterpret_cast(hash | 1)}; +} + // Unloads module from the current context via cuModuleUnload. void UnloadRocmModule(Context* context, hipModule_t module) { ScopedActivateContext activated(context); @@ -671,10 +692,10 @@ absl::StatusOr> RocmExecutor::LoadKernel( const std::string& kernel_name = spec.kernel_name(); if (spec.has_cuda_cubin_in_memory()) { - const char* hsaco = reinterpret_cast( - spec.cuda_cubin_in_memory()->cubin_bytes.data()); + const auto& cubin = spec.cuda_cubin_in_memory()->cubin_bytes; + const char* hsaco = reinterpret_cast(cubin.data()); absl::MutexLock lock{in_memory_modules_mu_}; - ModuleHandle module_handle{hsaco}; + ModuleHandle module_handle = HsacoModuleHandle(hsaco, cubin.size()); hipModule_t& module = in_memory_modules_[module_handle]; if (module == nullptr) { @@ -749,16 +770,17 @@ absl::StatusOr RocmExecutor::LoadModule( // TODO(ROCm): Need generic term instead of cubin/cuda/ptx if (spec.has_cuda_cubin_in_memory()) { absl::MutexLock lock{in_memory_modules_mu_}; - return LoadModuleFromHsaco( - reinterpret_cast(spec.cuda_cubin_in_memory().data())); + const auto& cubin = spec.cuda_cubin_in_memory(); + return LoadModuleFromHsaco(reinterpret_cast(cubin.data()), + cubin.size()); } else { return absl::InternalError("No HASCO binary found"); } } absl::StatusOr RocmExecutor::LoadModuleFromHsaco( - const char* hsaco) { - ModuleHandle module_handle{hsaco}; + const char* hsaco, size_t size) { + ModuleHandle module_handle = HsacoModuleHandle(hsaco, size); uint64_t module_refcount; hipModule_t module; std::tie(module, module_refcount) = gpu_binary_to_module_[module_handle]; diff --git a/xla/stream_executor/rocm/rocm_executor.h b/xla/stream_executor/rocm/rocm_executor.h index cbf064795206c..f5406670a334b 100644 --- a/xla/stream_executor/rocm/rocm_executor.h +++ b/xla/stream_executor/rocm/rocm_executor.h @@ -140,7 +140,8 @@ class RocmExecutor : public GpuExecutor { absl::Status InitBlas(); // Loads a module in HSACO format. - absl::StatusOr LoadModuleFromHsaco(const char* hsaco) + absl::StatusOr LoadModuleFromHsaco(const char* hsaco, + size_t size) ABSL_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); bool UnloadGpuBinary(ModuleHandle module_handle)