diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 4ae09dda8254a..928a1cfdbf348 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -682,12 +682,9 @@ absl::StatusOr> RocmExecutor::LoadKernel( const auto& cubin = spec.cuda_cubin_in_memory()->cubin_bytes; const char* hsaco = reinterpret_cast(cubin.data()); absl::MutexLock lock{in_memory_modules_mu_}; - ModuleHandle module_handle = HsacoModuleHandle(hsaco, cubin.size()); - hipModule_t& module = in_memory_modules_[module_handle]; - - if (module == nullptr) { - TF_ASSIGN_OR_RETURN(module, LoadHsaco(&rocm_context_, hsaco)); - } + TF_ASSIGN_OR_RETURN(ModuleHandle module_handle, + LoadModuleFromHsaco(hsaco, cubin.size())); + hipModule_t module = gpu_binary_to_module_.at(module_handle).first; kernel_to_gpu_binary_[rocm_kernel.get()] = module_handle; VLOG(2) << "getting function " << kernel_name << " from module " << module;