Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/stream_executor/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/numeric:int128",
Expand Down
40 changes: 31 additions & 9 deletions xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ limitations under the License.

#include "xla/stream_executor/rocm/rocm_executor.h"

#include <unistd.h>

#include <algorithm>
#include <cstddef>
#include <cstdint>
Expand All @@ -30,6 +28,7 @@ limitations under the License.

#include "absl/base/casts.h"
#include "absl/container/inlined_vector.h"
#include "absl/hash/hash.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/numeric/int128.h"
Expand All @@ -47,6 +46,7 @@ limitations under the License.
#include "rocm/include/hip/hip_runtime.h"
#include "rocm/include/hip/hip_version.h"
#include "rocm/rocm_config.h"
#include <unistd.h>
#include "xla/stream_executor/activate_context.h"
#include "xla/stream_executor/blas.h"
#include "xla/stream_executor/command_buffer.h"
Expand Down Expand Up @@ -176,6 +176,17 @@ absl::StatusOr<hipFunction_t> GetModuleFunction(Context* context,
const char* kernel_name) {
ScopedActivateContext activated(context);
CHECK(module != nullptr && kernel_name != nullptr);
// Check for pre-existing HIP errors before the call. On ROCm 7+
// the per-thread error state is sticky: successful HIP calls do
// not clear it, so a stale error from a prior operation would
// produce confusing diagnostics if we proceeded.
hipError_t pre_err = ::hipPeekAtLastError();
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I am not sure we shall just clean up sticky error bit here. It may hide problems in other libraries. Like the recent one about rocprofiler collector using wrong device ID.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Pavel. intention here is to not to clear sticky error but get a peek at it and if there is an error dont proceed and error out here. if we simply proceed with this error we run into complex run to run variation issue under parallel loads. so I added more this check more of a guard.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok, I see, I have confused it wit hipGetLastError()

if (pre_err != hipSuccess) {
return absl::InternalError(
absl::StrCat("There was a HIP error before calling "
"hipModuleGetFunction for kernel '",
kernel_name, "': ", ToString(pre_err)));
}
hipFunction_t function;
TF_RETURN_IF_ERROR(
ToStatus(wrap::hipModuleGetFunction(&function, module, kernel_name),
Expand All @@ -197,6 +208,16 @@ absl::Status GetModuleSymbol(Context* context, hipModule_t module,
absl::StrCat("Failed to get symbol '", symbol_name, "'"));
}

// Compute a content-based ModuleHandle from HSACO bytes.
// Using a content hash instead of the raw data pointer avoids stale cache
// entries when an HSACO buffer is freed and a new one is allocated at the
// same address (pointer-reuse cache collision).
Copy link
Copy Markdown

@pemeliya pemeliya Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering how it happens that we have stale cache entries? I saw we have several flat_hash_maps in rocm_executor: in_memory_modules_, kernel_to_gpu_binary_ and gpu_binary_to_module_.

Which one was returning a stale cache entry? I see at least that UnloadGpuBinary() shall cleanup gpu_binary_to_module_ and in_memory_modules_. Though, the logic about erasing in_memory_modules_ is not easy to understand at first glance..

I mean, from my understanding, this shall not happen since GpuExecutable maintains a list of module handles which shall be automatically unloaded at the end.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes there is an issue with unload right now and it doesnt take all cases into account. I'm working on a fix , its not ready yet, might take couple more days.
But I agree with that fix we may not need hashing at all. I just discovered this yesterday we don't have to merge this PR now. I'll work on improved one.

but can we please not revert the upstream PR , without this fix we have very annoying jax bug in pytest run.

summary : this fix works but there is more to it, and I have found actual problem with unload an erase routine will work on it and come up with better fix.

cc: @i-chaochen

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think we need to have a proper fix for that.. I mean, here we use the result of absl::HashOf() as a unique key to identify hipModule_t object. Though, absl::Hash is very fast but prone to collisions (not like SHA256), so this could be a subtle bug which only shows up under a heavy load - when two different hipModules suddenly get the same hash key..

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @pemeliya fixed the root problem that surfaced as stale cache, I opened downstream PR for review feedback. if XLA team agrees with the fix I'll open PR upstream.
PR: #798
Can you please review at your convenience.

ModuleHandle HsacoModuleHandle(const char* hsaco, size_t size) {
auto hash = absl::HashOf(absl::string_view(hsaco, size));
// Ensure hash is never 0 (ModuleHandle treats nullptr as invalid)
return ModuleHandle{reinterpret_cast<const void*>(hash | 1)};
}

// Unloads module from the current context via cuModuleUnload.
void UnloadRocmModule(Context* context, hipModule_t module) {
ScopedActivateContext activated(context);
Expand Down Expand Up @@ -671,10 +692,10 @@ absl::StatusOr<std::unique_ptr<Kernel>> RocmExecutor::LoadKernel(
const std::string& kernel_name = spec.kernel_name();

if (spec.has_cuda_cubin_in_memory()) {
const char* hsaco = reinterpret_cast<const char*>(
spec.cuda_cubin_in_memory()->cubin_bytes.data());
const auto& cubin = spec.cuda_cubin_in_memory()->cubin_bytes;
const char* hsaco = reinterpret_cast<const char*>(cubin.data());
absl::MutexLock lock{in_memory_modules_mu_};
ModuleHandle module_handle{hsaco};
ModuleHandle module_handle = HsacoModuleHandle(hsaco, cubin.size());
hipModule_t& module = in_memory_modules_[module_handle];

if (module == nullptr) {
Expand Down Expand Up @@ -749,16 +770,17 @@ absl::StatusOr<ModuleHandle> RocmExecutor::LoadModule(
// TODO(ROCm): Need generic term instead of cubin/cuda/ptx
if (spec.has_cuda_cubin_in_memory()) {
absl::MutexLock lock{in_memory_modules_mu_};
return LoadModuleFromHsaco(
reinterpret_cast<const char*>(spec.cuda_cubin_in_memory().data()));
const auto& cubin = spec.cuda_cubin_in_memory();
return LoadModuleFromHsaco(reinterpret_cast<const char*>(cubin.data()),
cubin.size());
} else {
return absl::InternalError("No HASCO binary found");
}
}

absl::StatusOr<ModuleHandle> RocmExecutor::LoadModuleFromHsaco(
const char* hsaco) {
ModuleHandle module_handle{hsaco};
const char* hsaco, size_t size) {
ModuleHandle module_handle = HsacoModuleHandle(hsaco, size);
uint64_t module_refcount;
hipModule_t module;
std::tie(module, module_refcount) = gpu_binary_to_module_[module_handle];
Expand Down
3 changes: 2 additions & 1 deletion xla/stream_executor/rocm/rocm_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ class RocmExecutor : public GpuExecutor {
absl::Status InitBlas();

// Loads a module in HSACO format.
absl::StatusOr<ModuleHandle> LoadModuleFromHsaco(const char* hsaco)
absl::StatusOr<ModuleHandle> LoadModuleFromHsaco(const char* hsaco,
size_t size)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);

bool UnloadGpuBinary(ModuleHandle module_handle)
Expand Down