Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ bazel_dep(name = "googletest", version = "1.16.0")
bazel_dep(name = "platforms", version = "0.0.11")
bazel_dep(name = "rules_cc", version = "0.1.1")
bazel_dep(name = "rules_python", version = "1.3.0")
bazel_dep(name = "bazel_skylib", version = "1.7.1")

python = use_extension("@rules_python//python/extensions:python.bzl", "python")
python.toolchain(
Expand All @@ -26,6 +27,10 @@ new_local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local.

local_torch = use_repo_rule("//toolchains:local_torch.bzl", "local_torch")

torch_nccl_detect = use_repo_rule("//toolchains/torch_nccl:defs.bzl", "torch_nccl_detect")

torch_nccl_detect(name = "torch_nccl")

# External dependency for torch_tensorrt if you already have precompiled binaries.
new_local_repository(
name = "torch_tensorrt",
Expand Down
5 changes: 4 additions & 1 deletion core/runtime/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_pkg//:pkg.bzl", "pkg_tar")
load("@rules_pkg//pkg:mappings.bzl", "pkg_files")
load("//toolchains/torch_nccl:defs.bzl", "if_torch_nccl")

package(default_visibility = ["//visibility:public"])

config_setting(
Expand Down Expand Up @@ -77,6 +79,7 @@ cc_library(
"TRTEngineProfiler.h",
"runtime.h",
],
copts = if_torch_nccl(["-DUSE_C10D_NCCL"]),
linkopts = [
"-lstdc++fs",
],
Expand Down Expand Up @@ -121,6 +124,6 @@ pkg_tar(
pkg_files(
name = "include_pkg_files",
srcs = [":include_files"],
visibility = ["//visibility:public"],
prefix = "include/torch_tensorrt/core/runtime/",
visibility = ["//visibility:public"],
)
160 changes: 159 additions & 1 deletion core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
#include "core/util/prelude.h"
#include "torch/torch.h"

#ifdef ENABLE_TRT_NCCL_COLLECTIVES
#include "torch/csrc/distributed/c10d/GroupRegistry.hpp"
#include "torch/csrc/distributed/c10d/NCCLUtils.hpp"
#include "torch/csrc/distributed/c10d/ProcessGroup.hpp"
#include "torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp"
#endif

namespace torch_tensorrt {
namespace core {
namespace runtime {
Expand Down Expand Up @@ -88,7 +95,15 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
serialized_info[SERIALIZED_METADATA_IDX],
(static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]))
? ResourceAllocationStrategy::kDynamic
: ResourceAllocationStrategy::kStatic)) {}
: ResourceAllocationStrategy::kStatic)) {
// Load distributed info if available (backward compatible with older ABI versions)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont need backwards compat unless some semantic definition changed, just bump the version

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rank and world_size fields are optional — they are only populated for distributed engines. The size/empty guards ensure backward compatibility with non-distributed engines and older ABI versions where these fields don't exist. Since these fields are not always present, bumping the ABI version would unnecessarily break non-distributed engines that don't use them.
Do you want something like we bump the ABI version and introduce a field to handle the rank and world_size fields? Something like IS_MD_ENGINE as you said? Would that not just be a different approach?

if (serialized_info.size() > RANK_IDX && !serialized_info[RANK_IDX].empty()) {
this->rank = std::stoll(serialized_info[RANK_IDX]);
}
if (serialized_info.size() > WORLD_SIZE_IDX && !serialized_info[WORLD_SIZE_IDX].empty()) {
this->world_size = std::stoll(serialized_info[WORLD_SIZE_IDX]);
}
}

TRTEngine::TRTEngine(
const std::string& mod_name,
Expand Down Expand Up @@ -519,6 +534,149 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt
}
}

void TRTEngine::set_rank(int64_t rank_val) {
this->rank = rank_val;
LOG_DEBUG("Rank set on TRTEngine: " << this->rank);
}

void TRTEngine::set_world_size(int64_t world_size_val) {
this->world_size = world_size_val;
LOG_DEBUG("World size set on TRTEngine: " << this->world_size);
}

#ifdef ENABLE_TRT_NCCL_COLLECTIVES
void TRTEngine::set_nccl_comm(int64_t comm_ptr) {
this->nccl_comm = reinterpret_cast<ncclComm_t>(comm_ptr);
LOG_DEBUG("NCCL communicator stored on TRTEngine (rank=" << this->rank << ")");

// Also set on TensorRT execution context
set_nccl_communicator_to_trt_context();
}

bool TRTEngine::set_nccl_communicator_to_trt_context() {
// Set NCCL communicator on TensorRT execution context
// The communicator should be set from Python via set_nccl_comm() or set_process_group()

if (!exec_ctx) {
LOG_ERROR("Cannot set NCCL communicator: execution context is null");
return false;
}

if (this->nccl_comm == nullptr) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it can be avoided, do not let anything be null. Should be a sentinel value or use a smart pointer

LOG_WARNING(
"Distributed inference enabled but no NCCL communicator set. "
"Call set_process_group() or set_nccl_comm() from Python first.");
return false;
}

