Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ bazel_dep(name = "googletest", version = "1.16.0")
bazel_dep(name = "platforms", version = "0.0.11")
bazel_dep(name = "rules_cc", version = "0.1.1")
bazel_dep(name = "rules_python", version = "1.3.0")
bazel_dep(name = "bazel_skylib", version = "1.7.1")

python = use_extension("@rules_python//python/extensions:python.bzl", "python")
python.toolchain(
Expand All @@ -26,6 +27,10 @@ new_local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local.

local_torch = use_repo_rule("//toolchains:local_torch.bzl", "local_torch")

torch_nccl_detect = use_repo_rule("//toolchains/torch_nccl:defs.bzl", "torch_nccl_detect")

torch_nccl_detect(name = "torch_nccl")

# External dependency for torch_tensorrt if you already have precompiled binaries.
new_local_repository(
name = "torch_tensorrt",
Expand Down
5 changes: 4 additions & 1 deletion core/runtime/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_pkg//:pkg.bzl", "pkg_tar")
load("@rules_pkg//pkg:mappings.bzl", "pkg_files")
load("//toolchains/torch_nccl:defs.bzl", "if_torch_nccl")

package(default_visibility = ["//visibility:public"])

config_setting(
Expand Down Expand Up @@ -77,6 +79,7 @@ cc_library(
"TRTEngineProfiler.h",
"runtime.h",
],
copts = if_torch_nccl(["-DUSE_C10D_NCCL"]),
linkopts = [
"-lstdc++fs",
],
Expand Down Expand Up @@ -121,6 +124,6 @@ pkg_tar(
pkg_files(
name = "include_pkg_files",
srcs = [":include_files"],
visibility = ["//visibility:public"],
prefix = "include/torch_tensorrt/core/runtime/",
visibility = ["//visibility:public"],
)
76 changes: 75 additions & 1 deletion core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
#include "core/util/prelude.h"
#include "torch/torch.h"

#ifdef ENABLE_TRT_NCCL_COLLECTIVES
#include "torch/csrc/distributed/c10d/GroupRegistry.hpp"
#include "torch/csrc/distributed/c10d/NCCLUtils.hpp"
#include "torch/csrc/distributed/c10d/ProcessGroup.hpp"
#include "torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp"
#endif

namespace torch_tensorrt {
namespace core {
namespace runtime {
Expand Down Expand Up @@ -88,7 +95,14 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
serialized_info[SERIALIZED_METADATA_IDX],
(static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]))
? ResourceAllocationStrategy::kDynamic
: ResourceAllocationStrategy::kStatic)) {}
: ResourceAllocationStrategy::kStatic)) {
this->is_md = std::stoi(serialized_info[IS_MD_ENGINE_IDX]);
if (this->is_md) {
LOG_INFO(
"Loaded distributed engine (built on rank " << serialized_info[OPTIONAL_RANK_IDX] << " of "
<< serialized_info[OPTIONAL_WORLD_SIZE_IDX] << ")");
}
}

TRTEngine::TRTEngine(
const std::string& mod_name,
Expand Down Expand Up @@ -497,6 +511,11 @@ std::vector<std::string> TRTEngine::serialize() {
serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();
serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] =
this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0";
serialized_info[IS_MD_ENGINE_IDX] = this->is_md ? "1" : "0";
if (this->is_md) {
serialized_info[OPTIONAL_RANK_IDX] = std::to_string(this->rank);
serialized_info[OPTIONAL_WORLD_SIZE_IDX] = std::to_string(this->world_size);
}

return serialized_info;
}
Expand All @@ -519,6 +538,61 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt
}
}

#ifdef ENABLE_TRT_NCCL_COLLECTIVES
void TRTEngine::detect_distributed_context(const std::string& group_name) {
auto pg = c10d::resolve_process_group(group_name);
if (pg) {
this->rank = pg->getRank();
this->world_size = pg->getSize();
this->is_md = this->world_size > 1;
LOG_DEBUG("Detected distributed context: rank=" << this->rank << ", world_size=" << this->world_size);
}
}

