Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions xla/stream_executor/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ xla_cc_test(

cc_library(
name = "rocm_context",
srcs = ["rocm_context.cc"],
hdrs = ["rocm_context.h"],
tags = [
"gpu",
Expand All @@ -61,14 +60,10 @@ cc_library(
":rocm_driver_wrapper",
":rocm_status",
"//xla/stream_executor/gpu:context",
"//xla/stream_executor/gpu:context_map",
"//xla/stream_executor/gpu:scoped_activate_context",
"//xla/tsl/platform:errors",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@local_config_rocm//rocm:rocm_headers",
],
)
Expand Down
223 changes: 0 additions & 223 deletions xla/stream_executor/rocm/rocm_context.cc

This file was deleted.

65 changes: 35 additions & 30 deletions xla/stream_executor/rocm/rocm_context.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include "absl/status/statusor.h"
/* Copyright 2023 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,56 +13,62 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// The ROCM-specific Driver library support, implementing the general Driver
// interface.
// On ROCm, hipCtx_t is a thin wrapper around a device ordinal and the entire
// context lifecycle (Retain/Release/SetCurrent/GetCurrent) is a no-op. AMD
// has deprecated every hipCtx* and hipDevicePrimaryCtx* API since ROCm 1.9
// with the recommendation to use hipSetDevice / hipGetDevice instead.
//
// RocmContext is a trivial implementation of the Context interface that
// delegates to hipSetDevice/hipGetDevice. It is intended to be owned as
// a plain value field inside RocmExecutor.

#ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_CONTEXT_H_
#define XLA_STREAM_EXECUTOR_ROCM_ROCM_CONTEXT_H_

#include <cstdint>

#include "absl/log/check.h"
#include "absl/status/status.h"
#include "rocm/include/hip/hip_runtime.h"
#include "xla/stream_executor/gpu/context.h"
#include "xla/stream_executor/gpu/context_map.h"
#include "xla/stream_executor/gpu/scoped_activate_context.h"
#include "xla/stream_executor/rocm/rocm_driver_wrapper.h"
#include "xla/stream_executor/rocm/rocm_status.h"
#include "xla/tsl/platform/errors.h"

namespace stream_executor::gpu {

// RocmContext implements the Context class for ROCm GPUs.
class RocmContext : public Context {
public:
RocmContext(hipCtx_t context, const int ordinal)
: context_(context), device_ordinal_(ordinal) {}
~RocmContext() override;
explicit RocmContext(int device_ordinal) : device_ordinal_(device_ordinal) {}
~RocmContext() override = default;

void SetActive() override {
CHECK_OK(
ToStatus(wrap::hipSetDevice(device_ordinal_), "Failed to set device"));
}

bool IsActive() const override {
int current_device;
if (wrap::hipGetDevice(&current_device) != hipSuccess) {
return false;
}
return current_device == device_ordinal_;
}

hipCtx_t context() const { return context_; }
void SetActive() override;
bool IsActive() const override;
int device_ordinal() const override { return device_ordinal_; }
absl::Status Synchronize() override;

// Disallow copying and moving.
absl::Status Synchronize() override {
ScopedActivateContext activation(this);
TF_RETURN_IF_ERROR(ToStatus(wrap::hipDeviceSynchronize(),
"could not synchronize on ROCM device"));
return absl::OkStatus();
}

RocmContext(RocmContext&&) = delete;
RocmContext(const RocmContext&) = delete;
RocmContext& operator=(RocmContext&&) = delete;
RocmContext& operator=(const RocmContext&) = delete;

// Returns the free amount of memory and total amount of memory, as reported
// by hipDeviceTotalMem.
bool GetDeviceMemoryUsage(int64_t* free_out, int64_t* total_out);

// Returns the total amount of memory available on the device.
static bool GetDeviceTotalMemory(hipDevice_t device, uint64_t* result);

// Returns the context map for all XLA-known ROCm contexts.
static ContextMap<hipCtx_t, RocmContext>* GetContextMap();

// Creates a new context for the given device.
static absl::StatusOr<RocmContext*> Create(int device_ordinal,
hipDevice_t device);

private:
hipCtx_t const context_;
const int device_ordinal_;
};

Expand Down
6 changes: 0 additions & 6 deletions xla/stream_executor/rocm/rocm_driver_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ namespace wrap {
// IMPORTANT: if you add a new HIP API to this list, please notify
// the rocm-profiler developers to track the API traces.
#define HIP_ROUTINE_EACH(__macro) \
__macro(hipCtxGetDevice) \
__macro(hipCtxSetCurrent) \
__macro(hipCtxEnablePeerAccess) \
__macro(hipDeviceCanAccessPeer) \
__macro(hipDeviceEnablePeerAccess) \
Expand All @@ -78,10 +76,6 @@ namespace wrap {
__macro(hipDeviceGetSharedMemConfig) \
__macro(hipDeviceGetStreamPriorityRange) \
__macro(hipDeviceGraphMemTrim) \
__macro(hipDevicePrimaryCtxGetState) \
__macro(hipDevicePrimaryCtxSetFlags) \
__macro(hipDevicePrimaryCtxRetain) \
__macro(hipDevicePrimaryCtxRelease) \
__macro(hipDeviceSetSharedMemConfig) \
__macro(hipDeviceSynchronize) \
__macro(hipDeviceTotalMem) \
Expand Down
Loading
Loading