// Set NCCL communicator on TensorRT execution context
try {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need like a real state machine or some true semantics here. Nothing else in the runtime just try catches

// Cast ncclComm_t to void* for TensorRT API
void* comm_ptr = static_cast<void*>(this->nccl_comm);

// Set the NCCL communicator on the execution context
// The device ID is used to identify which GPU's communicator this is
exec_ctx->setCommunicator(comm_ptr);

LOG_INFO(
"NCCL communicator set on TensorRT execution context "
"(rank="
<< this->rank << ", device=" << this->device_info.id << ")");
return true;
} catch (const std::exception& e) {
LOG_ERROR("Failed to set NCCL communicator on execution context: " << e.what());
return false;
}
}

void TRTEngine::init_nccl_comm(const std::string& group_name) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need a method that transparently calls another method?

// Use C++ registry to get NCCL communicator
set_process_group_from_registry(group_name);
}

bool TRTEngine::set_process_group_from_registry(const std::string& group_name) {
// Get ProcessGroup from C++ registry and extract NCCL communicator
// This avoids the need to pass the ProcessGroup from Python
LOG_INFO("TRTEngine::set_process_group_from_registry() called with group_name: " << group_name);
LOG_INFO(" Current rank: " << this->rank);
LOG_INFO(" Current world_size: " << this->world_size);
LOG_INFO(" Current device_id: " << this->device_info.id);

try {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, no nested try catches, you need to provide gaurentees about the state of the system.

// Resolve ProcessGroup from the native registry
auto pg = c10d::resolve_process_group(group_name);
if (!pg) {
LOG_ERROR("Failed to resolve ProcessGroup '" << group_name << "' from registry");
return false;
}
LOG_INFO(" Resolved ProcessGroup from registry: rank=" << pg->getRank() << ", size=" << pg->getSize());

// Update rank and world_size from the ProcessGroup if not already set
if (this->rank < 0) {
this->rank = pg->getRank();
LOG_INFO(" Set rank from ProcessGroup: " << this->rank);
}
if (this->world_size < 0) {
this->world_size = pg->getSize();
LOG_INFO(" Set world_size from ProcessGroup: " << this->world_size);
}

// Get the NCCL backend from the ProcessGroup
// ProcessGroup wraps Backend objects - we need to get the NCCL backend explicitly
c10::intrusive_ptr<c10d::Backend> backend;
try {
backend = pg->getBackend(c10d::ProcessGroup::BackendType::NCCL);
} catch (const std::exception& e) {
LOG_ERROR("Failed to get NCCL backend from ProcessGroup: " << e.what());
return false;
}

if (!backend) {
LOG_ERROR("ProcessGroup '" << group_name << "' does not have an NCCL backend");
return false;
}
LOG_INFO(" Got NCCL backend from ProcessGroup");

// Cast the backend to ProcessGroupNCCL
auto* nccl_pg = dynamic_cast<c10d::ProcessGroupNCCL*>(backend.get());
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to use smart pointers if possible

if (!nccl_pg) {
LOG_ERROR("Backend is not ProcessGroupNCCL (unexpected)");
return false;
}
LOG_INFO(" Successfully cast to ProcessGroupNCCL");

// Set current CUDA device to match the engine's device before getting comm
// getCommPtr() uses at::cuda::current_device() internally
at::cuda::set_device(this->device_info.id);
LOG_INFO(" Set current CUDA device to: " << this->device_info.id);

// Get NCCL comm pointer using the public getCommPtr() method
// This returns the communicator for the current CUDA device
int64_t comm_ptr = nccl_pg->getCommPtr();
if (comm_ptr == 0) {
LOG_ERROR(
"Failed to get NCCL communicator for device " << this->device_info.id
<< ". The communicator may not be initialized yet.");
LOG_ERROR("Hint: Ensure a collective operation has been performed on this device first.");
return false;
}

// Convert int64_t pointer to ncclComm_t
ncclComm_t comm = reinterpret_cast<ncclComm_t>(comm_ptr);

this->nccl_comm = comm;
LOG_INFO(" Successfully extracted NCCL communicator from registry");
LOG_INFO(" nccl_comm: " << (void*)this->nccl_comm);
// Set on TensorRT execution context
return True;

} catch (const std::exception& e) {
LOG_ERROR("Failed to get ProcessGroup from registry: " << e.what());
return false;
}
}
#endif // ENABLE_TRT_NCCL_COLLECTIVES

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
33 changes: 33 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,29 @@
#include "ATen/core/function_schema.h"
#include "ATen/cuda/CUDAGraph.h"
#include "NvInfer.h"
#include "NvInferVersion.h"
#include "c10/cuda/CUDAStream.h"
#include "torch/custom_class.h"

#include "core/runtime/TRTEngineProfiler.h"
#include "core/util/prelude.h"

// TensorRT 10.16+ has native NCCL collective support via IExecutionContext::setCommunicator()
#if NV_TENSORRT_MAJOR > 10 || (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR >= 16)
#define TRT_HAS_NATIVE_NCCL 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work for say jetson or windows?

#endif

