-
Notifications
You must be signed in to change notification settings - Fork 8
[ROCm] Cherry-pick: Fix HSACO module cache using pointer-based key causing stale lookups #780
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: rocm-jaxlib-v0.9.2
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,8 +15,6 @@ limitations under the License. | |
|
|
||
| #include "xla/stream_executor/rocm/rocm_executor.h" | ||
|
|
||
| #include <unistd.h> | ||
|
|
||
| #include <algorithm> | ||
| #include <cstddef> | ||
| #include <cstdint> | ||
|
|
@@ -30,6 +28,7 @@ limitations under the License. | |
|
|
||
| #include "absl/base/casts.h" | ||
| #include "absl/container/inlined_vector.h" | ||
| #include "absl/hash/hash.h" | ||
| #include "absl/log/check.h" | ||
| #include "absl/log/log.h" | ||
| #include "absl/numeric/int128.h" | ||
|
|
@@ -47,6 +46,7 @@ limitations under the License. | |
| #include "rocm/include/hip/hip_runtime.h" | ||
| #include "rocm/include/hip/hip_version.h" | ||
| #include "rocm/rocm_config.h" | ||
| #include <unistd.h> | ||
| #include "xla/stream_executor/activate_context.h" | ||
| #include "xla/stream_executor/blas.h" | ||
| #include "xla/stream_executor/command_buffer.h" | ||
|
|
@@ -176,6 +176,17 @@ absl::StatusOr<hipFunction_t> GetModuleFunction(Context* context, | |
| const char* kernel_name) { | ||
| ScopedActivateContext activated(context); | ||
| CHECK(module != nullptr && kernel_name != nullptr); | ||
| // Check for pre-existing HIP errors before the call. On ROCm 7+ | ||
| // the per-thread error state is sticky: successful HIP calls do | ||
| // not clear it, so a stale error from a prior operation would | ||
| // produce confusing diagnostics if we proceeded. | ||
| hipError_t pre_err = ::hipPeekAtLastError(); | ||
| if (pre_err != hipSuccess) { | ||
| return absl::InternalError( | ||
| absl::StrCat("There was a HIP error before calling " | ||
| "hipModuleGetFunction for kernel '", | ||
| kernel_name, "': ", ToString(pre_err))); | ||
| } | ||
| hipFunction_t function; | ||
| TF_RETURN_IF_ERROR( | ||
| ToStatus(wrap::hipModuleGetFunction(&function, module, kernel_name), | ||
|
|
@@ -197,6 +208,16 @@ absl::Status GetModuleSymbol(Context* context, hipModule_t module, | |
| absl::StrCat("Failed to get symbol '", symbol_name, "'")); | ||
| } | ||
|
|
||
| // Compute a content-based ModuleHandle from HSACO bytes. | ||
| // Using a content hash instead of the raw data pointer avoids stale cache | ||
| // entries when an HSACO buffer is freed and a new one is allocated at the | ||
| // same address (pointer-reuse cache collision). | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am wondering how it happens that we have stale cache entries? I saw we have several flat_hash_maps in rocm_executor: in_memory_modules_, kernel_to_gpu_binary_ and gpu_binary_to_module_. Which one was returning a stale cache entry? I see at least that UnloadGpuBinary() shall cleanup gpu_binary_to_module_ and in_memory_modules_. Though, the logic about erasing in_memory_modules_ is not easy to understand at first glance.. I mean, from my understanding, this shall not happen since GpuExecutable maintains a list of module handles which shall be automatically unloaded at the end.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes there is an issue with unload right now and it doesnt take all cases into account. I'm working on a fix , its not ready yet, might take couple more days. but can we please not revert the upstream PR , without this fix we have very annoying jax bug in pytest run. summary : this fix works but there is more to it, and I have found actual problem with unload an erase routine will work on it and come up with better fix. cc: @i-chaochen There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think we need to have a proper fix for that.. I mean, here we use the result of absl::HashOf() as a unique key to identify hipModule_t object. Though, absl::Hash is very fast but prone to collisions (not like SHA256), so this could be a subtle bug which only shows up under a heavy load - when two different hipModules suddenly get the same hash key..
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| ModuleHandle HsacoModuleHandle(const char* hsaco, size_t size) { | ||
| auto hash = absl::HashOf(absl::string_view(hsaco, size)); | ||
| // Ensure hash is never 0 (ModuleHandle treats nullptr as invalid) | ||
| return ModuleHandle{reinterpret_cast<const void*>(hash | 1)}; | ||
| } | ||
|
|
||
| // Unloads module from the current context via cuModuleUnload. | ||
| void UnloadRocmModule(Context* context, hipModule_t module) { | ||
| ScopedActivateContext activated(context); | ||
|
|
@@ -671,10 +692,10 @@ absl::StatusOr<std::unique_ptr<Kernel>> RocmExecutor::LoadKernel( | |
| const std::string& kernel_name = spec.kernel_name(); | ||
|
|
||
| if (spec.has_cuda_cubin_in_memory()) { | ||
| const char* hsaco = reinterpret_cast<const char*>( | ||
| spec.cuda_cubin_in_memory()->cubin_bytes.data()); | ||
| const auto& cubin = spec.cuda_cubin_in_memory()->cubin_bytes; | ||
| const char* hsaco = reinterpret_cast<const char*>(cubin.data()); | ||
| absl::MutexLock lock{in_memory_modules_mu_}; | ||
| ModuleHandle module_handle{hsaco}; | ||
| ModuleHandle module_handle = HsacoModuleHandle(hsaco, cubin.size()); | ||
| hipModule_t& module = in_memory_modules_[module_handle]; | ||
|
|
||
| if (module == nullptr) { | ||
|
|
@@ -749,16 +770,17 @@ absl::StatusOr<ModuleHandle> RocmExecutor::LoadModule( | |
| // TODO(ROCm): Need generic term instead of cubin/cuda/ptx | ||
| if (spec.has_cuda_cubin_in_memory()) { | ||
| absl::MutexLock lock{in_memory_modules_mu_}; | ||
| return LoadModuleFromHsaco( | ||
| reinterpret_cast<const char*>(spec.cuda_cubin_in_memory().data())); | ||
| const auto& cubin = spec.cuda_cubin_in_memory(); | ||
| return LoadModuleFromHsaco(reinterpret_cast<const char*>(cubin.data()), | ||
| cubin.size()); | ||
| } else { | ||
| return absl::InternalError("No HASCO binary found"); | ||
| } | ||
| } | ||
|
|
||
| absl::StatusOr<ModuleHandle> RocmExecutor::LoadModuleFromHsaco( | ||
| const char* hsaco) { | ||
| ModuleHandle module_handle{hsaco}; | ||
| const char* hsaco, size_t size) { | ||
| ModuleHandle module_handle = HsacoModuleHandle(hsaco, size); | ||
| uint64_t module_refcount; | ||
| hipModule_t module; | ||
| std::tie(module, module_refcount) = gpu_binary_to_module_[module_handle]; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I am not sure we shall just clean up sticky error bit here. It may hide problems in other libraries. Like the recent one about rocprofiler collector using wrong device ID.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Pavel. intention here is to not to clear sticky error but get a peek at it and if there is an error dont proceed and error out here. if we simply proceed with this error we run into complex run to run variation issue under parallel loads. so I added more this check more of a guard.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah ok, I see, I have confused it wit hipGetLastError()