void TRTEngine::setup_nccl_comm(const std::string& group_name) {
auto pg = c10d::resolve_process_group(group_name);
TORCHTRT_CHECK(pg != nullptr, "ProcessGroup '" << group_name << "' not found in registry");

// Set rank/world_size if not already set (e.g. load from disk without setup_engine)
if (this->rank < 0) {
this->rank = pg->getRank();
this->world_size = pg->getSize();
LOG_DEBUG("Set distributed context in setup_nccl_comm: rank=" << this->rank << ", world_size=" << this->world_size);
}

auto backend = pg->getBackend(c10d::ProcessGroup::BackendType::NCCL);
TORCHTRT_CHECK(backend != nullptr, "ProcessGroup '" << group_name << "' has no NCCL backend");

auto* nccl_pg = dynamic_cast<c10d::ProcessGroupNCCL*>(backend.get());
TORCHTRT_CHECK(nccl_pg != nullptr, "Backend is not ProcessGroupNCCL");

at::cuda::set_device(this->device_info.id);

int64_t comm_ptr = nccl_pg->getCommPtr();
TORCHTRT_CHECK(
comm_ptr != 0,
"NCCL communicator not initialized for device " << this->device_info.id
<< ". Ensure a collective operation has been performed first.");

this->nccl_comm = reinterpret_cast<void*>(comm_ptr);
set_nccl_communicator_to_trt_context();
LOG_INFO("NCCL comm set up (rank=" << this->rank << ", device=" << this->device_info.id << ")");
}

bool TRTEngine::set_nccl_communicator_to_trt_context() {
TORCHTRT_CHECK(exec_ctx != nullptr, "Cannot set NCCL communicator: execution context is null");
TORCHTRT_CHECK(this->nccl_comm != nullptr, "NCCL communicator is not set");

exec_ctx->setCommunicator(this->nccl_comm);

LOG_INFO(
"NCCL communicator set on TensorRT execution context "
"(rank="
<< this->rank << ", device=" << this->device_info.id << ")");
return true;
}
#endif // ENABLE_TRT_NCCL_COLLECTIVES

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
29 changes: 29 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,25 @@
#include "ATen/core/function_schema.h"
#include "ATen/cuda/CUDAGraph.h"
#include "NvInfer.h"
#include "NvInferVersion.h"
#include "c10/cuda/CUDAStream.h"
#include "torch/custom_class.h"

#include "core/runtime/TRTEngineProfiler.h"
#include "core/util/prelude.h"

// TensorRT 10.16+ has native NCCL collective support via IExecutionContext::setCommunicator()
#if NV_TENSORRT_MAJOR > 10 || (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR >= 16)
#define TRT_HAS_NATIVE_NCCL 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work for say jetson or windows?

#endif

// Full TRT NCCL collectives support requires both:
// 1. PyTorch built with NCCL (USE_C10D_NCCL defined via Bazel)
// 2. TensorRT 10.16+ (TRT_HAS_NATIVE_NCCL defined above)
#if defined(USE_C10D_NCCL) && defined(TRT_HAS_NATIVE_NCCL)
#define ENABLE_TRT_NCCL_COLLECTIVES 1
#endif