// Full TRT NCCL collectives support requires both:
// 1. PyTorch built with NCCL (USE_C10D_NCCL defined via Bazel)
// 2. TensorRT 10.16+ (TRT_HAS_NATIVE_NCCL defined above)
#if defined(USE_C10D_NCCL) && defined(TRT_HAS_NATIVE_NCCL)
#define ENABLE_TRT_NCCL_COLLECTIVES 1
#endif

#ifdef ENABLE_TRT_NCCL_COLLECTIVES
#include <nccl.h>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are pulling nccl.h but this is not a bazel dep, where are you getting it from?

#endif

namespace torch_tensorrt {
namespace core {
namespace runtime {
Expand Down Expand Up @@ -196,6 +213,22 @@ struct TRTEngine : torch::CustomClassHolder {
bool use_output_allocator_outputs = false; // users specify to use output allocator
std::shared_ptr<DynamicOutputAllocator> output_allocator;

// Member variables for distributed inference (-1 indicates non-distributed mode)
int64_t rank = -1;
int64_t world_size = -1;

// Set rank and world_size for distributed inference
void set_rank(int64_t rank_val);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these apis always available but not the nccl ones, is there a way to use MD-TRT without the below apis?

void set_world_size(int64_t world_size_val);

#ifdef ENABLE_TRT_NCCL_COLLECTIVES
ncclComm_t nccl_comm = nullptr;
void set_nccl_comm(int64_t comm_ptr);
void init_nccl_comm(const std::string& group_name = "default");
bool set_process_group_from_registry(const std::string& group_name = "default");
bool set_nccl_communicator_to_trt_context();
#endif

// TODO: Implement a call method
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);

Expand Down
16 changes: 16 additions & 0 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,22 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->enqueue_profile_path);
}

// Distributed setup - set NCCL communicator on TensorRT execution context
#ifdef ENABLE_TRT_NCCL_COLLECTIVES
if (compiled_engine->rank >= 0 && compiled_engine->world_size > 1) {
bool result = compiled_engine->set_nccl_communicator_to_trt_context();
if (!result) {
LOG_ERROR("Failed to set NCCL communicator on TRT context");
LOG_ERROR("This will cause collective operations to fail at runtime");
LOG_ERROR("Make sure to call module.init_nccl_comm() after compilation");
}
} else {
LOG_DEBUG(
"Single-device mode (rank=" << compiled_engine->rank << ", world_size=" << compiled_engine->world_size
<< ") - skipping NCCL setup");
}
#endif

// Block engine stream until results are available on caller stream
at::cuda::CUDAEvent caller_exec_complete;
caller_exec_complete.record(compiled_engine->caller_stream);
Expand Down
21 changes: 21 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
&TRTEngine::set_device_memory_budget)
.def_property("streamable_device_memory_budget", &TRTEngine::get_streamable_device_memory_budget)
.def_property("automatic_device_memory_budget", &TRTEngine::get_automatic_device_memory_budget)
.def_readonly("rank", &TRTEngine::rank)
.def_readonly("world_size", &TRTEngine::world_size)
.def("set_rank", &TRTEngine::set_rank)
.def("set_world_size", &TRTEngine::set_world_size)
#ifdef ENABLE_TRT_NCCL_COLLECTIVES
.def("set_nccl_comm", &TRTEngine::set_nccl_comm)
.def(
"init_nccl_comm",
[](c10::intrusive_ptr<TRTEngine> self, std::string group_name = "default") {
self->init_nccl_comm(group_name);
})
#endif
.def_pickle(
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> { return self->serialize(); },
[](std::vector<std::string> serialized_info) -> c10::intrusive_ptr<TRTEngine> {
Expand Down Expand Up @@ -150,6 +162,15 @@ TORCH_LIBRARY(tensorrt, m) {
m.def("REQUIRES_OUTPUT_ALLOCATOR_IDX", []() -> int64_t { return REQUIRES_OUTPUT_ALLOCATOR_IDX; });
m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; });
m.def("RESOURCE_ALLOCATION_STRATEGY_IDX", []() -> int64_t { return RESOURCE_ALLOCATION_STRATEGY_IDX; });
m.def("RANK_IDX", []() -> int64_t { return RANK_IDX; });
m.def("WORLD_SIZE_IDX", []() -> int64_t { return WORLD_SIZE_IDX; });
m.def("NATIVE_TRT_COLLECTIVES_AVAIL", []() -> bool {
#ifdef ENABLE_TRT_NCCL_COLLECTIVES
return true;
#else
return false;
#endif
});
m.def("_platform_linux_x86_64", []() -> std::string {
auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64);
return it->second;
Expand Down
2 changes: 2 additions & 0 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ typedef enum {
TARGET_PLATFORM_IDX,
REQUIRES_OUTPUT_ALLOCATOR_IDX,
RESOURCE_ALLOCATION_STRATEGY_IDX,
RANK_IDX,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need a field that is like IS_MD_ENGINE that controls these optional values

WORLD_SIZE_IDX,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to bump the ABI version

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fields that are only sometimes applicable should be prefixed with OPTIONAL_ and need guards

SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
} SerializedInfoIndex;

Expand Down
Loading
Loading