From 6152d3a6b4befcd94a7fae54bcbbd09edb59c6b3 Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 1 Apr 2026 16:46:58 -0700 Subject: [PATCH 1/7] Multi-Device TensorRT Runtime with Native NCCL Collectives - C++ runtime: NCCL communicator init via c10d, rank/world_size serialization, DynamicOutputAllocator, ABI version bump to 8 - Python runtime: distributed support in PythonTorchTensorRTModule and TorchTensorRTModule, NCCL library auto-detection - Conversion: native TRT DistCollective API (AllGather, ReduceScatter, AllReduce) with TRT-LLM plugin fallback - Graph lowering: fuse c10d_functional collectives + wait_tensor into single ops - Feature detection: native_trt_collectives flag, platform validation, graceful fallback chain - Build: conditional NCCL compilation via torch_nccl toolchain - Examples: tensor_parallel_simple_example.py, tensor_parallel_llama_llm.py --- MODULE.bazel | 5 + core/runtime/BUILD | 5 +- core/runtime/TRTEngine.cpp | 160 +++++++- core/runtime/TRTEngine.h | 33 ++ core/runtime/execute_engine.cpp | 16 + core/runtime/register_jit_hooks.cpp | 21 ++ core/runtime/runtime.h | 2 + .../tensor_parallel_simple_example.py | 117 ++++-- py/torch_tensorrt/_features.py | 16 +- py/torch_tensorrt/_utils.py | 33 ++ .../dynamo/conversion/_TRTInterpreter.py | 8 +- .../dynamo/conversion/_conversion.py | 42 ++- .../conversion/custom_ops_converters.py | 76 +++- .../dynamo/conversion/impl/nccl_ops.py | 351 +++++++++++++++++- .../lowering/passes/fuse_distributed_ops.py | 24 +- .../runtime/_PythonTorchTensorRTModule.py | 257 ++++++++++++- .../dynamo/runtime/_TorchTensorRTModule.py | 170 ++++++++- py/torch_tensorrt/dynamo/runtime/__init__.py | 5 + .../dynamo/runtime/_nccl_utils.py | 176 +++++++++ third_party/libtorch/BUILD | 2 + toolchains/torch_nccl/BUILD | 1 + toolchains/torch_nccl/defs.bzl | 60 +++ tools/llm/tensor_parallel_llama_llm.py | 340 +++++++++++++++++ 23 files changed, 1861 insertions(+), 59 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/runtime/_nccl_utils.py create mode 100644 toolchains/torch_nccl/BUILD create mode 100644 toolchains/torch_nccl/defs.bzl create mode 100644 tools/llm/tensor_parallel_llama_llm.py diff --git a/MODULE.bazel b/MODULE.bazel index cd557d8233..333429ec65 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -8,6 +8,7 @@ bazel_dep(name = "googletest", version = "1.16.0") bazel_dep(name = "platforms", version = "0.0.11") bazel_dep(name = "rules_cc", version = "0.1.1") bazel_dep(name = "rules_python", version = "1.3.0") +bazel_dep(name = "bazel_skylib", version = "1.7.1") python = use_extension("@rules_python//python/extensions:python.bzl", "python") python.toolchain( @@ -26,6 +27,10 @@ new_local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local. local_torch = use_repo_rule("//toolchains:local_torch.bzl", "local_torch") +torch_nccl_detect = use_repo_rule("//toolchains/torch_nccl:defs.bzl", "torch_nccl_detect") + +torch_nccl_detect(name = "torch_nccl") + # External dependency for torch_tensorrt if you already have precompiled binaries. new_local_repository( name = "torch_tensorrt", diff --git a/core/runtime/BUILD b/core/runtime/BUILD index 19260149ae..5fd06e7150 100644 --- a/core/runtime/BUILD +++ b/core/runtime/BUILD @@ -1,6 +1,8 @@ load("@rules_cc//cc:defs.bzl", "cc_library") load("@rules_pkg//:pkg.bzl", "pkg_tar") load("@rules_pkg//pkg:mappings.bzl", "pkg_files") +load("//toolchains/torch_nccl:defs.bzl", "if_torch_nccl") + package(default_visibility = ["//visibility:public"]) config_setting( @@ -77,6 +79,7 @@ cc_library( "TRTEngineProfiler.h", "runtime.h", ], + copts = if_torch_nccl(["-DUSE_C10D_NCCL"]), linkopts = [ "-lstdc++fs", ], @@ -121,6 +124,6 @@ pkg_tar( pkg_files( name = "include_pkg_files", srcs = [":include_files"], - visibility = ["//visibility:public"], prefix = "include/torch_tensorrt/core/runtime/", + visibility = ["//visibility:public"], ) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index d29daa112b..6937cbfa33 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -10,6 +10,13 @@ #include "core/util/prelude.h" #include "torch/torch.h" +#ifdef ENABLE_TRT_NCCL_COLLECTIVES +#include "torch/csrc/distributed/c10d/GroupRegistry.hpp" +#include "torch/csrc/distributed/c10d/NCCLUtils.hpp" +#include "torch/csrc/distributed/c10d/ProcessGroup.hpp" +#include "torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp" +#endif + namespace torch_tensorrt { namespace core { namespace runtime { @@ -88,7 +95,15 @@ TRTEngine::TRTEngine(std::vector serialized_info) serialized_info[SERIALIZED_METADATA_IDX], (static_cast(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic - : ResourceAllocationStrategy::kStatic)) {} + : ResourceAllocationStrategy::kStatic)) { + // Load distributed info if available (backward compatible with older ABI versions) + if (serialized_info.size() > RANK_IDX && !serialized_info[RANK_IDX].empty()) { + this->rank = std::stoll(serialized_info[RANK_IDX]); + } + if (serialized_info.size() > WORLD_SIZE_IDX && !serialized_info[WORLD_SIZE_IDX].empty()) { + this->world_size = std::stoll(serialized_info[WORLD_SIZE_IDX]); + } +} TRTEngine::TRTEngine( const std::string& mod_name, @@ -519,6 +534,149 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt } } +void TRTEngine::set_rank(int64_t rank_val) { + this->rank = rank_val; + LOG_DEBUG("Rank set on TRTEngine: " << this->rank); +} + +void TRTEngine::set_world_size(int64_t world_size_val) { + this->world_size = world_size_val; + LOG_DEBUG("World size set on TRTEngine: " << this->world_size); +} + +#ifdef ENABLE_TRT_NCCL_COLLECTIVES +void TRTEngine::set_nccl_comm(int64_t comm_ptr) { + this->nccl_comm = reinterpret_cast(comm_ptr); + LOG_DEBUG("NCCL communicator stored on TRTEngine (rank=" << this->rank << ")"); + + // Also set on TensorRT execution context + set_nccl_communicator_to_trt_context(); +} + +bool TRTEngine::set_nccl_communicator_to_trt_context() { + // Set NCCL communicator on TensorRT execution context + // The communicator should be set from Python via set_nccl_comm() or set_process_group() + + if (!exec_ctx) { + LOG_ERROR("Cannot set NCCL communicator: execution context is null"); + return false; + } + + if (this->nccl_comm == nullptr) { + LOG_WARNING( + "Distributed inference enabled but no NCCL communicator set. " + "Call set_process_group() or set_nccl_comm() from Python first."); + return false; + } + + // Set NCCL communicator on TensorRT execution context + try { + // Cast ncclComm_t to void* for TensorRT API + void* comm_ptr = static_cast(this->nccl_comm); + + // Set the NCCL communicator on the execution context + // The device ID is used to identify which GPU's communicator this is + exec_ctx->setCommunicator(comm_ptr); + + LOG_INFO( + "NCCL communicator set on TensorRT execution context " + "(rank=" + << this->rank << ", device=" << this->device_info.id << ")"); + return true; + } catch (const std::exception& e) { + LOG_ERROR("Failed to set NCCL communicator on execution context: " << e.what()); + return false; + } +} + +void TRTEngine::init_nccl_comm(const std::string& group_name) { + // Use C++ registry to get NCCL communicator + set_process_group_from_registry(group_name); +} + +bool TRTEngine::set_process_group_from_registry(const std::string& group_name) { + // Get ProcessGroup from C++ registry and extract NCCL communicator + // This avoids the need to pass the ProcessGroup from Python + LOG_INFO("TRTEngine::set_process_group_from_registry() called with group_name: " << group_name); + LOG_INFO(" Current rank: " << this->rank); + LOG_INFO(" Current world_size: " << this->world_size); + LOG_INFO(" Current device_id: " << this->device_info.id); + + try { + // Resolve ProcessGroup from the native registry + auto pg = c10d::resolve_process_group(group_name); + if (!pg) { + LOG_ERROR("Failed to resolve ProcessGroup '" << group_name << "' from registry"); + return false; + } + LOG_INFO(" Resolved ProcessGroup from registry: rank=" << pg->getRank() << ", size=" << pg->getSize()); + + // Update rank and world_size from the ProcessGroup if not already set + if (this->rank < 0) { + this->rank = pg->getRank(); + LOG_INFO(" Set rank from ProcessGroup: " << this->rank); + } + if (this->world_size < 0) { + this->world_size = pg->getSize(); + LOG_INFO(" Set world_size from ProcessGroup: " << this->world_size); + } + + // Get the NCCL backend from the ProcessGroup + // ProcessGroup wraps Backend objects - we need to get the NCCL backend explicitly + c10::intrusive_ptr backend; + try { + backend = pg->getBackend(c10d::ProcessGroup::BackendType::NCCL); + } catch (const std::exception& e) { + LOG_ERROR("Failed to get NCCL backend from ProcessGroup: " << e.what()); + return false; + } + + if (!backend) { + LOG_ERROR("ProcessGroup '" << group_name << "' does not have an NCCL backend"); + return false; + } + LOG_INFO(" Got NCCL backend from ProcessGroup"); + + // Cast the backend to ProcessGroupNCCL + auto* nccl_pg = dynamic_cast(backend.get()); + if (!nccl_pg) { + LOG_ERROR("Backend is not ProcessGroupNCCL (unexpected)"); + return false; + } + LOG_INFO(" Successfully cast to ProcessGroupNCCL"); + + // Set current CUDA device to match the engine's device before getting comm + // getCommPtr() uses at::cuda::current_device() internally + at::cuda::set_device(this->device_info.id); + LOG_INFO(" Set current CUDA device to: " << this->device_info.id); + + // Get NCCL comm pointer using the public getCommPtr() method + // This returns the communicator for the current CUDA device + int64_t comm_ptr = nccl_pg->getCommPtr(); + if (comm_ptr == 0) { + LOG_ERROR( + "Failed to get NCCL communicator for device " << this->device_info.id + << ". The communicator may not be initialized yet."); + LOG_ERROR("Hint: Ensure a collective operation has been performed on this device first."); + return false; + } + + // Convert int64_t pointer to ncclComm_t + ncclComm_t comm = reinterpret_cast(comm_ptr); + + this->nccl_comm = comm; + LOG_INFO(" Successfully extracted NCCL communicator from registry"); + LOG_INFO(" nccl_comm: " << (void*)this->nccl_comm); + // Set on TensorRT execution context + return True; + + } catch (const std::exception& e) { + LOG_ERROR("Failed to get ProcessGroup from registry: " << e.what()); + return false; + } +} +#endif // ENABLE_TRT_NCCL_COLLECTIVES + } // namespace runtime } // namespace core } // namespace torch_tensorrt diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 363631863f..eb7e1f46d4 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -9,12 +9,29 @@ #include "ATen/core/function_schema.h" #include "ATen/cuda/CUDAGraph.h" #include "NvInfer.h" +#include "NvInferVersion.h" #include "c10/cuda/CUDAStream.h" #include "torch/custom_class.h" #include "core/runtime/TRTEngineProfiler.h" #include "core/util/prelude.h" +// TensorRT 10.16+ has native NCCL collective support via IExecutionContext::setCommunicator() +#if NV_TENSORRT_MAJOR > 10 || (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR >= 16) +#define TRT_HAS_NATIVE_NCCL 1 +#endif + +// Full TRT NCCL collectives support requires both: +// 1. PyTorch built with NCCL (USE_C10D_NCCL defined via Bazel) +// 2. TensorRT 10.16+ (TRT_HAS_NATIVE_NCCL defined above) +#if defined(USE_C10D_NCCL) && defined(TRT_HAS_NATIVE_NCCL) +#define ENABLE_TRT_NCCL_COLLECTIVES 1 +#endif + +#ifdef ENABLE_TRT_NCCL_COLLECTIVES +#include +#endif + namespace torch_tensorrt { namespace core { namespace runtime { @@ -196,6 +213,22 @@ struct TRTEngine : torch::CustomClassHolder { bool use_output_allocator_outputs = false; // users specify to use output allocator std::shared_ptr output_allocator; + // Member variables for distributed inference (-1 indicates non-distributed mode) + int64_t rank = -1; + int64_t world_size = -1; + + // Set rank and world_size for distributed inference + void set_rank(int64_t rank_val); + void set_world_size(int64_t world_size_val); + +#ifdef ENABLE_TRT_NCCL_COLLECTIVES + ncclComm_t nccl_comm = nullptr; + void set_nccl_comm(int64_t comm_ptr); + void init_nccl_comm(const std::string& group_name = "default"); + bool set_process_group_from_registry(const std::string& group_name = "default"); + bool set_nccl_communicator_to_trt_context(); +#endif + // TODO: Implement a call method // c10::List Run(c10::List inputs); diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 553469392b..4868b092f4 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -311,6 +311,22 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->enqueue_profile_path); } + // Distributed setup - set NCCL communicator on TensorRT execution context +#ifdef ENABLE_TRT_NCCL_COLLECTIVES + if (compiled_engine->rank >= 0 && compiled_engine->world_size > 1) { + bool result = compiled_engine->set_nccl_communicator_to_trt_context(); + if (!result) { + LOG_ERROR("Failed to set NCCL communicator on TRT context"); + LOG_ERROR("This will cause collective operations to fail at runtime"); + LOG_ERROR("Make sure to call module.init_nccl_comm() after compilation"); + } + } else { + LOG_DEBUG( + "Single-device mode (rank=" << compiled_engine->rank << ", world_size=" << compiled_engine->world_size + << ") - skipping NCCL setup"); + } +#endif + // Block engine stream until results are available on caller stream at::cuda::CUDAEvent caller_exec_complete; caller_exec_complete.record(compiled_engine->caller_stream); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index e8f6217a21..ffae7c7455 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -108,6 +108,18 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = &TRTEngine::set_device_memory_budget) .def_property("streamable_device_memory_budget", &TRTEngine::get_streamable_device_memory_budget) .def_property("automatic_device_memory_budget", &TRTEngine::get_automatic_device_memory_budget) + .def_readonly("rank", &TRTEngine::rank) + .def_readonly("world_size", &TRTEngine::world_size) + .def("set_rank", &TRTEngine::set_rank) + .def("set_world_size", &TRTEngine::set_world_size) +#ifdef ENABLE_TRT_NCCL_COLLECTIVES + .def("set_nccl_comm", &TRTEngine::set_nccl_comm) + .def( + "init_nccl_comm", + [](c10::intrusive_ptr self, std::string group_name = "default") { + self->init_nccl_comm(group_name); + }) +#endif .def_pickle( [](const c10::intrusive_ptr& self) -> std::vector { return self->serialize(); }, [](std::vector serialized_info) -> c10::intrusive_ptr { @@ -150,6 +162,15 @@ TORCH_LIBRARY(tensorrt, m) { m.def("REQUIRES_OUTPUT_ALLOCATOR_IDX", []() -> int64_t { return REQUIRES_OUTPUT_ALLOCATOR_IDX; }); m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; }); m.def("RESOURCE_ALLOCATION_STRATEGY_IDX", []() -> int64_t { return RESOURCE_ALLOCATION_STRATEGY_IDX; }); + m.def("RANK_IDX", []() -> int64_t { return RANK_IDX; }); + m.def("WORLD_SIZE_IDX", []() -> int64_t { return WORLD_SIZE_IDX; }); + m.def("NATIVE_TRT_COLLECTIVES_AVAIL", []() -> bool { +#ifdef ENABLE_TRT_NCCL_COLLECTIVES + return true; +#else + return false; +#endif + }); m.def("_platform_linux_x86_64", []() -> std::string { auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64); return it->second; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index d8f71683d3..61e4362289 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -39,6 +39,8 @@ typedef enum { TARGET_PLATFORM_IDX, REQUIRES_OUTPUT_ALLOCATOR_IDX, RESOURCE_ALLOCATION_STRATEGY_IDX, + RANK_IDX, + WORLD_SIZE_IDX, SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index f2dc6861cb..298862f3eb 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -16,27 +16,48 @@ ----- .. code-block:: bash - mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py + # JIT mode python runtime + mpirun -n 2 python tensor_parallel_simple_example.py --mode jit_cpp + + # JIT mode cpp runtime + mpirun -n 2 python tensor_parallel_simple_example.py --mode jit_python + + WIP: Export and load mode + mpirun -n 2 python tensor_parallel_simple_example.py --mode export --save-path /tmp/tp_model.ep + mpirun -n 2 python tensor_parallel_simple_example.py --mode load --save-path /tmp/tp_model.ep + """ +import argparse import time -import tensorrt as trt import torch import torch.distributed as dist import torch.nn as nn +import torch.utils._pytree from tensor_parallel_initialize_dist import ( cleanup_distributed_env, - get_tensor_parallel_device_mesh, initialize_distributed_env, ) -# Initialize distributed environment and logger BEFORE importing torch_tensorrt -# This ensures logging is configured before any import-time log messages +torch.utils._pytree.register_constant( + torch.distributed.tensor._dtensor_spec.DTensorSpec +) + +parser = argparse.ArgumentParser(description="Tensor Parallel Simple Example") +parser.add_argument( + "--mode", + type=str, + choices=["jit_python", "jit_cpp", "export", "load"], + default="jit_python", +) +parser.add_argument("--save-path", type=str, default="/tmp/tp_model.ep") +args = parser.parse_args() + device_mesh, _world_size, _rank, logger = initialize_distributed_env( "tensor_parallel_simple_example" ) - +import torch_tensorrt from torch.distributed._tensor import Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, @@ -92,29 +113,65 @@ def forward(self, x): inp = torch.rand(20, 10, device="cuda") python_result = tp_model(inp) -backend = "torch_tensorrt" -tp_model = torch.compile( - tp_model, - backend=backend, - options={ - "truncate_long_and_double": True, - "enabled_precisions": {torch.float32, torch.float16}, - "use_python_runtime": True, - "min_block_size": 1, - "use_distributed_mode_trace": True, - }, - dynamic=None, -) - -# For TP, input needs to be same across all TP ranks. -# Setting the random seed is to mimic the behavior of dataloader. -torch.manual_seed(0) -inp = torch.rand(20, 10, device="cuda") -start = time.time() -output = tp_model(inp) -end = time.time() -logger.info(f"Compilation time is {end - start}") -assert (python_result - output).std() < 0.01, "Result is not correct." +if args.mode == "load": + # Load per-rank model: /tmp/tp_model.ep -> /tmp/tp_model_rank0_of_2.ep + logger.info(f"Loading from {args.save_path}") + loaded_model = torch_tensorrt.load(args.save_path) + output = loaded_model(inp) + assert (python_result - output).std() < 0.01, "Result mismatch" + logger.info("Load successful!") + +elif args.mode == "jit_python": + trt_model = torch.compile( + tp_model, + backend="torch_tensorrt", + options={ + "truncate_long_and_double": True, + "enabled_precisions": {torch.float32, torch.float16}, + "use_python_runtime": True, + "min_block_size": 1, + "use_distributed_mode_trace": True, + }, + ) + output = trt_model(inp) + assert (python_result - output).std() < 0.01, "Result mismatch" + logger.info("JIT compile successful!") + +elif args.mode == "jit_cpp": + trt_model = torch.compile( + tp_model, + backend="torch_tensorrt", + options={ + "truncate_long_and_double": True, + "enabled_precisions": {torch.float32, torch.float16}, + "use_python_runtime": False, + "min_block_size": 1, + "use_distributed_mode_trace": True, + }, + ) + output = trt_model(inp) + assert (python_result - output).std() < 0.01, "Result mismatch" + logger.info("JIT compile successful!") + +elif args.mode == "export": + # Export: torch.export + dynamo.compile - AOT compilation, can save + exported_program = torch.export.export(tp_model, (inp,), strict=False) + trt_model = torch_tensorrt.dynamo.compile( + exported_program, + inputs=[inp], + # enabled_precisions={torch.float32, torch.float16}, + truncate_double=True, + use_python_runtime=True, + min_block_size=1, + use_distributed_mode_trace=True, + ) + output = trt_model(inp) + assert (python_result - output).std() < 0.01, "Result mismatch" + + # Save per-rank: /tmp/tp_model.ep -> /tmp/tp_model_rank0_of_2.ep + save_path = torch_tensorrt.save(trt_model, args.save_path, inputs=[inp]) + logger.info(f"Saved to {save_path}") + dist.barrier() -# This cleans up the distributed process group cleanup_distributed_env() +logger.info("Done!") diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py index 318fc79461..1f7c389d8c 100644 --- a/py/torch_tensorrt/_features.py +++ b/py/torch_tensorrt/_features.py @@ -7,6 +7,7 @@ import tensorrt from torch_tensorrt._utils import ( check_cross_compile_trt_win_lib, + check_native_trt_collectives, load_tensorrt_llm_for_nccl, sanitized_torch_version, ) @@ -25,6 +26,7 @@ "windows_cross_compile", "tensorrt_rtx", "trtllm_for_nccl", + "native_trt_collectives", ], ) @@ -50,7 +52,16 @@ _FX_FE_AVAIL = False if _TENSORRT_RTX else True _REFIT_AVAIL = True _WINDOWS_CROSS_COMPILE = check_cross_compile_trt_win_lib() -_TRTLLM_AVAIL = load_tensorrt_llm_for_nccl() + +# Check if native TRT collectives are available (TRT 10.16+ with NCCL) +_NATIVE_TRT_COLLECTIVES_AVAIL = check_native_trt_collectives( + linked_file_full_path, linked_file_runtime_full_path +) + +# Only load TRT-LLM for NCCL if native TRT collectives are not available +_TRTLLM_AVAIL = False +if not _NATIVE_TRT_COLLECTIVES_AVAIL: + _TRTLLM_AVAIL = load_tensorrt_llm_for_nccl() if _TENSORRT_RTX: @@ -78,6 +89,7 @@ _WINDOWS_CROSS_COMPILE, _TENSORRT_RTX, _TRTLLM_AVAIL, + _NATIVE_TRT_COLLECTIVES_AVAIL, ) T = TypeVar("T") @@ -85,7 +97,7 @@ def _enabled_features_str() -> str: enabled = lambda x: "ENABLED" if x else "DISABLED" - out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n - TensorRT-LLM for NCCL: {enabled(_TRTLLM_AVAIL)}\n" # type: ignore[no-untyped-call] + out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n - TensorRT-LLM for NCCL: {enabled(_TRTLLM_AVAIL)}\n - Native TRT Collectives: {enabled(_NATIVE_TRT_COLLECTIVES_AVAIL)}\n" # type: ignore[no-untyped-call] return out_str diff --git a/py/torch_tensorrt/_utils.py b/py/torch_tensorrt/_utils.py index 20521590ba..9b2993b56d 100644 --- a/py/torch_tensorrt/_utils.py +++ b/py/torch_tensorrt/_utils.py @@ -299,6 +299,39 @@ def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: return False +def check_native_trt_collectives( + torchtrt_lib_path: Optional[str] = None, + torchtrt_runtime_lib_path: Optional[str] = None, +) -> bool: + """ + Check if native TRT collectives are available (TRT 10.16+ with NCCL). + + This function loads the torch_tensorrt runtime library to register torch ops, + then calls NATIVE_TRT_COLLECTIVES_AVAIL() to check if the runtime was compiled + with NCCL support and TensorRT 10.16+. + + Args: + torchtrt_lib_path: Path to libtorchtrt.so (full library) + torchtrt_runtime_lib_path: Path to libtorchtrt_runtime.so (runtime-only library) + + Returns: + bool: True if native TRT collectives are available, False otherwise. + """ + try: + # Load the runtime library to register torch ops + # Prefer full library if available, otherwise runtime-only + if torchtrt_lib_path and os.path.isfile(torchtrt_lib_path): + torch.ops.load_library(torchtrt_lib_path) + elif torchtrt_runtime_lib_path and os.path.isfile(torchtrt_runtime_lib_path): + torch.ops.load_library(torchtrt_runtime_lib_path) + else: + return False + + return bool(torch.ops.tensorrt.NATIVE_TRT_COLLECTIVES_AVAIL()) + except Exception: + return False + + def load_tensorrt_llm_for_nccl() -> bool: """ Attempts to load the TensorRT-LLM plugin and initialize it. diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index f04836c52c..ba0cc90c21 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -225,6 +225,12 @@ def _populate_trt_builder_config( ) -> trt.IBuilderConfig: builder_config = self.builder.create_builder_config() + if ENABLED_FEATURES.native_trt_collectives: + _LOGGER.info("Using native TRT collectives") + builder_config.set_preview_feature( + trt.PreviewFeature.MULTIDEVICE_RUNTIME_10_16, True + ) + if self._debugger_config and self._debugger_config.engine_builder_monitor: builder_config.progress_monitor = TRTBulderMonitor() @@ -453,7 +459,7 @@ def check_weight_equal( except Exception: return torch.all(sd_weight == network_weight) - @needs_refit # type: ignore[misc] + @needs_refit def _save_weight_mapping(self) -> None: """ Construct the weight name mapping from engine weight name to state_dict weight name. diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index e47d3f404f..1be2b92520 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -4,6 +4,7 @@ import logging from typing import Any, Dict, List, NamedTuple, Optional, Sequence +import tensorrt as trt import torch from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES @@ -25,8 +26,6 @@ ) from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -49,7 +48,7 @@ def infer_module_output_dtypes( """ outputs = [node for node in module.graph.nodes if node.op == "output"] outputs = outputs[0].args - return get_output_dtypes(outputs, truncate_double) + return list(get_output_dtypes(outputs, truncate_double)) def insert_engine_to_cache( @@ -358,6 +357,41 @@ def convert_module( "Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available" ) + rank = -1 + world_size = -1 + if settings.use_distributed_mode_trace: + import os + + import torch.distributed as dist + + # Check if distributed backends are available + if ENABLED_FEATURES.native_trt_collectives: + logger.info( + "Native TRT collectives available (TRT 10.16+) for distributed execution" + ) + elif ENABLED_FEATURES.trtllm_for_nccl: + logger.info("TRT-LLM NCCL plugins available for distributed execution") + else: + logger.warning( + "Distributed mode requested but neither native TRT collectives nor TRT-LLM NCCL plugins are available. " + "Distributed execution may not work correctly. " + "For native TRT collectives, ensure TensorRT 10.16+ and torch_tensorrt built with NCCL support. " + "For TRT-LLM fallback, set TRTLLM_PLUGINS_PATH or USE_TRTLLM_PLUGINS=1." + ) + + if dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + # Fallback to environment variables + rank = int(os.environ.get("RANK", -1)) + world_size = int(os.environ.get("WORLD_SIZE", -1)) + + if rank >= 0 and world_size > 0: + logger.info( + f"Creating TRT module for distributed execution: rank={rank}, world_size={world_size}" + ) + return rt_cls( serialized_engine=serialized_interpreter_result.serialized_engine, input_binding_names=list(serialized_interpreter_result.input_names), @@ -367,4 +401,6 @@ def convert_module( weight_name_map=serialized_interpreter_result.weight_name_map, requires_output_allocator=serialized_interpreter_result.requires_output_allocator, symbolic_shape_expressions=serialized_interpreter_result.symbolic_shape_expressions, + rank=rank, + world_size=world_size, ) diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index 302a254f60..2324b04806 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -14,11 +14,69 @@ ) from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( tensorrt_fused_nccl_all_gather_op, + tensorrt_fused_nccl_all_reduce_op, tensorrt_fused_nccl_reduce_scatter_op, ) _LOGGER: logging.Logger = logging.getLogger(__name__) +if ENABLED_FEATURES.native_trt_collectives: + # Use native TensorRT DistCollective API (no TensorRT-LLM dependency) + _LOGGER.info("Using native TensorRT DistCollective API for distributed operations") + + @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) + def fused_nccl_gather( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + """All-gather using native TensorRT DistCollective API""" + return impl.nccl_ops.nccl_gather_native( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) + + @dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op) + def fused_nccl_reduce_scatter( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + """Reduce-scatter using native TensorRT DistCollective API""" + return impl.nccl_ops.nccl_reduce_scatter_native( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) + + @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_reduce_op) + def fused_nccl_all_reduce( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + """All-reduce using native TensorRT DistCollective API""" + reduce_op = args[1] if len(args) > 1 else "sum" + return impl.nccl_ops.nccl_all_reduce_native( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + reduce_op=reduce_op, + ) + # Conditionally register NCCL converters only if TensorRT-LLM plugin is available. # We use an `if` statement instead of @needs_trtllm_for_nccl decorator because @@ -28,7 +86,7 @@ # Order 1: @needs_trtllm_for_nccl followed by registering the converter leads to plugin registry not finding nccl ops plugins since we register the bare converter, without the decorator # Order 2: registering the converter first followed by @needs_trtllm_for_nccl leads to "NotImplementedError: TensorRT-LLM plugin for NCCL is not available :TensorRT-LLM plugin for NCCL is not available" and no fall back to pytorch -if ENABLED_FEATURES.trtllm_for_nccl: +elif ENABLED_FEATURES.trtllm_for_nccl: _LOGGER.debug( "TensorRT-LLM plugin for NCCL is available. Registering NCCL converters." ) @@ -65,6 +123,22 @@ def fused_nccl_reduce_scatter( [args[0]], ) + @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_reduce_op) + def fused_nccl_all_reduce( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + return impl.nccl_ops.nccl_all_reduce( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) + else: _LOGGER.info( "TensorRT-LLM plugin for NCCL is not available. " diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py index e64c06ca39..55df569f09 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py @@ -1,3 +1,4 @@ +import logging import os from enum import IntEnum, IntFlag, auto from typing import Optional, Tuple, Union @@ -9,6 +10,8 @@ from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import set_layer_name +logger = logging.getLogger(__name__) + # class for AllReduce class AllReduceStrategy(IntEnum): @@ -33,6 +36,34 @@ class AllReduceConfig(IntFlag): PUSH_MODE = auto() +def _get_distributed_rank_and_world_size() -> Tuple[int, int]: + """Get rank and world_size from environment variables. + + Returns: + (rank, world_size) tuple. + + Raises: + RuntimeError: If WORLD_SIZE is not set. + """ + _world_size = os.environ.get("WORLD_SIZE") + if _world_size is None: + raise RuntimeError( + "The WORLD_SIZE env variable is not set in distributed environment" + ) + world_size = int(_world_size) + + # Get rank from environment + _rank = int(os.environ.get("RANK", 0)) + if _rank is not None: + rank = int(_rank) + else: + raise RuntimeError( + "The RANK env variable is not set in distributed environment" + ) + + return rank, world_size + + def nccl_gather( ctx: ConversionContext, target: Union[Target, str], @@ -44,13 +75,11 @@ def nccl_gather( "AllGather", "1", "tensorrt_llm" ) assert allgather_plg_creator is not None - _world_size = os.environ.get("WORLD_SIZE") - if _world_size is not None: - world_size = int(_world_size) - else: - raise RuntimeError( - "The WORLD_SIZE env variable is not set in distributed environment" - ) + rank, world_size = _get_distributed_rank_and_world_size() + logger.debug( + f"Adding TRT-LLM NCCL gather: name={name}, rank={rank}, world_size={world_size}" + ) + group = list(range(world_size)) group = trt.PluginField( "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 @@ -66,6 +95,56 @@ def nccl_gather( return layer.get_output(0) +def nccl_all_reduce( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs: Tuple[Argument, ...], +) -> trt.ITensor: + allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator( + "AllReduce", "1", "tensorrt_llm" + ) + assert allreduce_plg_creator is not None + + counter = 0 + strategy = AllReduceStrategy.NCCL + config = AllReduceConfig(0) + rank, world_size = _get_distributed_rank_and_world_size() + logger.debug( + f"Adding TRT-LLM NCCL all reduce: name={name}, rank={rank}, world_size={world_size}" + ) + group = list(range(world_size)) + group = trt.PluginField( + "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 + ) + + p_dtype = trt.float32 + pf_dtype = trt.PluginField( + "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32 + ) + pfc = [group, pf_dtype] + p_strategy = trt.PluginField( + "strategy", np.array([int(strategy)], np.int8), trt.PluginFieldType.INT8 + ) + pfc.append(p_strategy) + p_config = trt.PluginField( + "config", np.array([int(config)], np.int8), trt.PluginFieldType.INT8 + ) + pfc.append(p_config) + p_counter = trt.PluginField( + "counter", np.array([counter], np.int32), trt.PluginFieldType.INT32 + ) + pfc.append(p_counter) + + pfc = trt.PluginFieldCollection(pfc) + ar_plug = allreduce_plg_creator.create_plugin("allreduce", pfc) + + layer = ctx.net.add_plugin_v2(plug_inputs, ar_plug) + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) + + def nccl_reduce_scatter( ctx: ConversionContext, target: Union[Target, str], @@ -82,13 +161,10 @@ def nccl_reduce_scatter( counter = 0 strategy = AllReduceStrategy.NCCL config = AllReduceConfig(0) - _world_size = os.environ.get("WORLD_SIZE") - if _world_size is not None: - world_size = int(_world_size) - else: - raise RuntimeError( - "The WORLD_SIZE env variable is not set in distributed environment" - ) + rank, world_size = _get_distributed_rank_and_world_size() + logger.debug( + f"Adding TRT-LLM NCCL reduce scatter: name={name}, rank={rank}, world_size={world_size}" + ) group = list(range(world_size)) group = trt.PluginField( "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 @@ -118,3 +194,250 @@ def nccl_reduce_scatter( layer = ctx.net.add_plugin_v2(plug_inputs, ar_plug) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) + + +def nccl_gather_native( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs: Tuple[Argument, ...], +) -> trt.ITensor: + """ + Implement all_gather using native TensorRT DistCollective API. + + This operation gathers tensors from all ranks and concatenates them. + Each rank contributes a tensor, and all ranks receive the concatenated result. + + Returns: + Output tensor after all_gather operation + + Example: + Input on rank 0: [1, 2] shape=(2,) + Input on rank 1: [3, 4] shape=(2,) + Output on all ranks: [1, 2, 3, 4] shape=(4,) + """ + rank, world_size = _get_distributed_rank_and_world_size() + logger.debug( + f"Adding native all_gather: name={name}, rank={rank}, world_size={world_size}" + ) + + # Get the input tensor + input_tensor = plug_inputs[0] + + try: + # Use native TensorRT DistCollective API for ALL_GATHER + # For ALL_GATHER, the reduce operation and root rank parameters are ignored + # The last parameter (group) can be None to include all ranks + import numpy as np + + # Create array of all participating rank IDs [0, 1, 2, ..., world_size-1] + groups = np.arange(world_size, dtype=np.int64) + + logger.debug( + f"Creating ALL_GATHER layer: groups={groups.tolist()}, groupSize={world_size}" + ) + layer = ctx.net.add_dist_collective( + input_tensor, + trt.CollectiveOperation.ALL_GATHER, + trt.ReduceOperation.NONE, # Ignored for ALL_GATHER + -1, # Root rank - ignored for ALL_GATHER + groups, # None means all ranks participate (world_size ranks) + ) + + logger.debug(f"Successfully created native ALL_GATHER layer: {name}") + logger.debug( + f"Calling add_dist_collective: input_shape={input_tensor.shape}, " + f"groups={groups.tolist()}, groupSize={len(groups)} (inferred from array)" + ) + + set_layer_name(layer, target, name, source_ir) + + output = layer.get_output(0) + layer.num_ranks = world_size + + return output + + except AttributeError as e: + error_msg = ( + f"Native ALL_GATHER failed: {e}. " + "This usually means TensorRT doesn't support native distributed collectives. " + f"Your TensorRT version: {trt.__version__}. " + "Native collectives require TensorRT 10.16 or later. " + "Consider using TensorRT-LLM plugins instead by setting USE_NATIVE_TRT_COLLECTIVES=0" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + except Exception as e: + logger.error(f"Native ALL_GATHER failed: {e} (type: {type(e).__name__})") + raise + + +def nccl_reduce_scatter_native( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs: Tuple[Argument, ...], + reduce_op: str = "sum", +) -> trt.ITensor: + """ + Implement reduce_scatter using native TensorRT DistCollective API. + + This operation reduces tensors from all ranks and scatters the result. + The input is split along dimension 0 and each rank receives one chunk + after the reduction operation. + + Returns: + Output tensor after reduce_scatter operation + reduce_op: Reduction operation ("sum", "prod", "min", "max", "avg") + + Example (with SUM reduction): + Input on rank 0: [1, 2, 3, 4] shape=(4,) + Input on rank 1: [5, 6, 7, 8] shape=(4,) + Output on rank 0: [1+5, 2+6] = [6, 8] shape=(2,) + Output on rank 1: [3+7, 4+8] = [10, 12] shape=(2,) + """ + rank, world_size = _get_distributed_rank_and_world_size() + logger.debug( + f"Adding native reduce_scatter: name={name}, rank={rank}, world_size={world_size}, reduce_op={reduce_op}" + ) + + # Get the input tensor + input_tensor = plug_inputs[0] + + # Map string reduction op to TensorRT ReduceOperation enum + reduce_op_map = { + "sum": trt.ReduceOperation.SUM, + "prod": trt.ReduceOperation.PROD, + "min": trt.ReduceOperation.MIN, + "max": trt.ReduceOperation.MAX, + "avg": trt.ReduceOperation.AVG, + } + + if reduce_op.lower() not in reduce_op_map: + raise ValueError( + f"Unsupported reduce operation: {reduce_op}. " + f"Supported: {list(reduce_op_map.keys())}" + ) + + trt_reduce_op = reduce_op_map[reduce_op.lower()] + + try: + layer = ctx.net.add_dist_collective( + input_tensor, + trt.CollectiveOperation.REDUCE_SCATTER, + trt_reduce_op, + -1, + None, # None means all ranks participate + ) + + set_layer_name(layer, target, name, source_ir) + + output = layer.get_output(0) + layer.num_ranks = world_size + logger.debug( + f"Successfully created native REDUCE_SCATTER layer: {name}, reduce_op={reduce_op}" + ) + + return output + + except AttributeError as e: + error_msg = ( + f"Native ALL_REDUCE_SCATTER failed: {e}. " + "This usually means TensorRT doesn't support native distributed collectives. " + f"Your TensorRT version: {trt.__version__}. " + "Native collectives require TensorRT 10.16 or later. " + "Consider using TensorRT-LLM plugins instead by setting USE_NATIVE_TRT_COLLECTIVES=0" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + except Exception as e: + logger.error(f"Native REDUCE_SCATTER failed: {e} (type: {type(e).__name__})") + raise + + +def nccl_all_reduce_native( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs: Tuple[Argument, ...], + reduce_op: str = "sum", +) -> trt.ITensor: + """ + Implement all_reduce using native TensorRT DistCollective API. + + This operation reduces tensors across all ranks in-place. Every rank + receives the same reduced result. + + Returns: + Output tensor after all_reduce operation + + Args: + reduce_op: Reduction operation ("sum", "prod", "min", "max", "avg") + + Example (with SUM reduction): + Input on rank 0: [1, 2, 3, 4] shape=(4,) + Input on rank 1: [5, 6, 7, 8] shape=(4,) + Output on rank 0: [6, 8, 10, 12] shape=(4,) + Output on rank 1: [6, 8, 10, 12] shape=(4,) + """ + rank, world_size = _get_distributed_rank_and_world_size() + logger.debug( + f"Adding native all_reduce: name={name}, rank={rank}, world_size={world_size}, reduce_op={reduce_op}" + ) + + input_tensor = plug_inputs[0] + + reduce_op_map = { + "sum": trt.ReduceOperation.SUM, + "prod": trt.ReduceOperation.PROD, + "min": trt.ReduceOperation.MIN, + "max": trt.ReduceOperation.MAX, + "avg": trt.ReduceOperation.AVG, + } + + if reduce_op.lower() not in reduce_op_map: + raise ValueError( + f"Unsupported reduce operation: {reduce_op}. " + f"Supported: {list(reduce_op_map.keys())}" + ) + + trt_reduce_op = reduce_op_map[reduce_op.lower()] + + try: + layer = ctx.net.add_dist_collective( + input_tensor, + trt.CollectiveOperation.ALL_REDUCE, + trt_reduce_op, + -1, + None, + ) + + set_layer_name(layer, target, name, source_ir) + + output = layer.get_output(0) + layer.num_ranks = world_size + logger.debug( + f"Successfully created native ALL_REDUCE layer: {name}, reduce_op={reduce_op}" + ) + + return output + + except AttributeError as e: + error_msg = ( + f"Native ALL_REDUCE failed: {e}. " + "This usually means TensorRT doesn't support native distributed collectives. " + f"Your TensorRT version: {trt.__version__}. " + "Native collectives require TensorRT 10.16 or later. " + "Consider using TensorRT-LLM plugins instead by setting USE_NATIVE_TRT_COLLECTIVES=0" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + except Exception as e: + logger.error(f"Native ALL_REDUCE failed: {e} (type: {type(e).__name__})") + raise diff --git a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py index 02cb2ccd56..a6a826dbb6 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py @@ -33,6 +33,16 @@ def tensorrt_fused_nccl_reduce_scatter_op( ) +def tensorrt_fused_nccl_all_reduce_op( + inp: Any, reduce_op: str, group_size: int, group_name: str +) -> torch.Tensor: + return torch.ops._c10d_functional.wait_tensor.default( + torch.ops._c10d_functional.all_reduce.default( + inp, reduce_op, group_size, group_name + ) + ) + + def fuse_distributed_ops( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: @@ -43,6 +53,7 @@ def fuse_distributed_ops( in ( torch.ops._c10d_functional.all_gather_into_tensor.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, + torch.ops._c10d_functional.all_reduce.default, ) and len(node.users) == 1 and list(node.users)[0].target @@ -53,14 +64,23 @@ def fuse_distributed_ops( with gm.graph.inserting_after(wait_tensor_node): fused_node = gm.graph.create_node( op="call_function", - target=tensorrt_fused_nccl_all_gather_op, # Define your custom fused function + target=tensorrt_fused_nccl_all_gather_op, args=(node.args[0], node.args[1], node.args[2]), ) + elif ( + node.target == torch.ops._c10d_functional.reduce_scatter_tensor.default + ): + with gm.graph.inserting_after(wait_tensor_node): + fused_node = gm.graph.create_node( + op="call_function", + target=tensorrt_fused_nccl_reduce_scatter_op, + args=(node.args[0], node.args[1], node.args[2], node.args[3]), + ) else: with gm.graph.inserting_after(wait_tensor_node): fused_node = gm.graph.create_node( op="call_function", - target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function + target=tensorrt_fused_nccl_all_reduce_op, args=(node.args[0], node.args[1], node.args[2], node.args[3]), ) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 31182bbe21..bd8203fe8f 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -4,15 +4,18 @@ from contextlib import nullcontext from typing import Any, Dict, List, Optional, Sequence, Tuple +import tensorrt as trt import torch import torch_tensorrt from torch.nn import Module from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform, dtype +from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger +from torch_tensorrt.dynamo.runtime._nccl_utils import setup_nccl_library from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER from torch_tensorrt.runtime._utils import ( @@ -21,8 +24,6 @@ multi_gpu_device_check, ) -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -134,6 +135,8 @@ def __init__( requires_output_allocator: bool = False, symbolic_shape_expressions: Optional[Dict[str, List[Dict[str, Any]]]] = None, _debugger_config: Optional[DebuggerConfig] = None, + rank: int = -1, + world_size: int = 1, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine @@ -149,7 +152,8 @@ def __init__( weight_name_map (dict): Mapping of engine weight name to state_dict weight name requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) symbolic_shape_expressions (List[str]): List of symbolic shape expressions for each output binding - + rank (int): Rank of the current process, applicable for distributed inference + world_size (int): World size of the distributed process, applicable for distributed inference Example: .. code-block:: py @@ -229,6 +233,10 @@ def __init__( if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() + self.rank = rank + self.world_size = world_size + self._nccl_comm: Optional[Any] = None + def set_output_tensors_as_unowned(self, enabled: bool) -> None: """ Flag to set if the output tensors of this engine are solely owned by the Torch-TensorRT Runtime or if they might be shared with a user. @@ -280,6 +288,222 @@ def set_default_device_memory_budget(self) -> int: logger.debug(f"Weight streaming budget set to {budget_bytes}B") return self._set_device_memory_budget(budget_bytes) + # Distributed functions + @property + def is_distributed(self) -> bool: + """Check if this module is configured for distributed execution.""" + return self.rank >= 0 and self.world_size > 1 + + @property + def has_native_trt_collectives(self) -> bool: + """Check if native TRT collectives are available (TRT 10.16+ with NCCL).""" + return bool(ENABLED_FEATURES.native_trt_collectives) + + def get_rank(self) -> int: + """Get the rank of this process in distributed execution.""" + return self.rank + + def get_world_size(self) -> int: + """Get the total number of processes in distributed execution.""" + return self.world_size + + def set_nccl_communicator(self, comm: Any) -> None: + if not self.is_distributed: + logger.warning( + "Setting NCCL communicator on non-distributed module " + f"(rank={self.rank}, world_size={self.world_size})" + ) + self._nccl_comm = comm + # Only set communicator on context if native TRT collectives are available (TRT 10.16+) + if not self.has_native_trt_collectives: + logger.debug( + "Native TRT collectives not available, skipping set_communicator on TensorRT context" + ) + return + + if self.context is not None: + try: + # TensorRT's set_communicator expects a PyCapsule, not an integer pointer + # Convert integer pointer to PyCapsule if needed + comm_to_pass = comm + if isinstance(comm, int): + import ctypes + + # Create a PyCapsule from the pointer value + ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object + ctypes.pythonapi.PyCapsule_New.argtypes = [ + ctypes.c_void_p, + ctypes.c_char_p, + ctypes.c_void_p, + ] + comm_to_pass = ctypes.pythonapi.PyCapsule_New(comm, None, None) + logger.debug( + f"Converted integer pointer {comm} to PyCapsule for TensorRT" + ) + + success = self.context.set_communicator(comm_to_pass) + if success: + logger.debug( + f"NCCL communicator set on TensorRT context (rank={self.rank})" + ) + else: + logger.warning( + f"set_communicator returned False (rank={self.rank})" + ) + except AttributeError: + logger.warning("TensorRT context does not support set_communicator") + except TypeError as e: + logger.error(f"Failed to set NCCL communicator: {e}") + raise + + def get_nccl_communicator(self) -> Optional[Any]: + """Get the NCCL communicator if set.""" + return self._nccl_comm + + def setup_nccl(self, use_pytorch_comm: bool = True) -> None: + # to check if we need try block for this + # Ensure NCCL library path is configured for TensorRT + # This handles the case where pip-installed PyTorch has NCCL in a non-standard location + setup_nccl_library() + try: + import torch.distributed as dist + except ImportError as e: + raise RuntimeError( + "torch.distributed is required for setup_nccl(). " f"Import error: {e}" + ) + if not dist.is_initialized(): + raise RuntimeError( + "torch.distributed must be initialized before calling setup_nccl(). " + "Call dist.init_process_group('nccl') first." + ) + + if not self.is_distributed: + raise RuntimeError( + f"Module is not configured for distributed execution " + f"(rank={self.rank}, world_size={self.world_size}). " + "Pass rank and world_size to constructor." + ) + + # Check if native TRT collectives are available + if self.has_native_trt_collectives: + logger.info( + f"Using native TRT collectives (TRT 10.16+) for distributed execution (rank={self.rank})" + ) + elif ENABLED_FEATURES.trtllm_for_nccl: + logger.info(f"Using TRT-LLM plugins for NCCL backend (rank={self.rank})") + else: + logger.warning( + "Neither native TRT collectives nor TRT-LLM NCCL plugins are available. " + "Distributed execution may not work correctly. " + "For native TRT collectives, ensure TensorRT 10.16+ is installed and " + "torch_tensorrt was built with NCCL support. " + "For TRT-LLM fallback, set TRTLLM_PLUGINS_PATH or USE_TRTLLM_PLUGINS=1." + ) + + # Try to get communicator from PyTorch's ProcessGroupNCCL which is preferred + nccl_comm = self._get_nccl_comm_from_process_group() + + # Fall back to creating via NCCL library if process group method fails + # Note: this is fallback mechanism which is to be tested + if nccl_comm is None: + logger.debug("Falling back to creating NCCL communicator via nccl library") + nccl_comm = self._create_nccl_comm_via_nccl_lib() + + # Set the communicator + self.set_nccl_communicator(nccl_comm) + + def _get_nccl_comm_from_process_group(self) -> Optional[Any]: + # expectation is that dist.init_process_group has been called + # In there, Rank 0 generated ncclUniqueId + # Broadcasted it to all ranks via the store + # Each rank called ncclCommInitRank() + import torch.distributed as dist + + pg = dist.group.WORLD + if pg is None: + logger.debug("No default process group available") + return None + + # Check if backend is NCCL + if dist.get_backend(pg) != "nccl": + logger.debug("ProcessGroup backend is not NCCL, cannot reuse communicator") + return None + + # Get the NCCL backend object via _get_backend (internal API) + if not hasattr(pg, "_get_backend"): + logger.debug("ProcessGroup does not have _get_backend method") + return None + + try: + backend = pg._get_backend(torch.device("cuda")) + except Exception as e: + logger.debug(f"Failed to get NCCL backend: {e}") + return None + + # now we have the backend + # Get comm pointer from the backend (internal API) + if not hasattr(backend, "_comm_ptr"): + logger.debug("NCCL backend does not have _comm_ptr method") + return None + + # Force NCCL communicator initialization with a dummy collective. + # PyTorch's ProcessGroupNCCL uses lazy initialization - the NCCL + # communicator is only created when the first collective operation + # is performed. Without this, _comm_ptr() returns 0. + try: + dummy = torch.zeros(1, device="cuda") + dist.all_reduce(dummy) + logger.debug("Forced NCCL initialization with dummy all_reduce") + except Exception as e: + logger.debug(f"Failed to force NCCL initialization: {e}") + return None + + try: + comm_ptr = backend._comm_ptr() + except Exception as e: + logger.debug(f"Failed to call _comm_ptr: {e}") + return None + + if comm_ptr is None or comm_ptr == 0: + logger.debug("_comm_ptr returned None or 0") + return None + + logger.info( + f"Reusing PyTorch's NCCL communicator (ptr={comm_ptr}, rank={self.rank})" + ) + return comm_ptr + + def _create_nccl_comm_via_nccl_lib(self) -> Any: + import nccl.core as nccl + import torch.distributed as dist + + rank = self.rank + world_size = self.world_size + + # Generate unique ID on rank 0 and broadcast as a tensor + if rank == 0: + uid = nccl.get_unique_id() + uid_bytes = uid.as_bytes + uid_tensor = torch.frombuffer( + bytearray(uid_bytes), dtype=torch.uint8 + ).cuda() + logger.debug(f"Rank {rank} created NCCL unique ID ({len(uid_bytes)} bytes)") + else: + uid_tensor = torch.zeros(128, dtype=torch.uint8, device="cuda") + + dist.broadcast(uid_tensor, src=0) + logger.debug(f"Rank {rank} received NCCL unique ID") + + uid = nccl.UniqueId.from_bytes(bytes(uid_tensor.cpu().numpy())) + + comm = nccl.Communicator.init(world_size, rank, uid) + logger.info( + f"Created new NCCL communicator via nccl library (rank={rank}, world_size={world_size})" + ) + + self._nccl_comm_handle = comm + return comm.ptr + def setup_engine(self) -> None: assert ( self.target_platform == Platform.current_platform() @@ -333,6 +557,9 @@ def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> Non state_dict[prefix + "input_names"] = self.input_names state_dict[prefix + "output_names"] = self.output_names state_dict[prefix + "platform"] = self.target_platform + # Distributed info (always saved, -1 indicates non-distributed) + state_dict[prefix + "rank"] = self.rank + state_dict[prefix + "world_size"] = self.world_size def _load_from_state_dict( self, @@ -348,6 +575,9 @@ def _load_from_state_dict( self.input_names = state_dict[prefix + "input_names"] self.output_names = state_dict[prefix + "output_names"] self.target_platform = state_dict[prefix + "platform"] + # Distributed info (optional, backward compatible with non-distributed models) + self.rank = state_dict.get(prefix + "rank", -1) + self.world_size = state_dict.get(prefix + "world_size", -1) # Run multi-gpu device check to validate engine instantiation multi_gpu_device_check() @@ -357,10 +587,14 @@ def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() state.pop("engine", None) state.pop("context", None) + # NCCLcomm cannot be pickled + state.pop("_nccl_comm", None) return state def __setstate__(self, state: Dict[str, Any]) -> None: self.__dict__.update(state) + # reset after unpickling, apbose: is this required though? + self._nccl_comm = None self.setup_engine() def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule: @@ -699,6 +933,23 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: ): self._check_initialized() + if self.is_distributed and self._nccl_comm is None: + nccl_type = ( + "native TRT collectives" + if self.has_native_trt_collectives + else ( + "TRT-LLM NCCL plugins" + if ENABLED_FEATURES.trtllm_for_nccl + else "unknown backend" + ) + ) + logger.info( + f"Setting up NCCL for distributed execution using {nccl_type} " + f"(rank={self.rank}, world_size={self.world_size})" + ) + self.setup_nccl() + logger.info(f"NCCL setup complete, comm={self._nccl_comm}") + # If in safe mode, check at each iteration for whether a switch is required if ( torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index d77c0bf39f..dc898ee6af 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -15,6 +15,7 @@ needs_torch_tensorrt_runtime, ) from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.runtime._nccl_utils import setup_nccl_library logger = logging.getLogger(__name__) @@ -53,7 +54,9 @@ RESOURCE_ALLOCATION_STRATEGY_IDX = ( torch.ops.tensorrt.RESOURCE_ALLOCATION_STRATEGY_IDX() ) # 10 - SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 11 + RANK_IDX = torch.ops.tensorrt.RANK_IDX() # 11 + WORLD_SIZE_IDX = torch.ops.tensorrt.WORLD_SIZE_IDX() # 12 + SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 13 @for_all_methods(needs_torch_tensorrt_runtime) @@ -88,6 +91,8 @@ def __init__( weight_name_map: Optional[dict[Any, Any]] = None, requires_output_allocator: bool = False, symbolic_shape_expressions: Optional[Dict[str, List[Dict[str, Any]]]] = None, + rank: int = -1, + world_size: int = 1, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines @@ -146,6 +151,11 @@ def __init__( self.requires_output_allocator = requires_output_allocator self.dynamically_allocate_resources = settings.dynamically_allocate_resources self.symbolic_shape_expressions = symbolic_shape_expressions + self.rank = rank + self.world_size = world_size + self._nccl_setup_done = ( + False # Track if NCCL has been setup for distributed mode + ) if ( serialized_engine @@ -203,6 +213,8 @@ def _pack_engine_info(self) -> List[str | bytes]: engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = str( int(self.dynamically_allocate_resources) ) + engine_info[RANK_IDX] = str(self.rank) + engine_info[WORLD_SIZE_IDX] = str(self.world_size) return engine_info @@ -329,6 +341,159 @@ def set_pre_allocated_outputs(self, enable: bool) -> None: def set_use_output_allocator(self, enable: bool) -> None: self.engine.use_output_allocator_outputs = enable + def _auto_init_distributed(self) -> None: + """ + Automatically initialize distributed inference if the engine was compiled with + rank and world_size set (from torch.distributed). + + This is called automatically after setup_engine() to configure NCCL communicators + for distributed inference without requiring manual setup. + """ + if self.engine is None: + return + + logger.debug( + f"In _auto_init_distributed: _nccl_setup_done={self._nccl_setup_done}" + ) + + if not ENABLED_FEATURES.native_trt_collectives: + logger.debug( + "TRT native NCCL collectives not available, skipping distributed setup" + ) + return + + # Check if the engine has distributed info set (rank >= 0 and world_size > 1) + if ( + self.engine.rank >= 0 + and self.engine.world_size > 1 + and not self._nccl_setup_done + ): + try: + import torch.distributed as dist + + self.set_distributed_info() + + # Only auto-initialize if torch.distributed is initialized + if dist.is_available() and dist.is_initialized(): + logger.debug( + f"Auto-initializing distributed inference " + f"(rank={self.engine.rank}, world_size={self.engine.world_size})" + ) + # this calls self.engine.set_process_group(process_group) + pg = dist.group.WORLD + self.int_nccl_comm(pg) + self._nccl_setup_done = True + logger.debug(f"NCCL setup complete (rank={self.engine.rank})") + else: + logger.warning( + f"Engine has distributed info (rank={self.engine.rank}, world_size={self.engine.world_size}) " + f"but torch.distributed is not initialized. " + f"Call dist.init_process_group() and then module.set_process_group() manually." + ) + except RuntimeError as e: + # Catch tracing errors specifically (e.g., "Tracer cannot infer type of ProcessGroup") + if "Tracer cannot infer" in str(e) or "traced functions" in str(e): + logger.debug("Skipping NCCL auto-init during tracing/compilation") + else: + logger.warning( + f"Failed to auto-initialize distributed inference: {e}. " + f"Call module.set_process_group() manually if needed." + ) + + def set_distributed_info( + self, rank: Optional[int] = None, world_size: Optional[int] = None + ) -> None: + """ + Set rank and world_size for distributed inference. + + This method sets the rank and world_size on the TensorRT engine. If not provided, + they will be auto-detected from torch.distributed. + + Args: + rank: Rank of the current process (auto-detects if None) + world_size: Total number of processes (auto-detects if None) + + """ + if self.engine is None: + raise RuntimeError( + "Engine has not been setup yet. Call setup_engine() first." + ) + + # Auto-detect if not provided + if rank is None or world_size is None: + import torch.distributed as dist + + if dist.is_available() and dist.is_initialized(): + if rank is None: + rank = dist.get_rank() + if world_size is None: + world_size = dist.get_world_size() + else: + raise RuntimeError( + "torch.distributed is not initialized and rank/world_size not provided. " + "Call dist.init_process_group() first or provide rank/world_size explicitly." + ) + + # Set on C++ TRTEngine + self.engine.set_rank(rank) + self.engine.set_world_size(world_size) + logger.debug( + f"Distributed info set on TRTEngine: rank={rank}, world_size={world_size}" + ) + + def init_nccl_comm(self, process_group: Optional[Any] = None) -> None: + """ + Initialize NCCL communicator for distributed execution. + + This method initializes the NCCL communicator from the C++ ProcessGroup registry. + The ProcessGroup must be registered in PyTorch's native registry (which happens + automatically when using torch.distributed). + """ + if not ENABLED_FEATURES.native_trt_collectives: + raise RuntimeError( + "Native TRT NCCL collectives are not available. " + "Requires TensorRT 10.16+ and PyTorch built with NCCL support." + ) + if self.engine is None: + raise RuntimeError( + "Engine has not been setup yet. Call setup_engine() first." + ) + + setup_nccl_library() + + # Get the process group if not provided + if process_group is None: + try: + import torch.distributed as dist + + if not dist.is_initialized(): + raise RuntimeError( + "torch.distributed is not initialized. Call dist.init_process_group() first." + ) + process_group = dist.distributed_c10d._get_default_group() + logger.debug("Using default ProcessGroup from torch.distributed") + except Exception as e: + raise RuntimeError(f"Failed to get default process group: {e}") + + # Get the group name from the ProcessGroup + # This is the name used to register the group in the C++ registry + group_name = "default" + if ( + hasattr(process_group, "group_name") + and process_group.group_name is not None + ): + group_name = process_group.group_name + logger.debug(f"Using ProcessGroup with group_name: {group_name}") + + # Initialize NCCL communicator from C++ registry + # This uses c10d::resolve_process_group() to get the ProcessGroup and extract the NCCL comm + self.engine.init_nccl_comm(group_name) + + self._nccl_setup_done = True + logger.debug( + f"NCCL comm initialized from ProcessGroup (rank={self.engine.rank}, world_size={self.engine.world_size})" + ) + def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: """Implementation of the forward pass for a TensorRT engine @@ -341,6 +506,9 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: if self.engine is None: raise RuntimeError("Engine has not been setup yet.") + if not self._nccl_setup_done: + self._auto_init_distributed() + assert len(inputs) == len( self.input_binding_names ), f"Wrong number of inputs, expected {len(self.input_binding_names)} got {len(inputs)}." diff --git a/py/torch_tensorrt/dynamo/runtime/__init__.py b/py/torch_tensorrt/dynamo/runtime/__init__.py index 0eb66b24b0..63cabbef67 100644 --- a/py/torch_tensorrt/dynamo/runtime/__init__.py +++ b/py/torch_tensorrt/dynamo/runtime/__init__.py @@ -1,4 +1,9 @@ import torch_tensorrt +from torch_tensorrt.dynamo.runtime._nccl_utils import ( # noqa: F401 + check_nccl_library_path, + get_nccl_library_path, + setup_nccl_library, +) from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( # noqa: F401 PythonTorchTensorRTModule, ) diff --git a/py/torch_tensorrt/dynamo/runtime/_nccl_utils.py b/py/torch_tensorrt/dynamo/runtime/_nccl_utils.py new file mode 100644 index 0000000000..9f8519d715 --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/_nccl_utils.py @@ -0,0 +1,176 @@ +""" +NCCL Library Utilities for Distributed TensorRT Inference + +This module handles NCCL library path resolution to ensure TensorRT and PyTorch +use the same NCCL library instance. This is critical for sharing NCCL communicators +between PyTorch's distributed backend and TensorRT's native NCCL collectives. + +Background: +----------- +TensorRT's dlopen("libnccl.so") may load a different NCCL library than PyTorch, +causing crashes when sharing NCCL communicators. + +- PyTorch loads NCCL via RPATH baked at compile time (libnccl.so.2) +- TensorRT lazy-loads NCCL via dlopen("libnccl.so") at runtime + +The mismatch occurs because: +1. pip's nvidia-nccl-cu* package only ships libnccl.so.2 (no libnccl.so symlink) +2. TRT specifically looks for libnccl.so, misses pip's copy, falls back to system NCCL + +Environments: +------------- +- NGC containers: No action needed (both use system NCCL) +- pip install torch: Requires symlink + LD_LIBRARY_PATH setup + +Future: +------- +TensorRT 11.0 will support TRT_NCCL_LIBRARY env var, eliminating the need for +symlink workarounds. +""" + +import logging +import os +from typing import Optional + +logger = logging.getLogger(__name__) + +_nccl_setup_checked = False + + +def get_nccl_library_path() -> Optional[str]: + """ + Get the path to PyTorch's NCCL library directory. + + Returns: + Path to NCCL lib directory if nvidia.nccl package exists, None otherwise. + None indicates system NCCL is being used (e.g., NGC containers). + """ + try: + import nvidia.nccl + + nccl_lib_dir = os.path.join(list(nvidia.nccl.__path__)[0], "lib") + if os.path.isdir(nccl_lib_dir): + return nccl_lib_dir + return None + except ImportError: + # nvidia.nccl not installed - using system NCCL (e.g., NGC container) + return None + + +def ensure_nccl_symlink(nccl_lib_dir: str) -> bool: + """ + Ensure libnccl.so symlink exists pointing to libnccl.so.2. + + TensorRT's dlopen looks for "libnccl.so", but pip's nvidia-nccl package + only ships "libnccl.so.2". This creates the necessary symlink. + + Args: + nccl_lib_dir: Path to the NCCL library directory + + Returns: + True if symlink exists or was created, False otherwise. + """ + nccl_so = os.path.join(nccl_lib_dir, "libnccl.so") + nccl_so_2 = os.path.join(nccl_lib_dir, "libnccl.so.2") + + # Check if symlink already exists + if os.path.lexists(nccl_so): + return True + + # Check if target exists + if not os.path.exists(nccl_so_2): + logger.warning(f"NCCL library not found at {nccl_so_2}") + return False + + # Try to create symlink + try: + os.symlink("libnccl.so.2", nccl_so) + logger.info(f"Created NCCL symlink: {nccl_so} -> libnccl.so.2") + return True + except PermissionError: + logger.warning( + f"Cannot create NCCL symlink at {nccl_so} (permission denied). " + f"Please run: ln -sf libnccl.so.2 {nccl_so}" + ) + return False + except OSError as e: + logger.warning(f"Failed to create NCCL symlink: {e}") + return False + + +def check_nccl_library_path() -> bool: + """ + Check if LD_LIBRARY_PATH includes PyTorch's NCCL directory. + + Returns: + True if configuration is correct, False if LD_LIBRARY_PATH needs updating. + """ + nccl_lib_dir = get_nccl_library_path() + + if nccl_lib_dir is None: + # System NCCL - no action needed + return True + + ld_library_path = os.environ.get("LD_LIBRARY_PATH", "") + return nccl_lib_dir in ld_library_path + + +def setup_nccl_library() -> None: + """ + Setup NCCL library path for TensorRT distributed inference. + + This function: + 1. Detects if nvidia.nccl pip package is installed + 2. Creates libnccl.so symlink if needed + 3. Warns if LD_LIBRARY_PATH is not configured + + Call this before initializing NCCL communicators for TensorRT. + + For NGC containers (system NCCL), this is a no-op. + For pip-installed PyTorch, this ensures TensorRT can find the correct NCCL. + """ + global _nccl_setup_checked + + # Only check once per process + if _nccl_setup_checked: + return + _nccl_setup_checked = True + + nccl_lib_dir = get_nccl_library_path() + + if nccl_lib_dir is None: + # NGC container or system NCCL - no action needed + logger.debug( + "nvidia.nccl package not found. " + "Assuming system NCCL is used by both PyTorch and TensorRT." + ) + return + + logger.debug(f"Found nvidia.nccl package at: {nccl_lib_dir}") + + # Ensure symlink exists + symlink_ok = ensure_nccl_symlink(nccl_lib_dir) + + # Check LD_LIBRARY_PATH + ld_library_path = os.environ.get("LD_LIBRARY_PATH", "") + if nccl_lib_dir not in ld_library_path: + logger.warning( + f"\n" + f"{'=' * 70}\n" + f"NCCL LIBRARY PATH WARNING\n" + f"{'=' * 70}\n" + f"PyTorch's NCCL library directory is not in LD_LIBRARY_PATH.\n" + f"TensorRT may load a different NCCL library than PyTorch,\n" + f"causing distributed inference to fail.\n" + f"\n" + f"NCCL directory: {nccl_lib_dir}\n" + f"\n" + f"Please set before running:\n" + f" export LD_LIBRARY_PATH={nccl_lib_dir}:$LD_LIBRARY_PATH\n" + f"{'=' * 70}\n" + ) + else: + logger.debug(f"LD_LIBRARY_PATH includes NCCL directory: {nccl_lib_dir}") + + if symlink_ok: + logger.debug("NCCL library setup complete") diff --git a/third_party/libtorch/BUILD b/third_party/libtorch/BUILD index 37309f7209..45d45646a9 100644 --- a/third_party/libtorch/BUILD +++ b/third_party/libtorch/BUILD @@ -34,10 +34,12 @@ cc_library( hdrs = glob( [ "include/torch/**/*.h", + "include/torch/**/*.hpp", ], allow_empty = True, exclude = [ "include/torch/csrc/api/include/**/*.h", + "include/torch/csrc/api/include/**/*.hpp", ], ) + glob([ "include/torch/csrc/api/include/**/*.h", diff --git a/toolchains/torch_nccl/BUILD b/toolchains/torch_nccl/BUILD new file mode 100644 index 0000000000..ffd0fb0cdc --- /dev/null +++ b/toolchains/torch_nccl/BUILD @@ -0,0 +1 @@ +package(default_visibility = ["//visibility:public"]) diff --git a/toolchains/torch_nccl/defs.bzl b/toolchains/torch_nccl/defs.bzl new file mode 100644 index 0000000000..9a7a015c2c --- /dev/null +++ b/toolchains/torch_nccl/defs.bzl @@ -0,0 +1,60 @@ +"""NCCL detection for PyTorch builds.""" + +def _torch_nccl_detect_impl(repository_ctx): + """Detect if PyTorch was built with NCCL support.""" + + # Skip detection on non-Linux (NCCL not available) + os_name = repository_ctx.os.name.lower() + if "linux" not in os_name: + has_nccl = False + else: + # Find libtorch path + result = repository_ctx.execute([ + "python3", + "-c", + "import torch; import os; print(os.path.dirname(torch.__file__))", + ]) + + if result.return_code != 0: + has_nccl = False + else: + torch_path = result.stdout.strip() + lib_path = torch_path + "/lib/libtorch_cuda.so" + + # Check for ProcessGroupNCCL symbol + result = repository_ctx.execute([ + "grep", + "-q", + "ProcessGroupNCCL", + lib_path, + ]) + has_nccl = (result.return_code == 0) + + # Generate BUILD file with config_setting + repository_ctx.file("BUILD", """ +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") + +package(default_visibility = ["//visibility:public"]) + +bool_flag( + name = "use_nccl", + build_setting_default = {has_nccl}, +) + +config_setting( + name = "nccl_enabled", + flag_values = {{":use_nccl": "True"}}, +) +""".format(has_nccl = has_nccl)) + +torch_nccl_detect = repository_rule( + implementation = _torch_nccl_detect_impl, + local = True, # Re-run on each build to detect changes +) + +def if_torch_nccl(if_true, if_false = []): + """Returns if_true if PyTorch has NCCL, else if_false.""" + return select({ + "@torch_nccl//:nccl_enabled": if_true, + "//conditions:default": if_false, + }) diff --git a/tools/llm/tensor_parallel_llama_llm.py b/tools/llm/tensor_parallel_llama_llm.py new file mode 100644 index 0000000000..04ac05fd79 --- /dev/null +++ b/tools/llm/tensor_parallel_llama_llm.py @@ -0,0 +1,340 @@ +""" +.. _run_llm_tp: + +Tensor Parallel LLM inference with Torch-TensorRT +================================================== + +This script extends run_llm.py to support Tensor Parallelism (TP) across multiple GPUs. +Weights in Attention (Q, K, V, O projections) and MLP (gate, up, down projections) +are sharded across ranks using PyTorch's parallelize_module API. AllReduce is inserted +automatically by RowwiseParallel at the output projection of each sub-block. +Torch-TensorRT lowers the resulting collective ops into TRT ncclwrapper via TRT-MD. + +Usage +----- +.. code-block:: bash + + mpirun -n 2 python3 tensor_parallel_llama_llm.py \\ + --model meta-llama/Llama-3.2-1B-Instruct \\ + --prompt "What is parallel programming?" \\ + --model_precision FP16 --num_tokens 128 +""" + +import argparse +import logging +import os +from contextlib import nullcontext + +# Distributed init must happen before importing torch_tensorrt. +# mpirun sets OMPI_COMM_WORLD_LOCAL_RANK / OMPI_COMM_WORLD_SIZE but NOT the +# RANK/WORLD_SIZE/MASTER_ADDR/MASTER_PORT vars that PyTorch's env:// rendezvous +# expects, so we translate them here. +import torch +import torch.distributed as dist +import torch.distributed.tensor._dtensor_spec +import torch.utils._pytree +from torch.distributed.device_mesh import init_device_mesh + +# DTensorSpec appears in the graph during torch.export of a TP model. +# Register it as a pytree constant so the exporter treats it as a +# compile-time constant rather than a dynamic input. +torch.utils._pytree.register_constant( + torch.distributed.tensor._dtensor_spec.DTensorSpec +) + +_ompi_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0)) +_ompi_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) +os.environ.setdefault("RANK", str(_ompi_rank)) +os.environ.setdefault("WORLD_SIZE", str(_ompi_size)) +os.environ.setdefault("MASTER_ADDR", "127.0.0.1") +os.environ.setdefault("MASTER_PORT", "29501") + +dist.init_process_group(backend="nccl") +rank = dist.get_rank() +world_size = dist.get_world_size() +DEVICE = torch.device(f"cuda:{rank}") +torch.cuda.set_device(DEVICE) + + +def initialize_logger( + rank, logger_file_name, file_level=logging.DEBUG, console_level=logging.INFO +): + """Initialize rank-specific Torch-TensorRT logger with configurable handler levels.""" + logger = logging.getLogger("torch_tensorrt") + logger.setLevel(logging.DEBUG) + logger.handlers.clear() + + # File handler + fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") + fh.setLevel(file_level) + fh.setFormatter( + logging.Formatter( + f"[Rank {rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + ) + logger.addHandler(fh) + + # Console handler + ch = logging.StreamHandler() + ch.setLevel(console_level) + ch.setFormatter(logging.Formatter(f"[Rank {rank}] %(levelname)s: %(message)s")) + logger.addHandler(ch) + + logger.propagate = False + return logger + + +# Initialize logger for this rank +logger = initialize_logger( + rank, "llm_tp_log_mod", file_level=logging.DEBUG, console_level=logging.INFO +) +logger.info( + f"Initialized distributed environment: rank={rank}, world_size={world_size}, device={DEVICE}" +) + +import torch_tensorrt +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + parallelize_module, +) +from torchtrt_ext import register_sdpa +from transformers import AutoModelForCausalLM, AutoTokenizer +from utils import generate, record_stats, time_generate + + +def get_model(args, device_mesh): + with torch.no_grad(): + model = ( + AutoModelForCausalLM.from_pretrained( + args.model, + use_cache=False, + attn_implementation="sdpa", + ignore_mismatched_sizes=True, + ) + .eval() + .to(DEVICE) + ) + register_sdpa.enable_sdpa_converter(args.model, model.config) + + if args.model_precision == "FP16": + model = model.to(torch.float16) + elif args.model_precision == "BF16": + model = model.to(torch.bfloat16) + + # Build TP plan: ColwiseParallel for first linear in each pair, + # RowwiseParallel for second linear (inserts AllReduce when output_layouts=Replicate). + tp_plan = {} + for i in range(model.config.num_hidden_layers): + tp_plan.update( + { + f"model.layers.{i}.self_attn.q_proj": ColwiseParallel(), + f"model.layers.{i}.self_attn.k_proj": ColwiseParallel(), + f"model.layers.{i}.self_attn.v_proj": ColwiseParallel(), + f"model.layers.{i}.self_attn.o_proj": RowwiseParallel(), + f"model.layers.{i}.mlp.gate_proj": ColwiseParallel(), + f"model.layers.{i}.mlp.up_proj": ColwiseParallel(), + f"model.layers.{i}.mlp.down_proj": RowwiseParallel(), + } + ) + parallelize_module(model, device_mesh, tp_plan) + + # HuggingFace attention uses self.num_heads / self.num_key_value_heads to + # reshape Q/K/V outputs. After weight sharding each rank only holds + # num_heads // world_size columns, so these attributes must be updated or + # the reshape will produce a shape mismatch error. + for layer in model.model.layers: + layer.self_attn.num_heads = model.config.num_attention_heads // world_size + layer.self_attn.num_key_value_heads = ( + model.config.num_key_value_heads // world_size + ) + + return model + + +def compile_torchtrt(model, input_ids, args): + use_fp32_acc = False + use_explicit_typing = False + if args.model_precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.model_precision == "BF16": + enabled_precisions = {torch.bfloat16} + else: + enabled_precisions = {torch.float32} + + # torch.export does not support DTensor-parallelized models (sharding propagation + # fails during run_decompositions). Use torch.compile with dynamic=True so that + # torch._dynamo traces via aot_autograd (use_distributed_mode_trace=True path) + # and builds a single TRT engine with dynamic sequence-length profiles. + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch.compile( + model, + backend="torch_tensorrt", + dynamic=False, + options={ + "enabled_precisions": enabled_precisions, + "use_explicit_typing": use_explicit_typing, + "use_fp32_acc": use_fp32_acc, + "device": DEVICE, + "disable_tf32": True, + "use_python_runtime": True, + "use_distributed_mode_trace": True, + "debug": args.debug, + "min_block_size": args.min_block_size, + "assume_dynamic_shape_support": True, + }, + ) + + return trt_model + + +def print_outputs(backend_name, gen_tokens, tokenizer): + print(f"========= {backend_name} =========") + print( + f"{backend_name} model generated text: ", + tokenizer.decode(gen_tokens[0], skip_special_tokens=True), + ) + print("===================================") + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run tensor parallel LLM inference with Torch-TensorRT" + ) + arg_parser.add_argument( + "--model", + type=str, + default="meta-llama/Llama-3.2-1B-Instruct", + help="Name of LLM model", + ) + arg_parser.add_argument( + "--tokenizer", + type=str, + default="", + help="Name of LLM model tokenizer", + ) + arg_parser.add_argument( + "--prompt", + type=str, + default="What is parallel programming?", + help="Prompt", + ) + arg_parser.add_argument( + "--model_precision", + type=str, + default="FP16", + help="Precision to use in the model. Options: FP16, BF16, FP32", + ) + arg_parser.add_argument( + "--num_tokens", + type=int, + default=128, + help="Number of output tokens to generate", + ) + arg_parser.add_argument( + "--min_block_size", + type=int, + default=1, + help="Minimum block size for TensorRT compilation", + ) + arg_parser.add_argument( + "--benchmark", + action="store_true", + help="Enable benchmarking mode (default: False)", + ) + arg_parser.add_argument( + "--iterations", + type=int, + default=5, + help="Number of benchmark iterations", + ) + arg_parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size for benchmarking", + ) + arg_parser.add_argument( + "--isl", + type=int, + default=2048, + help="Input sequence length for benchmarking", + ) + arg_parser.add_argument( + "--debug", + action="store_true", + help="Enable debug logging (default: False)", + ) + args = arg_parser.parse_args() + + device_mesh = init_device_mesh("cuda", (world_size,)) + + with torch.inference_mode(): + model = get_model(args, device_mesh) + + assert model.config.num_key_value_heads % world_size == 0, ( + f"num_key_value_heads ({model.config.num_key_value_heads}) must be " + f"divisible by world_size ({world_size})." + ) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + if args.benchmark: + input_ids = torch.randint( + 1, 10000, (args.batch_size, args.isl), dtype=torch.int64 + ).to(DEVICE) + else: + model_inputs = tokenizer(args.prompt, return_tensors="pt") + input_ids = model_inputs["input_ids"].to(DEVICE) + + MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.num_tokens + + # Run uncompiled torch model first for comparison + torch_gen_tokens = generate( + model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + ) + if rank == 0: + print_outputs("Torch-TP (uncompiled)", torch_gen_tokens, tokenizer) + + trt_model = compile_torchtrt(model, input_ids, args) + + trt_gen_tokens = generate( + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + ) + + if args.benchmark: + trt_timings = time_generate( + generate, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + + if rank == 0: + if not args.benchmark: + print_outputs("TensorRT-TP", trt_gen_tokens, tokenizer) + else: + trt_stats = record_stats( + "TensorRT-TP", + trt_timings, + args.model_precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + print("=========TensorRT-TP PERFORMANCE============") + print(trt_stats) + + dist.destroy_process_group() From d3a2ef5f97a59b2a59fd52d2e3f79328720aa1da Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 8 Apr 2026 14:29:40 -0700 Subject: [PATCH 2/7] removing the try-except block in TRTengine.cpp and correcting the typis --- core/runtime/TRTEngine.cpp | 136 +++++------------- .../dynamo/runtime/_TorchTensorRTModule.py | 2 +- 2 files changed, 35 insertions(+), 103 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 6937cbfa33..f7787f0433 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -554,39 +554,20 @@ void TRTEngine::set_nccl_comm(int64_t comm_ptr) { } bool TRTEngine::set_nccl_communicator_to_trt_context() { - // Set NCCL communicator on TensorRT execution context - // The communicator should be set from Python via set_nccl_comm() or set_process_group() + TORCHTRT_CHECK(exec_ctx != nullptr, "Cannot set NCCL communicator: execution context is null"); + TORCHTRT_CHECK( + this->nccl_comm != nullptr, + "Distributed inference enabled but no NCCL communicator set. " + "Call set_process_group() or set_nccl_comm() from Python first."); - if (!exec_ctx) { - LOG_ERROR("Cannot set NCCL communicator: execution context is null"); - return false; - } + void* comm_ptr = static_cast(this->nccl_comm); + exec_ctx->setCommunicator(comm_ptr); - if (this->nccl_comm == nullptr) { - LOG_WARNING( - "Distributed inference enabled but no NCCL communicator set. " - "Call set_process_group() or set_nccl_comm() from Python first."); - return false; - } - - // Set NCCL communicator on TensorRT execution context - try { - // Cast ncclComm_t to void* for TensorRT API - void* comm_ptr = static_cast(this->nccl_comm); - - // Set the NCCL communicator on the execution context - // The device ID is used to identify which GPU's communicator this is - exec_ctx->setCommunicator(comm_ptr); - - LOG_INFO( - "NCCL communicator set on TensorRT execution context " - "(rank=" - << this->rank << ", device=" << this->device_info.id << ")"); - return true; - } catch (const std::exception& e) { - LOG_ERROR("Failed to set NCCL communicator on execution context: " << e.what()); - return false; - } + LOG_INFO( + "NCCL communicator set on TensorRT execution context " + "(rank=" + << this->rank << ", device=" << this->device_info.id << ")"); + return true; } void TRTEngine::init_nccl_comm(const std::string& group_name) { @@ -595,85 +576,36 @@ void TRTEngine::init_nccl_comm(const std::string& group_name) { } bool TRTEngine::set_process_group_from_registry(const std::string& group_name) { - // Get ProcessGroup from C++ registry and extract NCCL communicator - // This avoids the need to pass the ProcessGroup from Python LOG_INFO("TRTEngine::set_process_group_from_registry() called with group_name: " << group_name); - LOG_INFO(" Current rank: " << this->rank); - LOG_INFO(" Current world_size: " << this->world_size); - LOG_INFO(" Current device_id: " << this->device_info.id); - - try { - // Resolve ProcessGroup from the native registry - auto pg = c10d::resolve_process_group(group_name); - if (!pg) { - LOG_ERROR("Failed to resolve ProcessGroup '" << group_name << "' from registry"); - return false; - } - LOG_INFO(" Resolved ProcessGroup from registry: rank=" << pg->getRank() << ", size=" << pg->getSize()); - // Update rank and world_size from the ProcessGroup if not already set - if (this->rank < 0) { - this->rank = pg->getRank(); - LOG_INFO(" Set rank from ProcessGroup: " << this->rank); - } - if (this->world_size < 0) { - this->world_size = pg->getSize(); - LOG_INFO(" Set world_size from ProcessGroup: " << this->world_size); - } + auto pg = c10d::resolve_process_group(group_name); + TORCHTRT_CHECK(pg != nullptr, "ProcessGroup '" << group_name << "' not found in registry"); + LOG_INFO(" Resolved ProcessGroup: rank=" << pg->getRank() << ", size=" << pg->getSize()); - // Get the NCCL backend from the ProcessGroup - // ProcessGroup wraps Backend objects - we need to get the NCCL backend explicitly - c10::intrusive_ptr backend; - try { - backend = pg->getBackend(c10d::ProcessGroup::BackendType::NCCL); - } catch (const std::exception& e) { - LOG_ERROR("Failed to get NCCL backend from ProcessGroup: " << e.what()); - return false; - } + if (this->rank < 0) { + this->rank = pg->getRank(); + } + if (this->world_size < 0) { + this->world_size = pg->getSize(); + } - if (!backend) { - LOG_ERROR("ProcessGroup '" << group_name << "' does not have an NCCL backend"); - return false; - } - LOG_INFO(" Got NCCL backend from ProcessGroup"); + auto backend = pg->getBackend(c10d::ProcessGroup::BackendType::NCCL); + TORCHTRT_CHECK(backend != nullptr, "ProcessGroup '" << group_name << "' has no NCCL backend"); - // Cast the backend to ProcessGroupNCCL - auto* nccl_pg = dynamic_cast(backend.get()); - if (!nccl_pg) { - LOG_ERROR("Backend is not ProcessGroupNCCL (unexpected)"); - return false; - } - LOG_INFO(" Successfully cast to ProcessGroupNCCL"); - - // Set current CUDA device to match the engine's device before getting comm - // getCommPtr() uses at::cuda::current_device() internally - at::cuda::set_device(this->device_info.id); - LOG_INFO(" Set current CUDA device to: " << this->device_info.id); - - // Get NCCL comm pointer using the public getCommPtr() method - // This returns the communicator for the current CUDA device - int64_t comm_ptr = nccl_pg->getCommPtr(); - if (comm_ptr == 0) { - LOG_ERROR( - "Failed to get NCCL communicator for device " << this->device_info.id - << ". The communicator may not be initialized yet."); - LOG_ERROR("Hint: Ensure a collective operation has been performed on this device first."); - return false; - } + auto* nccl_pg = dynamic_cast(backend.get()); + TORCHTRT_CHECK(nccl_pg != nullptr, "Backend is not ProcessGroupNCCL"); - // Convert int64_t pointer to ncclComm_t - ncclComm_t comm = reinterpret_cast(comm_ptr); + at::cuda::set_device(this->device_info.id); - this->nccl_comm = comm; - LOG_INFO(" Successfully extracted NCCL communicator from registry"); - LOG_INFO(" nccl_comm: " << (void*)this->nccl_comm); - // Set on TensorRT execution context - return True; + int64_t comm_ptr = nccl_pg->getCommPtr(); + TORCHTRT_CHECK( + comm_ptr != 0, + "NCCL communicator not initialized for device " << this->device_info.id + << ". Ensure a collective operation has been performed first."); - } catch (const std::exception& e) { - LOG_ERROR("Failed to get ProcessGroup from registry: " << e.what()); - return false; - } + this->nccl_comm = reinterpret_cast(comm_ptr); + LOG_INFO(" NCCL communicator set: " << (void*)this->nccl_comm); + return true; } #endif // ENABLE_TRT_NCCL_COLLECTIVES diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index dc898ee6af..4099e35a91 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -381,7 +381,7 @@ def _auto_init_distributed(self) -> None: ) # this calls self.engine.set_process_group(process_group) pg = dist.group.WORLD - self.int_nccl_comm(pg) + self.init_nccl_comm(pg) self._nccl_setup_done = True logger.debug(f"NCCL setup complete (rank={self.engine.rank})") else: From 02268c0036495f430fa4830cf73fa6bef21a0f5d Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 9 Apr 2026 12:13:35 -0700 Subject: [PATCH 3/7] Redesign distributed inference API: auto-detect rank, lazy NCCL setup, ABI v9 --- core/runtime/TRTEngine.cpp | 96 +++---- core/runtime/TRTEngine.h | 13 +- core/runtime/execute_engine.cpp | 16 +- core/runtime/register_jit_hooks.cpp | 17 +- core/runtime/runtime.h | 7 +- .../dynamo/conversion/_conversion.py | 21 -- .../runtime/_PythonTorchTensorRTModule.py | 270 +++++------------- .../dynamo/runtime/_TorchTensorRTModule.py | 211 +++----------- 8 files changed, 176 insertions(+), 475 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index f7787f0433..b15f2093c5 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -96,12 +96,17 @@ TRTEngine::TRTEngine(std::vector serialized_info) (static_cast(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic : ResourceAllocationStrategy::kStatic)) { - // Load distributed info if available (backward compatible with older ABI versions) - if (serialized_info.size() > RANK_IDX && !serialized_info[RANK_IDX].empty()) { - this->rank = std::stoll(serialized_info[RANK_IDX]); - } - if (serialized_info.size() > WORLD_SIZE_IDX && !serialized_info[WORLD_SIZE_IDX].empty()) { - this->world_size = std::stoll(serialized_info[WORLD_SIZE_IDX]); + if (std::stoi(serialized_info[IS_MD_ENGINE_IDX])) { + int64_t build_rank = std::stoll(serialized_info[OPTIONAL_RANK_IDX]); + int64_t build_world_size = std::stoll(serialized_info[OPTIONAL_WORLD_SIZE_IDX]); + if (build_rank != this->rank) { + LOG_INFO( + "Distributed engine originally built on rank " << build_rank << " of " << build_world_size + << ", now running on rank " << this->rank << " of " + << this->world_size); + } else { + LOG_INFO("Distributed engine: rank " << this->rank << " of " << this->world_size); + } } } @@ -512,6 +517,12 @@ std::vector TRTEngine::serialize() { serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize(); serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; + bool is_md = this->world_size > 1; + serialized_info[IS_MD_ENGINE_IDX] = is_md ? "1" : "0"; + if (is_md) { + serialized_info[OPTIONAL_RANK_IDX] = std::to_string(this->rank); + serialized_info[OPTIONAL_WORLD_SIZE_IDX] = std::to_string(this->world_size); + } return serialized_info; } @@ -534,60 +545,19 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt } } -void TRTEngine::set_rank(int64_t rank_val) { - this->rank = rank_val; - LOG_DEBUG("Rank set on TRTEngine: " << this->rank); -} - -void TRTEngine::set_world_size(int64_t world_size_val) { - this->world_size = world_size_val; - LOG_DEBUG("World size set on TRTEngine: " << this->world_size); -} - #ifdef ENABLE_TRT_NCCL_COLLECTIVES -void TRTEngine::set_nccl_comm(int64_t comm_ptr) { - this->nccl_comm = reinterpret_cast(comm_ptr); - LOG_DEBUG("NCCL communicator stored on TRTEngine (rank=" << this->rank << ")"); - - // Also set on TensorRT execution context - set_nccl_communicator_to_trt_context(); -} - -bool TRTEngine::set_nccl_communicator_to_trt_context() { - TORCHTRT_CHECK(exec_ctx != nullptr, "Cannot set NCCL communicator: execution context is null"); - TORCHTRT_CHECK( - this->nccl_comm != nullptr, - "Distributed inference enabled but no NCCL communicator set. " - "Call set_process_group() or set_nccl_comm() from Python first."); - - void* comm_ptr = static_cast(this->nccl_comm); - exec_ctx->setCommunicator(comm_ptr); - - LOG_INFO( - "NCCL communicator set on TensorRT execution context " - "(rank=" - << this->rank << ", device=" << this->device_info.id << ")"); - return true; -} - -void TRTEngine::init_nccl_comm(const std::string& group_name) { - // Use C++ registry to get NCCL communicator - set_process_group_from_registry(group_name); -} - -bool TRTEngine::set_process_group_from_registry(const std::string& group_name) { - LOG_INFO("TRTEngine::set_process_group_from_registry() called with group_name: " << group_name); - +void TRTEngine::detect_distributed_context(const std::string& group_name) { auto pg = c10d::resolve_process_group(group_name); - TORCHTRT_CHECK(pg != nullptr, "ProcessGroup '" << group_name << "' not found in registry"); - LOG_INFO(" Resolved ProcessGroup: rank=" << pg->getRank() << ", size=" << pg->getSize()); - - if (this->rank < 0) { + if (pg) { this->rank = pg->getRank(); - } - if (this->world_size < 0) { this->world_size = pg->getSize(); + LOG_DEBUG("Detected distributed context: rank=" << this->rank << ", world_size=" << this->world_size); } +} + +void TRTEngine::setup_nccl_comm(const std::string& group_name) { + auto pg = c10d::resolve_process_group(group_name); + TORCHTRT_CHECK(pg != nullptr, "ProcessGroup '" << group_name << "' not found in registry"); auto backend = pg->getBackend(c10d::ProcessGroup::BackendType::NCCL); TORCHTRT_CHECK(backend != nullptr, "ProcessGroup '" << group_name << "' has no NCCL backend"); @@ -604,7 +574,21 @@ bool TRTEngine::set_process_group_from_registry(const std::string& group_name) { << ". Ensure a collective operation has been performed first."); this->nccl_comm = reinterpret_cast(comm_ptr); - LOG_INFO(" NCCL communicator set: " << (void*)this->nccl_comm); + set_nccl_communicator_to_trt_context(); + LOG_INFO("NCCL comm set up (rank=" << this->rank << ", device=" << this->device_info.id << ")"); +} + +bool TRTEngine::set_nccl_communicator_to_trt_context() { + TORCHTRT_CHECK(exec_ctx != nullptr, "Cannot set NCCL communicator: execution context is null"); + TORCHTRT_CHECK(this->nccl_comm != nullptr, "NCCL communicator is not set"); + + void* comm_ptr = static_cast(this->nccl_comm); + exec_ctx->setCommunicator(comm_ptr); + + LOG_INFO( + "NCCL communicator set on TensorRT execution context " + "(rank=" + << this->rank << ", device=" << this->device_info.id << ")"); return true; } #endif // ENABLE_TRT_NCCL_COLLECTIVES diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index eb7e1f46d4..8e5dd78676 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -217,15 +217,14 @@ struct TRTEngine : torch::CustomClassHolder { int64_t rank = -1; int64_t world_size = -1; - // Set rank and world_size for distributed inference - void set_rank(int64_t rank_val); - void set_world_size(int64_t world_size_val); - #ifdef ENABLE_TRT_NCCL_COLLECTIVES ncclComm_t nccl_comm = nullptr; - void set_nccl_comm(int64_t comm_ptr); - void init_nccl_comm(const std::string& group_name = "default"); - bool set_process_group_from_registry(const std::string& group_name = "default"); + + // Detect rank and world_size from ProcessGroup + void detect_distributed_context(const std::string& group_name); + + // Resolve ProcessGroup, get NCCL communicator, and bind to TRT context + void setup_nccl_comm(const std::string& group_name); bool set_nccl_communicator_to_trt_context(); #endif diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 4868b092f4..137358efa3 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -311,19 +311,11 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->enqueue_profile_path); } - // Distributed setup - set NCCL communicator on TensorRT execution context + // Distributed setup - bind NCCL communicator to TRT execution context + // setup_nccl_comm must have been called from Python before first forward #ifdef ENABLE_TRT_NCCL_COLLECTIVES - if (compiled_engine->rank >= 0 && compiled_engine->world_size > 1) { - bool result = compiled_engine->set_nccl_communicator_to_trt_context(); - if (!result) { - LOG_ERROR("Failed to set NCCL communicator on TRT context"); - LOG_ERROR("This will cause collective operations to fail at runtime"); - LOG_ERROR("Make sure to call module.init_nccl_comm() after compilation"); - } - } else { - LOG_DEBUG( - "Single-device mode (rank=" << compiled_engine->rank << ", world_size=" << compiled_engine->world_size - << ") - skipping NCCL setup"); + if (compiled_engine->world_size > 1 && compiled_engine->nccl_comm != nullptr) { + compiled_engine->set_nccl_communicator_to_trt_context(); } #endif diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index ffae7c7455..91331a8e52 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -110,15 +110,15 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def_property("automatic_device_memory_budget", &TRTEngine::get_automatic_device_memory_budget) .def_readonly("rank", &TRTEngine::rank) .def_readonly("world_size", &TRTEngine::world_size) - .def("set_rank", &TRTEngine::set_rank) - .def("set_world_size", &TRTEngine::set_world_size) #ifdef ENABLE_TRT_NCCL_COLLECTIVES - .def("set_nccl_comm", &TRTEngine::set_nccl_comm) .def( - "init_nccl_comm", - [](c10::intrusive_ptr self, std::string group_name = "default") { - self->init_nccl_comm(group_name); + "detect_distributed_context", + [](c10::intrusive_ptr self, std::string group_name) { + self->detect_distributed_context(group_name); }) + .def( + "setup_nccl_comm", + [](c10::intrusive_ptr self, std::string group_name) { self->setup_nccl_comm(group_name); }) #endif .def_pickle( [](const c10::intrusive_ptr& self) -> std::vector { return self->serialize(); }, @@ -162,8 +162,9 @@ TORCH_LIBRARY(tensorrt, m) { m.def("REQUIRES_OUTPUT_ALLOCATOR_IDX", []() -> int64_t { return REQUIRES_OUTPUT_ALLOCATOR_IDX; }); m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; }); m.def("RESOURCE_ALLOCATION_STRATEGY_IDX", []() -> int64_t { return RESOURCE_ALLOCATION_STRATEGY_IDX; }); - m.def("RANK_IDX", []() -> int64_t { return RANK_IDX; }); - m.def("WORLD_SIZE_IDX", []() -> int64_t { return WORLD_SIZE_IDX; }); + m.def("IS_MD_ENGINE_IDX", []() -> int64_t { return IS_MD_ENGINE_IDX; }); + m.def("OPTIONAL_RANK_IDX", []() -> int64_t { return OPTIONAL_RANK_IDX; }); + m.def("OPTIONAL_WORLD_SIZE_IDX", []() -> int64_t { return OPTIONAL_WORLD_SIZE_IDX; }); m.def("NATIVE_TRT_COLLECTIVES_AVAIL", []() -> bool { #ifdef ENABLE_TRT_NCCL_COLLECTIVES return true; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 61e4362289..741d8dee3b 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -16,7 +16,7 @@ namespace core { namespace runtime { using EngineID = int64_t; -const std::string ABI_VERSION = "8"; +const std::string ABI_VERSION = "9"; extern bool MULTI_DEVICE_SAFE_MODE; typedef enum { @@ -39,8 +39,9 @@ typedef enum { TARGET_PLATFORM_IDX, REQUIRES_OUTPUT_ALLOCATOR_IDX, RESOURCE_ALLOCATION_STRATEGY_IDX, - RANK_IDX, - WORLD_SIZE_IDX, + IS_MD_ENGINE_IDX, + OPTIONAL_RANK_IDX, + OPTIONAL_WORLD_SIZE_IDX, SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 1be2b92520..069ae3f43c 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -357,13 +357,7 @@ def convert_module( "Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available" ) - rank = -1 - world_size = -1 if settings.use_distributed_mode_trace: - import os - - import torch.distributed as dist - # Check if distributed backends are available if ENABLED_FEATURES.native_trt_collectives: logger.info( @@ -379,19 +373,6 @@ def convert_module( "For TRT-LLM fallback, set TRTLLM_PLUGINS_PATH or USE_TRTLLM_PLUGINS=1." ) - if dist.is_initialized(): - rank = dist.get_rank() - world_size = dist.get_world_size() - else: - # Fallback to environment variables - rank = int(os.environ.get("RANK", -1)) - world_size = int(os.environ.get("WORLD_SIZE", -1)) - - if rank >= 0 and world_size > 0: - logger.info( - f"Creating TRT module for distributed execution: rank={rank}, world_size={world_size}" - ) - return rt_cls( serialized_engine=serialized_interpreter_result.serialized_engine, input_binding_names=list(serialized_interpreter_result.input_names), @@ -401,6 +382,4 @@ def convert_module( weight_name_map=serialized_interpreter_result.weight_name_map, requires_output_allocator=serialized_interpreter_result.requires_output_allocator, symbolic_shape_expressions=serialized_interpreter_result.symbolic_shape_expressions, - rank=rank, - world_size=world_size, ) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index bd8203fe8f..51f804a854 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -135,8 +135,6 @@ def __init__( requires_output_allocator: bool = False, symbolic_shape_expressions: Optional[Dict[str, List[Dict[str, Any]]]] = None, _debugger_config: Optional[DebuggerConfig] = None, - rank: int = -1, - world_size: int = 1, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine @@ -152,8 +150,6 @@ def __init__( weight_name_map (dict): Mapping of engine weight name to state_dict weight name requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) symbolic_shape_expressions (List[str]): List of symbolic shape expressions for each output binding - rank (int): Rank of the current process, applicable for distributed inference - world_size (int): World size of the distributed process, applicable for distributed inference Example: .. code-block:: py @@ -233,8 +229,15 @@ def __init__( if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() - self.rank = rank - self.world_size = world_size + # Auto-detect distributed context + import torch.distributed as dist + + if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + else: + self.rank = -1 + self.world_size = -1 self._nccl_comm: Optional[Any] = None def set_output_tensors_as_unowned(self, enabled: bool) -> None: @@ -292,218 +295,67 @@ def set_default_device_memory_budget(self) -> int: @property def is_distributed(self) -> bool: """Check if this module is configured for distributed execution.""" - return self.rank >= 0 and self.world_size > 1 + return bool(self.world_size > 1) - @property - def has_native_trt_collectives(self) -> bool: - """Check if native TRT collectives are available (TRT 10.16+ with NCCL).""" - return bool(ENABLED_FEATURES.native_trt_collectives) + def setup_nccl_comm(self) -> None: + """Set up NCCL communicator from PyTorch's ProcessGroup. - def get_rank(self) -> int: - """Get the rank of this process in distributed execution.""" - return self.rank + In PythonTorchTensorRTModule, this is a single call that gets the NCCL comm + and binds it to the TRT context. rank/world_size are already set in __init__ + via dist.get_rank(). - def get_world_size(self) -> int: - """Get the total number of processes in distributed execution.""" - return self.world_size - - def set_nccl_communicator(self, comm: Any) -> None: + In TorchTensorRTModule (C++ runtime), this is split into two calls: + - detect_distributed_context(group_name): sets rank/world_size on the C++ engine + (called in setup_engine, needed for serialization before forward) + - setup_nccl_comm(group_name): gets NCCL comm and binds to TRT context + (called lazily on first forward) + """ if not self.is_distributed: - logger.warning( - "Setting NCCL communicator on non-distributed module " - f"(rank={self.rank}, world_size={self.world_size})" - ) - self._nccl_comm = comm - # Only set communicator on context if native TRT collectives are available (TRT 10.16+) - if not self.has_native_trt_collectives: - logger.debug( - "Native TRT collectives not available, skipping set_communicator on TensorRT context" - ) return - if self.context is not None: - try: - # TensorRT's set_communicator expects a PyCapsule, not an integer pointer - # Convert integer pointer to PyCapsule if needed - comm_to_pass = comm - if isinstance(comm, int): - import ctypes - - # Create a PyCapsule from the pointer value - ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object - ctypes.pythonapi.PyCapsule_New.argtypes = [ - ctypes.c_void_p, - ctypes.c_char_p, - ctypes.c_void_p, - ] - comm_to_pass = ctypes.pythonapi.PyCapsule_New(comm, None, None) - logger.debug( - f"Converted integer pointer {comm} to PyCapsule for TensorRT" - ) - - success = self.context.set_communicator(comm_to_pass) - if success: - logger.debug( - f"NCCL communicator set on TensorRT context (rank={self.rank})" - ) - else: - logger.warning( - f"set_communicator returned False (rank={self.rank})" - ) - except AttributeError: - logger.warning("TensorRT context does not support set_communicator") - except TypeError as e: - logger.error(f"Failed to set NCCL communicator: {e}") - raise - - def get_nccl_communicator(self) -> Optional[Any]: - """Get the NCCL communicator if set.""" - return self._nccl_comm - - def setup_nccl(self, use_pytorch_comm: bool = True) -> None: - # to check if we need try block for this - # Ensure NCCL library path is configured for TensorRT - # This handles the case where pip-installed PyTorch has NCCL in a non-standard location setup_nccl_library() - try: - import torch.distributed as dist - except ImportError as e: - raise RuntimeError( - "torch.distributed is required for setup_nccl(). " f"Import error: {e}" - ) + + import torch.distributed as dist + if not dist.is_initialized(): raise RuntimeError( - "torch.distributed must be initialized before calling setup_nccl(). " + "torch.distributed must be initialized before calling setup_nccl_comm(). " "Call dist.init_process_group('nccl') first." ) - if not self.is_distributed: - raise RuntimeError( - f"Module is not configured for distributed execution " - f"(rank={self.rank}, world_size={self.world_size}). " - "Pass rank and world_size to constructor." - ) - - # Check if native TRT collectives are available - if self.has_native_trt_collectives: - logger.info( - f"Using native TRT collectives (TRT 10.16+) for distributed execution (rank={self.rank})" - ) - elif ENABLED_FEATURES.trtllm_for_nccl: - logger.info(f"Using TRT-LLM plugins for NCCL backend (rank={self.rank})") - else: - logger.warning( - "Neither native TRT collectives nor TRT-LLM NCCL plugins are available. " - "Distributed execution may not work correctly. " - "For native TRT collectives, ensure TensorRT 10.16+ is installed and " - "torch_tensorrt was built with NCCL support. " - "For TRT-LLM fallback, set TRTLLM_PLUGINS_PATH or USE_TRTLLM_PLUGINS=1." - ) - - # Try to get communicator from PyTorch's ProcessGroupNCCL which is preferred - nccl_comm = self._get_nccl_comm_from_process_group() - - # Fall back to creating via NCCL library if process group method fails - # Note: this is fallback mechanism which is to be tested - if nccl_comm is None: - logger.debug("Falling back to creating NCCL communicator via nccl library") - nccl_comm = self._create_nccl_comm_via_nccl_lib() - - # Set the communicator - self.set_nccl_communicator(nccl_comm) - - def _get_nccl_comm_from_process_group(self) -> Optional[Any]: - # expectation is that dist.init_process_group has been called - # In there, Rank 0 generated ncclUniqueId - # Broadcasted it to all ranks via the store - # Each rank called ncclCommInitRank() - import torch.distributed as dist - pg = dist.group.WORLD - if pg is None: - logger.debug("No default process group available") - return None - - # Check if backend is NCCL - if dist.get_backend(pg) != "nccl": - logger.debug("ProcessGroup backend is not NCCL, cannot reuse communicator") - return None - - # Get the NCCL backend object via _get_backend (internal API) - if not hasattr(pg, "_get_backend"): - logger.debug("ProcessGroup does not have _get_backend method") - return None - - try: - backend = pg._get_backend(torch.device("cuda")) - except Exception as e: - logger.debug(f"Failed to get NCCL backend: {e}") - return None - - # now we have the backend - # Get comm pointer from the backend (internal API) - if not hasattr(backend, "_comm_ptr"): - logger.debug("NCCL backend does not have _comm_ptr method") - return None - - # Force NCCL communicator initialization with a dummy collective. - # PyTorch's ProcessGroupNCCL uses lazy initialization - the NCCL - # communicator is only created when the first collective operation - # is performed. Without this, _comm_ptr() returns 0. - try: - dummy = torch.zeros(1, device="cuda") - dist.all_reduce(dummy) - logger.debug("Forced NCCL initialization with dummy all_reduce") - except Exception as e: - logger.debug(f"Failed to force NCCL initialization: {e}") - return None - - try: - comm_ptr = backend._comm_ptr() - except Exception as e: - logger.debug(f"Failed to call _comm_ptr: {e}") - return None + if pg is None or dist.get_backend(pg) != "nccl": + raise RuntimeError("Default ProcessGroup must use NCCL backend") - if comm_ptr is None or comm_ptr == 0: - logger.debug("_comm_ptr returned None or 0") - return None + backend = pg._get_backend(torch.device("cuda")) - logger.info( - f"Reusing PyTorch's NCCL communicator (ptr={comm_ptr}, rank={self.rank})" - ) - return comm_ptr + # Force NCCL communicator initialization with a dummy collective + dummy = torch.zeros(1, device="cuda") + dist.all_reduce(dummy) - def _create_nccl_comm_via_nccl_lib(self) -> Any: - import nccl.core as nccl - import torch.distributed as dist + comm_ptr = backend._comm_ptr() + if comm_ptr is None or comm_ptr == 0: + raise RuntimeError("Failed to get NCCL communicator from ProcessGroup") - rank = self.rank - world_size = self.world_size - - # Generate unique ID on rank 0 and broadcast as a tensor - if rank == 0: - uid = nccl.get_unique_id() - uid_bytes = uid.as_bytes - uid_tensor = torch.frombuffer( - bytearray(uid_bytes), dtype=torch.uint8 - ).cuda() - logger.debug(f"Rank {rank} created NCCL unique ID ({len(uid_bytes)} bytes)") - else: - uid_tensor = torch.zeros(128, dtype=torch.uint8, device="cuda") + self._nccl_comm = comm_ptr - dist.broadcast(uid_tensor, src=0) - logger.debug(f"Rank {rank} received NCCL unique ID") + # Bind communicator to TRT execution context (PyCapsule required by TRT Python API) + if self.context is not None: + import ctypes - uid = nccl.UniqueId.from_bytes(bytes(uid_tensor.cpu().numpy())) + ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object + ctypes.pythonapi.PyCapsule_New.argtypes = [ + ctypes.c_void_p, + ctypes.c_char_p, + ctypes.c_void_p, + ] + comm_capsule = ctypes.pythonapi.PyCapsule_New(comm_ptr, None, None) + self.context.set_communicator(comm_capsule) - comm = nccl.Communicator.init(world_size, rank, uid) logger.info( - f"Created new NCCL communicator via nccl library (rank={rank}, world_size={world_size})" + f"NCCL comm set up (rank={self.rank}, world_size={self.world_size})" ) - self._nccl_comm_handle = comm - return comm.ptr - def setup_engine(self) -> None: assert ( self.target_platform == Platform.current_platform() @@ -575,9 +427,27 @@ def _load_from_state_dict( self.input_names = state_dict[prefix + "input_names"] self.output_names = state_dict[prefix + "output_names"] self.target_platform = state_dict[prefix + "platform"] - # Distributed info (optional, backward compatible with non-distributed models) - self.rank = state_dict.get(prefix + "rank", -1) - self.world_size = state_dict.get(prefix + "world_size", -1) + + build_rank = state_dict.get(prefix + "rank", -1) + build_world_size = state_dict.get(prefix + "world_size", -1) + import torch.distributed as dist + + if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + else: + self.rank = -1 + self.world_size = -1 + if build_world_size > 1: + if build_rank != self.rank: + logger.info( + f"Distributed engine originally built on rank {build_rank} of {build_world_size}, " + f"now running on rank {self.rank} of {self.world_size}" + ) + else: + logger.info( + f"Distributed engine: rank {self.rank} of {self.world_size}" + ) # Run multi-gpu device check to validate engine instantiation multi_gpu_device_check() @@ -936,7 +806,7 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: if self.is_distributed and self._nccl_comm is None: nccl_type = ( "native TRT collectives" - if self.has_native_trt_collectives + if ENABLED_FEATURES.native_trt_collectives else ( "TRT-LLM NCCL plugins" if ENABLED_FEATURES.trtllm_for_nccl @@ -947,7 +817,7 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: f"Setting up NCCL for distributed execution using {nccl_type} " f"(rank={self.rank}, world_size={self.world_size})" ) - self.setup_nccl() + self.setup_nccl_comm() logger.info(f"NCCL setup complete, comm={self._nccl_comm}") # If in safe mode, check at each iteration for whether a switch is required diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 4099e35a91..07ba016503 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -15,7 +15,6 @@ needs_torch_tensorrt_runtime, ) from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.runtime._nccl_utils import setup_nccl_library logger = logging.getLogger(__name__) @@ -37,6 +36,9 @@ TARGET_PLATFORM_IDX = -1 # Not implemented REQUIRES_OUTPUT_ALLOCATOR_IDX = -1 # Not implemented SERIALIZATION_LEN = -1 # Not implemented +IS_MD_ENGINE_IDX = -1 # Not implemented +OPTIONAL_RANK_IDX = -1 # Not implemented +OPTIONAL_WORLD_SIZE_IDX = -1 # Not implemented if ENABLED_FEATURES.torch_tensorrt_runtime: ABI_TARGET_IDX = torch.ops.tensorrt.ABI_TARGET_IDX() # 0 @@ -54,9 +56,10 @@ RESOURCE_ALLOCATION_STRATEGY_IDX = ( torch.ops.tensorrt.RESOURCE_ALLOCATION_STRATEGY_IDX() ) # 10 - RANK_IDX = torch.ops.tensorrt.RANK_IDX() # 11 - WORLD_SIZE_IDX = torch.ops.tensorrt.WORLD_SIZE_IDX() # 12 - SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 13 + IS_MD_ENGINE_IDX = torch.ops.tensorrt.IS_MD_ENGINE_IDX() # 11 + OPTIONAL_RANK_IDX = torch.ops.tensorrt.OPTIONAL_RANK_IDX() # 12 + OPTIONAL_WORLD_SIZE_IDX = torch.ops.tensorrt.OPTIONAL_WORLD_SIZE_IDX() # 13 + SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 14 @for_all_methods(needs_torch_tensorrt_runtime) @@ -91,8 +94,6 @@ def __init__( weight_name_map: Optional[dict[Any, Any]] = None, requires_output_allocator: bool = False, symbolic_shape_expressions: Optional[Dict[str, List[Dict[str, Any]]]] = None, - rank: int = -1, - world_size: int = 1, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines @@ -151,11 +152,6 @@ def __init__( self.requires_output_allocator = requires_output_allocator self.dynamically_allocate_resources = settings.dynamically_allocate_resources self.symbolic_shape_expressions = symbolic_shape_expressions - self.rank = rank - self.world_size = world_size - self._nccl_setup_done = ( - False # Track if NCCL has been setup for distributed mode - ) if ( serialized_engine @@ -213,8 +209,14 @@ def _pack_engine_info(self) -> List[str | bytes]: engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = str( int(self.dynamically_allocate_resources) ) - engine_info[RANK_IDX] = str(self.rank) - engine_info[WORLD_SIZE_IDX] = str(self.world_size) + import torch.distributed as dist + + is_md = dist.is_initialized() and dist.get_world_size() > 1 + engine_info[IS_MD_ENGINE_IDX] = str(int(is_md)) + # serialized engine info for build time rank and world size + if is_md: + engine_info[OPTIONAL_RANK_IDX] = str(dist.get_rank()) + engine_info[OPTIONAL_WORLD_SIZE_IDX] = str(dist.get_world_size()) return engine_info @@ -251,6 +253,16 @@ def use_dynamically_allocated_resources( self.dynamically_allocate_resources ) + def _get_default_group_name(self) -> str: + """Get the group name of the default ProcessGroup.""" + import torch.distributed as dist + + if dist.is_available() and dist.is_initialized(): + pg = dist.group.WORLD + if pg is not None and hasattr(pg, "group_name"): + return str(pg.group_name) + return "" + def setup_engine(self) -> None: """ Setup engine for a module which has deferred engine setup. @@ -264,6 +276,18 @@ def setup_engine(self) -> None: return self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info()) + # Distributed setup is split into two calls for TorchTensorRTModule (C++ runtime): + # 1. detect_distributed_context: sets rank/world_size on C++ engine (here, at setup time) + # Needed so rank/world_size are available for serialization before any forward call. + # 2. setup_nccl_comm: gets NCCL comm and binds to TRT context (lazily, in forward) + # Deferred because NCCL comm is only needed at execution time. + # + # In PythonTorchTensorRTModule, this is a single setup_nccl_comm() call in forward + # because rank/world_size are set in __init__ via dist.get_rank(). + group_name = self._get_default_group_name() + if group_name: + self.engine.detect_distributed_context(group_name) + def encode_metadata(self, metadata: Any) -> str: metadata = copy.deepcopy(metadata) dumped_metadata = pickle.dumps(metadata) @@ -341,159 +365,6 @@ def set_pre_allocated_outputs(self, enable: bool) -> None: def set_use_output_allocator(self, enable: bool) -> None: self.engine.use_output_allocator_outputs = enable - def _auto_init_distributed(self) -> None: - """ - Automatically initialize distributed inference if the engine was compiled with - rank and world_size set (from torch.distributed). - - This is called automatically after setup_engine() to configure NCCL communicators - for distributed inference without requiring manual setup. - """ - if self.engine is None: - return - - logger.debug( - f"In _auto_init_distributed: _nccl_setup_done={self._nccl_setup_done}" - ) - - if not ENABLED_FEATURES.native_trt_collectives: - logger.debug( - "TRT native NCCL collectives not available, skipping distributed setup" - ) - return - - # Check if the engine has distributed info set (rank >= 0 and world_size > 1) - if ( - self.engine.rank >= 0 - and self.engine.world_size > 1 - and not self._nccl_setup_done - ): - try: - import torch.distributed as dist - - self.set_distributed_info() - - # Only auto-initialize if torch.distributed is initialized - if dist.is_available() and dist.is_initialized(): - logger.debug( - f"Auto-initializing distributed inference " - f"(rank={self.engine.rank}, world_size={self.engine.world_size})" - ) - # this calls self.engine.set_process_group(process_group) - pg = dist.group.WORLD - self.init_nccl_comm(pg) - self._nccl_setup_done = True - logger.debug(f"NCCL setup complete (rank={self.engine.rank})") - else: - logger.warning( - f"Engine has distributed info (rank={self.engine.rank}, world_size={self.engine.world_size}) " - f"but torch.distributed is not initialized. " - f"Call dist.init_process_group() and then module.set_process_group() manually." - ) - except RuntimeError as e: - # Catch tracing errors specifically (e.g., "Tracer cannot infer type of ProcessGroup") - if "Tracer cannot infer" in str(e) or "traced functions" in str(e): - logger.debug("Skipping NCCL auto-init during tracing/compilation") - else: - logger.warning( - f"Failed to auto-initialize distributed inference: {e}. " - f"Call module.set_process_group() manually if needed." - ) - - def set_distributed_info( - self, rank: Optional[int] = None, world_size: Optional[int] = None - ) -> None: - """ - Set rank and world_size for distributed inference. - - This method sets the rank and world_size on the TensorRT engine. If not provided, - they will be auto-detected from torch.distributed. - - Args: - rank: Rank of the current process (auto-detects if None) - world_size: Total number of processes (auto-detects if None) - - """ - if self.engine is None: - raise RuntimeError( - "Engine has not been setup yet. Call setup_engine() first." - ) - - # Auto-detect if not provided - if rank is None or world_size is None: - import torch.distributed as dist - - if dist.is_available() and dist.is_initialized(): - if rank is None: - rank = dist.get_rank() - if world_size is None: - world_size = dist.get_world_size() - else: - raise RuntimeError( - "torch.distributed is not initialized and rank/world_size not provided. " - "Call dist.init_process_group() first or provide rank/world_size explicitly." - ) - - # Set on C++ TRTEngine - self.engine.set_rank(rank) - self.engine.set_world_size(world_size) - logger.debug( - f"Distributed info set on TRTEngine: rank={rank}, world_size={world_size}" - ) - - def init_nccl_comm(self, process_group: Optional[Any] = None) -> None: - """ - Initialize NCCL communicator for distributed execution. - - This method initializes the NCCL communicator from the C++ ProcessGroup registry. - The ProcessGroup must be registered in PyTorch's native registry (which happens - automatically when using torch.distributed). - """ - if not ENABLED_FEATURES.native_trt_collectives: - raise RuntimeError( - "Native TRT NCCL collectives are not available. " - "Requires TensorRT 10.16+ and PyTorch built with NCCL support." - ) - if self.engine is None: - raise RuntimeError( - "Engine has not been setup yet. Call setup_engine() first." - ) - - setup_nccl_library() - - # Get the process group if not provided - if process_group is None: - try: - import torch.distributed as dist - - if not dist.is_initialized(): - raise RuntimeError( - "torch.distributed is not initialized. Call dist.init_process_group() first." - ) - process_group = dist.distributed_c10d._get_default_group() - logger.debug("Using default ProcessGroup from torch.distributed") - except Exception as e: - raise RuntimeError(f"Failed to get default process group: {e}") - - # Get the group name from the ProcessGroup - # This is the name used to register the group in the C++ registry - group_name = "default" - if ( - hasattr(process_group, "group_name") - and process_group.group_name is not None - ): - group_name = process_group.group_name - logger.debug(f"Using ProcessGroup with group_name: {group_name}") - - # Initialize NCCL communicator from C++ registry - # This uses c10d::resolve_process_group() to get the ProcessGroup and extract the NCCL comm - self.engine.init_nccl_comm(group_name) - - self._nccl_setup_done = True - logger.debug( - f"NCCL comm initialized from ProcessGroup (rank={self.engine.rank}, world_size={self.engine.world_size})" - ) - def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: """Implementation of the forward pass for a TensorRT engine @@ -506,8 +377,12 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: if self.engine is None: raise RuntimeError("Engine has not been setup yet.") - if not self._nccl_setup_done: - self._auto_init_distributed() + # Lazy NCCL setup on first forward + if self.engine.world_size > 1 and not hasattr(self, "_nccl_initialized"): + group_name = self._get_default_group_name() + if group_name: + self.engine.setup_nccl_comm(group_name) + self._nccl_initialized = True assert len(inputs) == len( self.input_binding_names From 5b6abf4bd177fbb9bd5807e3db6e0014870eab6c Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 9 Apr 2026 12:31:52 -0700 Subject: [PATCH 4/7] remove nccl.h dependancy --- core/runtime/TRTEngine.cpp | 5 ++--- core/runtime/TRTEngine.h | 6 +----- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index b15f2093c5..ae20402810 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -573,7 +573,7 @@ void TRTEngine::setup_nccl_comm(const std::string& group_name) { "NCCL communicator not initialized for device " << this->device_info.id << ". Ensure a collective operation has been performed first."); - this->nccl_comm = reinterpret_cast(comm_ptr); + this->nccl_comm = reinterpret_cast(comm_ptr); set_nccl_communicator_to_trt_context(); LOG_INFO("NCCL comm set up (rank=" << this->rank << ", device=" << this->device_info.id << ")"); } @@ -582,8 +582,7 @@ bool TRTEngine::set_nccl_communicator_to_trt_context() { TORCHTRT_CHECK(exec_ctx != nullptr, "Cannot set NCCL communicator: execution context is null"); TORCHTRT_CHECK(this->nccl_comm != nullptr, "NCCL communicator is not set"); - void* comm_ptr = static_cast(this->nccl_comm); - exec_ctx->setCommunicator(comm_ptr); + exec_ctx->setCommunicator(this->nccl_comm); LOG_INFO( "NCCL communicator set on TensorRT execution context " diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 8e5dd78676..7fc4cc564b 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -28,10 +28,6 @@ #define ENABLE_TRT_NCCL_COLLECTIVES 1 #endif -#ifdef ENABLE_TRT_NCCL_COLLECTIVES -#include -#endif - namespace torch_tensorrt { namespace core { namespace runtime { @@ -218,7 +214,7 @@ struct TRTEngine : torch::CustomClassHolder { int64_t world_size = -1; #ifdef ENABLE_TRT_NCCL_COLLECTIVES - ncclComm_t nccl_comm = nullptr; + void* nccl_comm = nullptr; // Detect rank and world_size from ProcessGroup void detect_distributed_context(const std::string& group_name); From 8f911f577bcc74574d92f1dcb8e46cbc7791a6a5 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 9 Apr 2026 13:31:29 -0700 Subject: [PATCH 5/7] clean up import and add comment --- .../dynamo/runtime/_PythonTorchTensorRTModule.py | 9 +++------ py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py | 5 +---- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 51f804a854..4326f067cc 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -6,6 +6,7 @@ import tensorrt as trt import torch +import torch.distributed as dist import torch_tensorrt from torch.nn import Module from torch_tensorrt._Device import Device @@ -230,8 +231,6 @@ def __init__( self.setup_engine() # Auto-detect distributed context - import torch.distributed as dist - if dist.is_initialized(): self.rank = dist.get_rank() self.world_size = dist.get_world_size() @@ -315,8 +314,6 @@ def setup_nccl_comm(self) -> None: setup_nccl_library() - import torch.distributed as dist - if not dist.is_initialized(): raise RuntimeError( "torch.distributed must be initialized before calling setup_nccl_comm(). " @@ -428,10 +425,10 @@ def _load_from_state_dict( self.output_names = state_dict[prefix + "output_names"] self.target_platform = state_dict[prefix + "platform"] + # Same rule as C++ TRTEngine: serialized rank/world_size are build-time + # metadata for logging. Runtime rank is auto-detected from ProcessGroup. build_rank = state_dict.get(prefix + "rank", -1) build_world_size = state_dict.get(prefix + "world_size", -1) - import torch.distributed as dist - if dist.is_initialized(): self.rank = dist.get_rank() self.world_size = dist.get_world_size() diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 07ba016503..b1cd146881 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform from torch_tensorrt._features import ( @@ -209,8 +210,6 @@ def _pack_engine_info(self) -> List[str | bytes]: engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = str( int(self.dynamically_allocate_resources) ) - import torch.distributed as dist - is_md = dist.is_initialized() and dist.get_world_size() > 1 engine_info[IS_MD_ENGINE_IDX] = str(int(is_md)) # serialized engine info for build time rank and world size @@ -255,8 +254,6 @@ def use_dynamically_allocated_resources( def _get_default_group_name(self) -> str: """Get the group name of the default ProcessGroup.""" - import torch.distributed as dist - if dist.is_available() and dist.is_initialized(): pg = dist.group.WORLD if pg is not None and hasattr(pg, "group_name"): From b791b40f48304294cfedac8d600e3f15930d0c07 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 9 Apr 2026 13:47:12 -0700 Subject: [PATCH 6/7] moving setup_nccl_library call to example script --- .../tensor_parallel_simple_example.py | 3 + .../runtime/_PythonTorchTensorRTModule.py | 3 - tools/llm/tensor_parallel_llama_llm.py | 3 + trtengine_change.md | 751 ++++++++++++++++++ 4 files changed, 757 insertions(+), 3 deletions(-) create mode 100644 trtengine_change.md diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index 298862f3eb..4072e16fd1 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -58,6 +58,9 @@ "tensor_parallel_simple_example" ) import torch_tensorrt +from torch_tensorrt.dynamo.runtime._nccl_utils import setup_nccl_library + +setup_nccl_library() from torch.distributed._tensor import Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 4326f067cc..39764c6653 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -16,7 +16,6 @@ from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger -from torch_tensorrt.dynamo.runtime._nccl_utils import setup_nccl_library from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER from torch_tensorrt.runtime._utils import ( @@ -312,8 +311,6 @@ def setup_nccl_comm(self) -> None: if not self.is_distributed: return - setup_nccl_library() - if not dist.is_initialized(): raise RuntimeError( "torch.distributed must be initialized before calling setup_nccl_comm(). " diff --git a/tools/llm/tensor_parallel_llama_llm.py b/tools/llm/tensor_parallel_llama_llm.py index 04ac05fd79..4ce2524e6d 100644 --- a/tools/llm/tensor_parallel_llama_llm.py +++ b/tools/llm/tensor_parallel_llama_llm.py @@ -93,6 +93,9 @@ def initialize_logger( ) import torch_tensorrt +from torch_tensorrt.dynamo.runtime._nccl_utils import setup_nccl_library + +setup_nccl_library() from torch.distributed._tensor import Replicate, Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, diff --git a/trtengine_change.md b/trtengine_change.md new file mode 100644 index 0000000000..3cd10c8036 --- /dev/null +++ b/trtengine_change.md @@ -0,0 +1,751 @@ +# Design Changes for PR #4157 + +This document contains the exact code changes for the redesigned Multi-Device TensorRT Runtime, addressing review comments. + +**Note:** Build config changes (MODULE.bazel, pyproject.toml, setup.py, py/requirements.txt) and debug logging additions (backends.py, remove_sym_nodes.py, partitioning/common.py, utils.py) are NOT included — those are local environment changes. + +--- + +## 1. `core/runtime/runtime.h` + +**Change:** ABI version bump and renamed serialization indices. + +```diff +-const std::string ABI_VERSION = "8"; ++const std::string ABI_VERSION = "9"; +``` + +```diff +- RANK_IDX, +- WORLD_SIZE_IDX, ++ IS_MD_ENGINE_IDX, ++ OPTIONAL_RANK_IDX, ++ OPTIONAL_WORLD_SIZE_IDX, + SERIALIZATION_LEN, +``` + +--- + +## 2. `core/runtime/TRTEngine.h` + +**Change:** Removed `set_rank`, `set_world_size`, `set_nccl_comm`, `init_nccl_comm`, `set_process_group_from_registry`. Added `detect_distributed_context` and `setup_nccl_comm`. + +```diff + // Distributed inference fields (-1 indicates non-distributed mode) + int64_t rank = -1; + int64_t world_size = -1; + +- // Set rank and world_size for distributed inference +- void set_rank(int64_t rank_val); +- void set_world_size(int64_t world_size_val); +- + #ifdef ENABLE_TRT_NCCL_COLLECTIVES + ncclComm_t nccl_comm = nullptr; +- void set_nccl_comm(int64_t comm_ptr); +- void init_nccl_comm(const std::string& group_name = "default"); +- bool set_process_group_from_registry(const std::string& group_name = "default"); ++ ++ // Detect rank and world_size from ProcessGroup ++ void detect_distributed_context(const std::string& group_name); ++ ++ // Resolve ProcessGroup, get NCCL communicator, and bind to TRT context ++ void setup_nccl_comm(const std::string& group_name); + bool set_nccl_communicator_to_trt_context(); + #endif +``` + +--- + +## 3. `core/runtime/TRTEngine.cpp` + +### 3a. Constructor 2 (deserialization) — log build-time rank, don't overwrite + +```diff + : ResourceAllocationStrategy::kStatic)) { +- // Load distributed info if available (backward compatible with older ABI versions) +- if (serialized_info.size() > RANK_IDX && !serialized_info[RANK_IDX].empty()) { +- this->rank = std::stoll(serialized_info[RANK_IDX]); +- } +- if (serialized_info.size() > WORLD_SIZE_IDX && !serialized_info[WORLD_SIZE_IDX].empty()) { +- this->world_size = std::stoll(serialized_info[WORLD_SIZE_IDX]); +- } ++ if (std::stoi(serialized_info[IS_MD_ENGINE_IDX])) { ++ int64_t build_rank = std::stoll(serialized_info[OPTIONAL_RANK_IDX]); ++ int64_t build_world_size = std::stoll(serialized_info[OPTIONAL_WORLD_SIZE_IDX]); ++ if (build_rank != this->rank) { ++ LOG_INFO( ++ "Distributed engine originally built on rank " << build_rank << " of " << build_world_size ++ << ", now running on rank " << this->rank << " of " << this->world_size); ++ } else { ++ LOG_INFO( ++ "Distributed engine: rank " << this->rank << " of " << this->world_size); ++ } ++ } + } +``` + +### 3b. Constructor 3 — no distributed logic (removed detect_distributed_context call) + +No changes to constructor 3. It is clean — no distributed code. + +### 3c. Removed set_rank, set_world_size + +```diff +-void TRTEngine::set_rank(int64_t rank_val) { +- this->rank = rank_val; +- LOG_DEBUG("Rank set on TRTEngine: " << this->rank); +-} +- +-void TRTEngine::set_world_size(int64_t world_size_val) { +- this->world_size = world_size_val; +- LOG_DEBUG("World size set on TRTEngine: " << this->world_size); +-} +``` + +### 3d. Removed set_nccl_comm, init_nccl_comm, set_process_group_from_registry + +All three functions removed entirely. + +### 3e. New: detect_distributed_context + +```cpp +#ifdef ENABLE_TRT_NCCL_COLLECTIVES +void TRTEngine::detect_distributed_context(const std::string& group_name) { + auto pg = c10d::resolve_process_group(group_name); + if (pg) { + this->rank = pg->getRank(); + this->world_size = pg->getSize(); + LOG_DEBUG("Detected distributed context: rank=" << this->rank << ", world_size=" << this->world_size); + } +} +``` + +### 3f. New: setup_nccl_comm (replaces set_process_group_from_registry) + +```cpp +void TRTEngine::setup_nccl_comm(const std::string& group_name) { + auto pg = c10d::resolve_process_group(group_name); + TORCHTRT_CHECK(pg != nullptr, "ProcessGroup '" << group_name << "' not found in registry"); + + auto backend = pg->getBackend(c10d::ProcessGroup::BackendType::NCCL); + TORCHTRT_CHECK(backend != nullptr, "ProcessGroup '" << group_name << "' has no NCCL backend"); + + auto* nccl_pg = dynamic_cast(backend.get()); + TORCHTRT_CHECK(nccl_pg != nullptr, "Backend is not ProcessGroupNCCL"); + + at::cuda::set_device(this->device_info.id); + + int64_t comm_ptr = nccl_pg->getCommPtr(); + TORCHTRT_CHECK( + comm_ptr != 0, + "NCCL communicator not initialized for device " << this->device_info.id + << ". Ensure a collective operation has been performed first."); + + this->nccl_comm = reinterpret_cast(comm_ptr); + set_nccl_communicator_to_trt_context(); + LOG_INFO("NCCL comm set up (rank=" << this->rank << ", device=" << this->device_info.id << ")"); +} +``` + +### 3g. set_nccl_communicator_to_trt_context — replaced try-catch with TORCHTRT_CHECK + +```diff + bool TRTEngine::set_nccl_communicator_to_trt_context() { +- if (!exec_ctx) { +- LOG_ERROR("Cannot set NCCL communicator: execution context is null"); +- return false; +- } +- if (this->nccl_comm == nullptr) { +- LOG_WARNING(...); +- return false; +- } +- try { +- void* comm_ptr = static_cast(this->nccl_comm); +- exec_ctx->setCommunicator(comm_ptr); +- LOG_INFO(...); +- return true; +- } catch (const std::exception& e) { +- LOG_ERROR(...); +- return false; +- } ++ TORCHTRT_CHECK(exec_ctx != nullptr, "Cannot set NCCL communicator: execution context is null"); ++ TORCHTRT_CHECK(this->nccl_comm != nullptr, "NCCL communicator is not set"); ++ ++ void* comm_ptr = static_cast(this->nccl_comm); ++ exec_ctx->setCommunicator(comm_ptr); ++ ++ LOG_INFO( ++ "NCCL communicator set on TensorRT execution context " ++ "(rank=" << this->rank << ", device=" << this->device_info.id << ")"); ++ return true; + } +``` + +### 3h. serialize() — write IS_MD_ENGINE and optional rank/world_size + +```diff + serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = + this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; ++ bool is_md = this->world_size > 1; ++ serialized_info[IS_MD_ENGINE_IDX] = is_md ? "1" : "0"; ++ if (is_md) { ++ serialized_info[OPTIONAL_RANK_IDX] = std::to_string(this->rank); ++ serialized_info[OPTIONAL_WORLD_SIZE_IDX] = std::to_string(this->world_size); ++ } + + return serialized_info; +``` + +--- + +## 4. `core/runtime/register_jit_hooks.cpp` + +### 4a. Removed old bindings, added new ones + +```diff + .def_readonly("rank", &TRTEngine::rank) + .def_readonly("world_size", &TRTEngine::world_size) +- .def("set_rank", &TRTEngine::set_rank) +- .def("set_world_size", &TRTEngine::set_world_size) + #ifdef ENABLE_TRT_NCCL_COLLECTIVES +- .def("set_nccl_comm", &TRTEngine::set_nccl_comm) + .def( +- "init_nccl_comm", +- [](c10::intrusive_ptr self, std::string group_name = "default") { +- self->init_nccl_comm(group_name); ++ "detect_distributed_context", ++ [](c10::intrusive_ptr self, std::string group_name) { ++ self->detect_distributed_context(group_name); ++ }) ++ .def( ++ "setup_nccl_comm", ++ [](c10::intrusive_ptr self, std::string group_name) { ++ self->setup_nccl_comm(group_name); + }) + #endif +``` + +### 4b. Updated constant names + +```diff +- m.def("RANK_IDX", []() -> int64_t { return RANK_IDX; }); +- m.def("WORLD_SIZE_IDX", []() -> int64_t { return WORLD_SIZE_IDX; }); ++ m.def("IS_MD_ENGINE_IDX", []() -> int64_t { return IS_MD_ENGINE_IDX; }); ++ m.def("OPTIONAL_RANK_IDX", []() -> int64_t { return OPTIONAL_RANK_IDX; }); ++ m.def("OPTIONAL_WORLD_SIZE_IDX", []() -> int64_t { return OPTIONAL_WORLD_SIZE_IDX; }); +``` + +--- + +## 5. `core/runtime/execute_engine.cpp` + +**Change:** Only binds NCCL comm to TRT context. Does NOT call `setup_nccl_comm` — Python handles it. + +```diff +- // Distributed setup - set NCCL communicator on TensorRT execution context ++ // Distributed setup - bind NCCL communicator to TRT execution context ++ // setup_nccl_comm must have been called from Python before first forward + #ifdef ENABLE_TRT_NCCL_COLLECTIVES +- if (compiled_engine->rank >= 0 && compiled_engine->world_size > 1) { +- bool result = compiled_engine->set_nccl_communicator_to_trt_context(); +- if (!result) { +- LOG_ERROR("Failed to set NCCL communicator on TRT context"); +- LOG_ERROR("This will cause collective operations to fail at runtime"); +- LOG_ERROR("Make sure to call module.init_nccl_comm() after compilation"); +- } +- } else { +- LOG_DEBUG( +- "Single-device mode (rank=" << compiled_engine->rank << ", world_size=" << compiled_engine->world_size +- << ") - skipping NCCL setup"); ++ if (compiled_engine->world_size > 1 && compiled_engine->nccl_comm != nullptr) { ++ compiled_engine->set_nccl_communicator_to_trt_context(); + } + #endif +``` + +--- + +## 6. `py/torch_tensorrt/dynamo/conversion/_conversion.py` + +**Change:** Removed rank/world_size detection and passing to module constructor. + +```diff +- rank = -1 +- world_size = -1 + if settings.use_distributed_mode_trace: +- import os +- import torch.distributed as dist + # Check if distributed backends are available + ... + + return rt_cls( + serialized_engine=..., + ... +- rank=rank, +- world_size=world_size, + ) +``` + +--- + +## 7. `py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py` + +### 7a. Updated constants + +```diff +- RANK_IDX = torch.ops.tensorrt.RANK_IDX() # 11 +- WORLD_SIZE_IDX = torch.ops.tensorrt.WORLD_SIZE_IDX() # 12 +- SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 13 ++ IS_MD_ENGINE_IDX = torch.ops.tensorrt.IS_MD_ENGINE_IDX() # 11 ++ OPTIONAL_RANK_IDX = torch.ops.tensorrt.OPTIONAL_RANK_IDX() # 12 ++ OPTIONAL_WORLD_SIZE_IDX = torch.ops.tensorrt.OPTIONAL_WORLD_SIZE_IDX() # 13 ++ SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 14 +``` + +### 7b. Constructor — removed rank/world_size args + +```diff + def __init__( + self, + serialized_engine: Optional[bytes] = None, + ... +- rank: int = -1, +- world_size: int = 1, + ): +``` + +Removed `self.rank = rank`, `self.world_size = world_size`, `self._nccl_setup_done`. + +### 7c. New helper: _get_default_group_name + +```python +def _get_default_group_name(self) -> str: + """Get the group name of the default ProcessGroup.""" + import torch.distributed as dist + if dist.is_available() and dist.is_initialized(): + pg = dist.group.WORLD + if pg is not None and hasattr(pg, "group_name"): + return pg.group_name + return "" +``` + +### 7d. setup_engine — calls detect_distributed_context + +```diff + def setup_engine(self) -> None: + if self.engine is not None: + return + self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info()) ++ ++ # Detect distributed context (rank/world_size) from ProcessGroup ++ group_name = self._get_default_group_name() ++ if group_name: ++ self.engine.detect_distributed_context(group_name) +``` + +### 7e. _pack_engine_info — uses dist.is_initialized + +```diff +- engine_info[RANK_IDX] = str(self.rank) +- engine_info[WORLD_SIZE_IDX] = str(self.world_size) ++ import torch.distributed as dist ++ is_md = dist.is_initialized() and dist.get_world_size() > 1 ++ engine_info[IS_MD_ENGINE_IDX] = str(int(is_md)) ++ if is_md: ++ engine_info[OPTIONAL_RANK_IDX] = str(dist.get_rank()) ++ engine_info[OPTIONAL_WORLD_SIZE_IDX] = str(dist.get_world_size()) +``` + +### 7f. forward — lazy NCCL setup + +```diff + def forward(self, *inputs): + if self.engine is None: + raise RuntimeError("Engine has not been setup yet.") + ++ # Lazy NCCL setup on first forward ++ if self.engine.world_size > 1 and not hasattr(self, '_nccl_initialized'): ++ group_name = self._get_default_group_name() ++ if group_name: ++ self.engine.setup_nccl_comm(group_name) ++ self._nccl_initialized = True ++ + assert len(inputs) == len(self.input_binding_names), ... +``` + +### 7g. Removed functions + +- `_auto_init_distributed()` — replaced by lazy setup in forward +- `set_distributed_info()` — called removed `set_rank`/`set_world_size` +- `init_nccl_comm()` — replaced by `setup_nccl_comm` in forward +- `setup_nccl_library` import — no longer needed + +--- + +## 8. `py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py` + +### 8a. Constructor — removed rank/world_size args, auto-detect + +```diff + def __init__( + self, + ... +- rank: int = -1, +- world_size: int = 1, + ): + ... +- self.rank = rank +- self.world_size = world_size ++ # Auto-detect distributed context ++ import torch.distributed as dist ++ if dist.is_initialized(): ++ self.rank = dist.get_rank() ++ self.world_size = dist.get_world_size() ++ else: ++ self.rank = -1 ++ self.world_size = -1 + self._nccl_comm: Optional[Any] = None +``` + +### 8b. Simplified setup_nccl_comm + +Replaced `setup_nccl`, `set_nccl_communicator`, `get_nccl_communicator`, `_get_nccl_comm_from_process_group`, `_create_nccl_comm_via_nccl_lib` with a single function: + +```python +def setup_nccl_comm(self) -> None: + """Set up NCCL communicator from PyTorch's ProcessGroup. + + In PythonTorchTensorRTModule, this is a single call that gets the NCCL comm + and binds it to the TRT context. rank/world_size are already set in __init__ + via dist.get_rank(). + + In TorchTensorRTModule (C++ runtime), this is split into two calls: + - detect_distributed_context(group_name): sets rank/world_size on the C++ engine + (called in setup_engine, needed for serialization before forward) + - setup_nccl_comm(group_name): gets NCCL comm and binds to TRT context + (called lazily on first forward) + """ + if not self.is_distributed: + return + + setup_nccl_library() + + import torch.distributed as dist + if not dist.is_initialized(): + raise RuntimeError( + "torch.distributed must be initialized before calling setup_nccl(). " + "Call dist.init_process_group('nccl') first." + ) + + pg = dist.group.WORLD + if pg is None or dist.get_backend(pg) != "nccl": + raise RuntimeError("Default ProcessGroup must use NCCL backend") + + backend = pg._get_backend(torch.device("cuda")) + + # Force NCCL communicator initialization with a dummy collective + dummy = torch.zeros(1, device="cuda") + dist.all_reduce(dummy) + + comm_ptr = backend._comm_ptr() + if comm_ptr is None or comm_ptr == 0: + raise RuntimeError("Failed to get NCCL communicator from ProcessGroup") + + self._nccl_comm = comm_ptr + + # Bind communicator to TRT execution context (PyCapsule required by TRT Python API) + if self.context is not None: + import ctypes + ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object + ctypes.pythonapi.PyCapsule_New.argtypes = [ + ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p, + ] + comm_capsule = ctypes.pythonapi.PyCapsule_New(comm_ptr, None, None) + self.context.set_communicator(comm_capsule) + + logger.info(f"NCCL comm set up (rank={self.rank}, world_size={self.world_size})") +``` + +### 8c. Removed functions + +- `get_rank()`, `get_world_size()` — fields are public +- `set_nccl_communicator()` — merged into `setup_nccl` +- `get_nccl_communicator()` — `_nccl_comm` is accessible +- `has_native_trt_collectives` property — use `ENABLED_FEATURES.native_trt_collectives` +- `_create_nccl_comm_via_nccl_lib()` — removed `nccl.core` dependency +- `_get_nccl_comm_from_process_group()` — merged into `setup_nccl` + +### 8d. _load_from_state_dict — auto-detect rank, log build-time rank + +```diff + self.target_platform = state_dict[prefix + "platform"] +- self.rank = state_dict.get(prefix + "rank", -1) +- self.world_size = state_dict.get(prefix + "world_size", -1) ++ ++ build_rank = state_dict.get(prefix + "rank", -1) ++ build_world_size = state_dict.get(prefix + "world_size", -1) ++ import torch.distributed as dist ++ if dist.is_initialized(): ++ self.rank = dist.get_rank() ++ self.world_size = dist.get_world_size() ++ else: ++ self.rank = -1 ++ self.world_size = -1 ++ if build_world_size > 1: ++ if build_rank != self.rank: ++ logger.info( ++ f"Distributed engine originally built on rank {build_rank} of {build_world_size}, " ++ f"now running on rank {self.rank} of {self.world_size}" ++ ) ++ else: ++ logger.info(f"Distributed engine: rank {self.rank} of {self.world_size}") +``` + +### 8e. forward — uses ENABLED_FEATURES directly + +```diff + if self.is_distributed and self._nccl_comm is None: + nccl_type = ( + "native TRT collectives" +- if self.has_native_trt_collectives ++ if ENABLED_FEATURES.native_trt_collectives + else ( +``` + +--- + +## 9. Remove `nccl.h` dependency — use `void*` for NCCL communicator + +**Rationale:** `nccl.h` is not a Bazel dependency — it's picked up from the system/PyTorch install path. Using `void*` instead of `ncclComm_t` removes this fragile dependency. We don't own the communicator (PyTorch's ProcessGroupNCCL owns it), so we just pass it as an opaque pointer to TRT's `setCommunicator(void*)`. + +### `core/runtime/TRTEngine.h` + +```diff + #ifdef ENABLE_TRT_NCCL_COLLECTIVES +-#include ++// Using void* instead of ncclComm_t to avoid nccl.h dependency. ++// We don't own the communicator — it's owned by PyTorch's ProcessGroupNCCL. + #endif +``` + +```diff + #ifdef ENABLE_TRT_NCCL_COLLECTIVES +- ncclComm_t nccl_comm = nullptr; ++ void* nccl_comm = nullptr; +``` + +### `core/runtime/TRTEngine.cpp` + +In `setup_nccl_comm`: +```diff +- this->nccl_comm = reinterpret_cast(comm_ptr); ++ this->nccl_comm = reinterpret_cast(comm_ptr); +``` + +In `set_nccl_communicator_to_trt_context`: +```diff +- void* comm_ptr = static_cast(this->nccl_comm); +- exec_ctx->setCommunicator(comm_ptr); ++ exec_ctx->setCommunicator(this->nccl_comm); +``` + +Also update section 3f `setup_nccl_comm` code to use `void*`: + +In section 3f above, replace: +```cpp + this->nccl_comm = reinterpret_cast(comm_ptr); +``` +with: +```cpp + this->nccl_comm = reinterpret_cast(comm_ptr); +``` + +And section 3g `set_nccl_communicator_to_trt_context` simplifies to: +```cpp +bool TRTEngine::set_nccl_communicator_to_trt_context() { + TORCHTRT_CHECK(exec_ctx != nullptr, "Cannot set NCCL communicator: execution context is null"); + TORCHTRT_CHECK(this->nccl_comm != nullptr, "NCCL communicator is not set"); + + exec_ctx->setCommunicator(this->nccl_comm); + + LOG_INFO( + "NCCL communicator set on TensorRT execution context " + "(rank=" << this->rank << ", device=" << this->device_info.id << ")"); + return true; +} +``` + +--- + +## Compatibility bug fixes (for PyTorch 2.10 / NGC 26.01) + +These are separate from the design changes but needed to run on the test environment: + +### `_FakeTensorUpdater.py` — guard for `torch._inductor.fx_passes.reinplace` + +```python +is_scatter = False +if hasattr(torch._inductor.fx_passes, "reinplace"): + is_scatter = ( + node.target + is torch._inductor.fx_passes.reinplace._generalized_scatter + ) +``` + +### `fuse_distributed_ops.py` — handle all_reduce with 3 args + +```python +fused_args = tuple(node.args) +if len(fused_args) < 4: + logger.debug(f"all_reduce node has {len(fused_args)} args instead of 4") +``` + +### `_TorchTensorRTModule.py` — typo fix + +```diff +- self.int_nccl_comm(pg) ++ self.init_nccl_comm(pg) +``` + +(This line is now removed entirely in the redesign, but was needed for the original PR.) + +--- + +## 10. Move `import torch.distributed as dist` to top-level + +Both Python runtime modules had `import torch.distributed as dist` scattered as local imports +inside multiple functions. Moved to top-level since `torch.distributed` is part of PyTorch +(no external dependency). + +### `_TorchTensorRTModule.py` + +```diff + import torch ++import torch.distributed as dist + from torch_tensorrt._Device import Device +``` + +Removed local imports from `_pack_engine_info()` and `_get_default_group_name()`. + +### `_PythonTorchTensorRTModule.py` + +```diff + import torch ++import torch.distributed as dist + import torch_tensorrt +``` + +Removed local imports from `__init__()`, `setup_nccl_comm()`, and `_load_from_state_dict()`. + +Added comment in `_load_from_state_dict` explaining the design rule: + +```python +# Same rule as C++ TRTEngine: serialized rank/world_size are build-time +# metadata for logging. Runtime rank is auto-detected from ProcessGroup. +build_rank = state_dict.get(prefix + "rank", -1) +build_world_size = state_dict.get(prefix + "world_size", -1) +if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() +else: + self.rank = -1 + self.world_size = -1 +if build_world_size > 1: + if build_rank != self.rank: + logger.info( + f"Distributed engine originally built on rank {build_rank} of {build_world_size}, " + f"now running on rank {self.rank} of {self.world_size}" + ) + else: + logger.info(f"Distributed engine: rank {self.rank} of {self.world_size}") +``` + +--- + +## 11. Function naming: `setup_nccl_comm` in both runtimes + +Both runtime modules use `setup_nccl_comm` as the function name for setting up NCCL, +but they work differently due to the C++ vs Python runtime distinction: + +### `_TorchTensorRTModule` (C++ runtime) — two separate calls + +```python +# In setup_engine(): +self.engine.detect_distributed_context(group_name) # sets rank/world_size on C++ engine + +# In forward() (lazily): +self.engine.setup_nccl_comm(group_name) # gets NCCL comm, binds to TRT context +``` + +**Why split:** rank/world_size must be available for serialization before any forward call. +The NCCL communicator is only needed at execution time. + +### `_PythonTorchTensorRTModule` (Python runtime) — single call + +```python +# In forward() (lazily): +self.setup_nccl_comm() # gets NCCL comm, converts to PyCapsule, sets on TRT context +``` + +**Why single:** rank/world_size are already set in `__init__` via `dist.get_rank()`. +No C++ engine to populate. Only need to get the NCCL comm and bind it. + +Comment in `_PythonTorchTensorRTModule.setup_nccl_comm`: +```python +def setup_nccl_comm(self) -> None: + """Set up NCCL communicator from PyTorch's ProcessGroup. + + In PythonTorchTensorRTModule, this is a single call that gets the NCCL comm + and binds it to the TRT context. rank/world_size are already set in __init__ + via dist.get_rank(). + + In TorchTensorRTModule (C++ runtime), this is split into two calls: + - detect_distributed_context(group_name): sets rank/world_size on the C++ engine + (called in setup_engine, needed for serialization before forward) + - setup_nccl_comm(group_name): gets NCCL comm and binds to TRT context + (called lazily on first forward) + """ +``` + +## 12. Move `setup_nccl_library()` to user scripts + +**Rationale:** `setup_nccl_library()` sets `LD_LIBRARY_PATH` so TensorRT can find `libnccl.so`. +This is a one-time environment setup, not an engine-level concern. The reviewer said this +should be a utility the user calls, not hidden inside engine code. + +### `_PythonTorchTensorRTModule.py` — removed call and import + +```diff +-from torch_tensorrt.dynamo.runtime._nccl_utils import setup_nccl_library +``` + +```diff + def setup_nccl_comm(self) -> None: + if not self.is_distributed: + return + +- setup_nccl_library() +- + if not dist.is_initialized(): +``` + +### Example scripts — added call after imports + +`examples/distributed_inference/tensor_parallel_simple_example.py`: +```python +import torch_tensorrt +from torch_tensorrt.dynamo.runtime._nccl_utils import setup_nccl_library + +setup_nccl_library() +``` + +`tools/llm/tensor_parallel_llama_llm.py`: +```python +import torch_tensorrt +from torch_tensorrt.dynamo.runtime._nccl_utils import setup_nccl_library + +setup_nccl_library() +``` + +The user is now responsible for calling `setup_nccl_library()` once before +distributed TRT inference. The function remains in `_nccl_utils` as a public utility. From b5c203effbfaeea08f9d41731a58f2a0e5d9a54f Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 9 Apr 2026 17:25:10 -0700 Subject: [PATCH 7/7] work on the save/load export part-add is_md flag, guard export tracing and enable DTensor decomposition --- core/runtime/TRTEngine.cpp | 29 ++++++++++--------- core/runtime/TRTEngine.h | 1 + core/runtime/execute_engine.cpp | 2 +- core/runtime/register_jit_hooks.cpp | 1 + .../tensor_parallel_simple_example.py | 6 ++-- py/torch_tensorrt/dynamo/_compiler.py | 18 ++++++++++-- py/torch_tensorrt/dynamo/_refit.py | 4 ++- py/torch_tensorrt/dynamo/backend/backends.py | 5 +++- .../dynamo/lowering/_decompositions.py | 13 +++++++-- .../dynamo/runtime/_TorchTensorRTModule.py | 6 +++- 10 files changed, 59 insertions(+), 26 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index ae20402810..e42f8268cc 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -96,17 +96,11 @@ TRTEngine::TRTEngine(std::vector serialized_info) (static_cast(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic : ResourceAllocationStrategy::kStatic)) { - if (std::stoi(serialized_info[IS_MD_ENGINE_IDX])) { - int64_t build_rank = std::stoll(serialized_info[OPTIONAL_RANK_IDX]); - int64_t build_world_size = std::stoll(serialized_info[OPTIONAL_WORLD_SIZE_IDX]); - if (build_rank != this->rank) { - LOG_INFO( - "Distributed engine originally built on rank " << build_rank << " of " << build_world_size - << ", now running on rank " << this->rank << " of " - << this->world_size); - } else { - LOG_INFO("Distributed engine: rank " << this->rank << " of " << this->world_size); - } + this->is_md = std::stoi(serialized_info[IS_MD_ENGINE_IDX]); + if (this->is_md) { + LOG_INFO( + "Loaded distributed engine (built on rank " << serialized_info[OPTIONAL_RANK_IDX] << " of " + << serialized_info[OPTIONAL_WORLD_SIZE_IDX] << ")"); } } @@ -517,9 +511,8 @@ std::vector TRTEngine::serialize() { serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize(); serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; - bool is_md = this->world_size > 1; - serialized_info[IS_MD_ENGINE_IDX] = is_md ? "1" : "0"; - if (is_md) { + serialized_info[IS_MD_ENGINE_IDX] = this->is_md ? "1" : "0"; + if (this->is_md) { serialized_info[OPTIONAL_RANK_IDX] = std::to_string(this->rank); serialized_info[OPTIONAL_WORLD_SIZE_IDX] = std::to_string(this->world_size); } @@ -551,6 +544,7 @@ void TRTEngine::detect_distributed_context(const std::string& group_name) { if (pg) { this->rank = pg->getRank(); this->world_size = pg->getSize(); + this->is_md = this->world_size > 1; LOG_DEBUG("Detected distributed context: rank=" << this->rank << ", world_size=" << this->world_size); } } @@ -559,6 +553,13 @@ void TRTEngine::setup_nccl_comm(const std::string& group_name) { auto pg = c10d::resolve_process_group(group_name); TORCHTRT_CHECK(pg != nullptr, "ProcessGroup '" << group_name << "' not found in registry"); + // Set rank/world_size if not already set (e.g. load from disk without setup_engine) + if (this->rank < 0) { + this->rank = pg->getRank(); + this->world_size = pg->getSize(); + LOG_DEBUG("Set distributed context in setup_nccl_comm: rank=" << this->rank << ", world_size=" << this->world_size); + } + auto backend = pg->getBackend(c10d::ProcessGroup::BackendType::NCCL); TORCHTRT_CHECK(backend != nullptr, "ProcessGroup '" << group_name << "' has no NCCL backend"); diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 7fc4cc564b..35591931aa 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -210,6 +210,7 @@ struct TRTEngine : torch::CustomClassHolder { std::shared_ptr output_allocator; // Member variables for distributed inference (-1 indicates non-distributed mode) + bool is_md = false; int64_t rank = -1; int64_t world_size = -1; diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 137358efa3..0c7d91848e 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -314,7 +314,7 @@ std::vector execute_engine(std::vector inputs, c10::intr // Distributed setup - bind NCCL communicator to TRT execution context // setup_nccl_comm must have been called from Python before first forward #ifdef ENABLE_TRT_NCCL_COLLECTIVES - if (compiled_engine->world_size > 1 && compiled_engine->nccl_comm != nullptr) { + if (compiled_engine->is_md && compiled_engine->nccl_comm != nullptr) { compiled_engine->set_nccl_communicator_to_trt_context(); } #endif diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 91331a8e52..aaaecbabec 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -108,6 +108,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = &TRTEngine::set_device_memory_budget) .def_property("streamable_device_memory_budget", &TRTEngine::get_streamable_device_memory_budget) .def_property("automatic_device_memory_budget", &TRTEngine::get_automatic_device_memory_budget) + .def_readonly("is_md", &TRTEngine::is_md) .def_readonly("rank", &TRTEngine::rank) .def_readonly("world_size", &TRTEngine::world_size) #ifdef ENABLE_TRT_NCCL_COLLECTIVES diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index 4072e16fd1..a7a8d55b8b 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -119,8 +119,8 @@ def forward(self, x): if args.mode == "load": # Load per-rank model: /tmp/tp_model.ep -> /tmp/tp_model_rank0_of_2.ep logger.info(f"Loading from {args.save_path}") - loaded_model = torch_tensorrt.load(args.save_path) - output = loaded_model(inp) + loaded_program = torch_tensorrt.load(args.save_path) + output = loaded_program.module()(inp) assert (python_result - output).std() < 0.01, "Result mismatch" logger.info("Load successful!") @@ -164,7 +164,7 @@ def forward(self, x): inputs=[inp], # enabled_precisions={torch.float32, torch.float16}, truncate_double=True, - use_python_runtime=True, + use_python_runtime=False, min_block_size=1, use_distributed_mode_trace=True, ) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index bc3cdc5721..88c941e52f 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -370,7 +370,11 @@ def cross_compile_for_windows( logger.info("Compilation Settings: %s\n", settings) exported_program = pre_export_lowering(exported_program, settings) exported_program = exported_program.run_decompositions( - get_decompositions(enable_experimental_decompositions, decompose_attention) + get_decompositions( + enable_experimental_decompositions, + decompose_attention, + use_distributed_mode_trace, + ) ) gm = exported_program.module() @@ -769,7 +773,11 @@ def compile( logger.info("Compilation Settings: %s\n", settings) exported_program = pre_export_lowering(exported_program, settings) exported_program = exported_program.run_decompositions( - get_decompositions(enable_experimental_decompositions, decompose_attention) + get_decompositions( + enable_experimental_decompositions, + decompose_attention, + use_distributed_mode_trace, + ) ) gm = exported_program.module() @@ -1418,7 +1426,11 @@ def convert_exported_program_to_serialized_trt_engine( logger.info("Compilation Settings: %s\n", settings) exported_program = pre_export_lowering(exported_program, settings) exported_program = exported_program.run_decompositions( - get_decompositions(enable_experimental_decompositions, decompose_attention) + get_decompositions( + enable_experimental_decompositions, + decompose_attention, + use_distributed_mode_trace, + ) ) gm = exported_program.module() diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 0b6af849fa..c179982b3d 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -360,7 +360,9 @@ def refit_module_weights( new_weight_module = pre_export_lowering(new_weight_module, settings) new_weight_module = new_weight_module.run_decompositions( get_decompositions( - settings.enable_experimental_decompositions, settings.decompose_attention + settings.enable_experimental_decompositions, + settings.decompose_attention, + settings.use_distributed_mode_trace, ) ) new_gm = new_weight_module.module() diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 00fb6977e8..b737b0c2d9 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -62,7 +62,9 @@ def aot_torch_tensorrt_aten_backend( ) settings_aot_autograd = {} settings_aot_autograd["decompositions"] = get_decompositions( - settings.enable_experimental_decompositions, settings.decompose_attention + settings.enable_experimental_decompositions, + settings.decompose_attention, + settings.use_distributed_mode_trace, ) # This is added since detach lowering leads to alias nodes # Error - View operation returned a tensor that is the same as the input base tensor @@ -143,6 +145,7 @@ def _pretraced_backend( decompositions=get_decompositions( settings.enable_experimental_decompositions, settings.decompose_attention, + settings.use_distributed_mode_trace, ), ) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index f6226385de..5df901a300 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -647,6 +647,7 @@ def masked_scatter_decomposition( def get_decompositions( enable_experimental_decompositions: bool = False, decompose_attention: bool = False, + use_distributed_mode_trace: bool = False, ) -> Dict[OpOverload, Callable[[Any], Any]]: trt_decomps = ( TORCH_TRT_DECOMPOSITIONS @@ -658,11 +659,19 @@ def get_decompositions( } ) + # For distributed (DTensor) models, allow aten.linear to decompose to addmm + # so that DTensor's sharding dispatch can find a strategy. + discard_decompositions = ( + torch_disabled_decompositions - {aten.linear.default} + if use_distributed_mode_trace + else torch_disabled_decompositions + ) + if enable_experimental_decompositions: CORE_ATEN_DECOMPOSITIONS_FILTERED: Dict[OpOverload, Callable[[Any], Any]] = { decomp: _core_aten_decompositions[decomp] for decomp in _core_aten_decompositions - if decomp not in torch_disabled_decompositions + if decomp not in discard_decompositions } return {**CORE_ATEN_DECOMPOSITIONS_FILTERED, **trt_decomps} else: @@ -674,7 +683,7 @@ def get_decompositions( DECOMP_TABLE_FILTERED: Dict[OpOverload, Callable[[Any], Any]] = { decomp: decomp_table[decomp] for decomp in decomp_table - if decomp not in torch_disabled_decompositions + if decomp not in discard_decompositions } return { diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index b1cd146881..05aa0f05e8 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -375,7 +375,11 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: raise RuntimeError("Engine has not been setup yet.") # Lazy NCCL setup on first forward - if self.engine.world_size > 1 and not hasattr(self, "_nccl_initialized"): + if ( + not torch.compiler.is_exporting() + and self.engine.is_md + and not hasattr(self, "_nccl_initialized") + ): group_name = self._get_default_group_name() if group_name: self.engine.setup_nccl_comm(group_name)