namespace torch_tensorrt {
namespace core {
namespace runtime {
Expand Down Expand Up @@ -196,6 +209,22 @@ struct TRTEngine : torch::CustomClassHolder {
bool use_output_allocator_outputs = false; // users specify to use output allocator
std::shared_ptr<DynamicOutputAllocator> output_allocator;

// Member variables for distributed inference (-1 indicates non-distributed mode)
bool is_md = false;
int64_t rank = -1;
int64_t world_size = -1;

#ifdef ENABLE_TRT_NCCL_COLLECTIVES
void* nccl_comm = nullptr;

// Detect rank and world_size from ProcessGroup
void detect_distributed_context(const std::string& group_name);

// Resolve ProcessGroup, get NCCL communicator, and bind to TRT context
void setup_nccl_comm(const std::string& group_name);
bool set_nccl_communicator_to_trt_context();
#endif

// TODO: Implement a call method
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);

Expand Down
8 changes: 8 additions & 0 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,14 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->enqueue_profile_path);
}

// Distributed setup - bind NCCL communicator to TRT execution context
// setup_nccl_comm must have been called from Python before first forward
#ifdef ENABLE_TRT_NCCL_COLLECTIVES
if (compiled_engine->is_md && compiled_engine->nccl_comm != nullptr) {
compiled_engine->set_nccl_communicator_to_trt_context();
}
#endif

// Block engine stream until results are available on caller stream
at::cuda::CUDAEvent caller_exec_complete;
caller_exec_complete.record(compiled_engine->caller_stream);
Expand Down
23 changes: 23 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,19 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
&TRTEngine::set_device_memory_budget)
.def_property("streamable_device_memory_budget", &TRTEngine::get_streamable_device_memory_budget)
.def_property("automatic_device_memory_budget", &TRTEngine::get_automatic_device_memory_budget)
.def_readonly("is_md", &TRTEngine::is_md)
.def_readonly("rank", &TRTEngine::rank)
.def_readonly("world_size", &TRTEngine::world_size)
#ifdef ENABLE_TRT_NCCL_COLLECTIVES
.def(
"detect_distributed_context",
[](c10::intrusive_ptr<TRTEngine> self, std::string group_name) {
self->detect_distributed_context(group_name);
})
.def(
"setup_nccl_comm",
[](c10::intrusive_ptr<TRTEngine> self, std::string group_name) { self->setup_nccl_comm(group_name); })
#endif
.def_pickle(
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> { return self->serialize(); },
[](std::vector<std::string> serialized_info) -> c10::intrusive_ptr<TRTEngine> {
Expand Down Expand Up @@ -150,6 +163,16 @@ TORCH_LIBRARY(tensorrt, m) {
m.def("REQUIRES_OUTPUT_ALLOCATOR_IDX", []() -> int64_t { return REQUIRES_OUTPUT_ALLOCATOR_IDX; });
m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; });
m.def("RESOURCE_ALLOCATION_STRATEGY_IDX", []() -> int64_t { return RESOURCE_ALLOCATION_STRATEGY_IDX; });
m.def("IS_MD_ENGINE_IDX", []() -> int64_t { return IS_MD_ENGINE_IDX; });
m.def("OPTIONAL_RANK_IDX", []() -> int64_t { return OPTIONAL_RANK_IDX; });
m.def("OPTIONAL_WORLD_SIZE_IDX", []() -> int64_t { return OPTIONAL_WORLD_SIZE_IDX; });
m.def("NATIVE_TRT_COLLECTIVES_AVAIL", []() -> bool {
#ifdef ENABLE_TRT_NCCL_COLLECTIVES
return true;
#else
return false;
#endif
});
m.def("_platform_linux_x86_64", []() -> std::string {
auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64);
return it->second;
Expand Down
5 changes: 4 additions & 1 deletion core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace core {
namespace runtime {

using EngineID = int64_t;
const std::string ABI_VERSION = "8";
const std::string ABI_VERSION = "9";
extern bool MULTI_DEVICE_SAFE_MODE;

typedef enum {
Expand All @@ -39,6 +39,9 @@ typedef enum {
TARGET_PLATFORM_IDX,
REQUIRES_OUTPUT_ALLOCATOR_IDX,
RESOURCE_ALLOCATION_STRATEGY_IDX,
IS_MD_ENGINE_IDX,
OPTIONAL_RANK_IDX,
OPTIONAL_WORLD_SIZE_IDX,
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
} SerializedInfoIndex;

Expand Down
Loading
Loading