diff --git a/MODULE.bazel b/MODULE.bazel index cd557d8233..333429ec65 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -8,6 +8,7 @@ bazel_dep(name = "googletest", version = "1.16.0") bazel_dep(name = "platforms", version = "0.0.11") bazel_dep(name = "rules_cc", version = "0.1.1") bazel_dep(name = "rules_python", version = "1.3.0") +bazel_dep(name = "bazel_skylib", version = "1.7.1") python = use_extension("@rules_python//python/extensions:python.bzl", "python") python.toolchain( @@ -26,6 +27,10 @@ new_local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local. local_torch = use_repo_rule("//toolchains:local_torch.bzl", "local_torch") +torch_nccl_detect = use_repo_rule("//toolchains/torch_nccl:defs.bzl", "torch_nccl_detect") + +torch_nccl_detect(name = "torch_nccl") + # External dependency for torch_tensorrt if you already have precompiled binaries. new_local_repository( name = "torch_tensorrt", diff --git a/core/runtime/BUILD b/core/runtime/BUILD index 19260149ae..5fd06e7150 100644 --- a/core/runtime/BUILD +++ b/core/runtime/BUILD @@ -1,6 +1,8 @@ load("@rules_cc//cc:defs.bzl", "cc_library") load("@rules_pkg//:pkg.bzl", "pkg_tar") load("@rules_pkg//pkg:mappings.bzl", "pkg_files") +load("//toolchains/torch_nccl:defs.bzl", "if_torch_nccl") + package(default_visibility = ["//visibility:public"]) config_setting( @@ -77,6 +79,7 @@ cc_library( "TRTEngineProfiler.h", "runtime.h", ], + copts = if_torch_nccl(["-DUSE_C10D_NCCL"]), linkopts = [ "-lstdc++fs", ], @@ -121,6 +124,6 @@ pkg_tar( pkg_files( name = "include_pkg_files", srcs = [":include_files"], - visibility = ["//visibility:public"], prefix = "include/torch_tensorrt/core/runtime/", + visibility = ["//visibility:public"], ) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index d29daa112b..e42f8268cc 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -10,6 +10,13 @@ #include "core/util/prelude.h" #include "torch/torch.h" +#ifdef ENABLE_TRT_NCCL_COLLECTIVES +#include "torch/csrc/distributed/c10d/GroupRegistry.hpp" +#include "torch/csrc/distributed/c10d/NCCLUtils.hpp" +#include "torch/csrc/distributed/c10d/ProcessGroup.hpp" +#include "torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp" +#endif + namespace torch_tensorrt { namespace core { namespace runtime { @@ -88,7 +95,14 @@ TRTEngine::TRTEngine(std::vector serialized_info) serialized_info[SERIALIZED_METADATA_IDX], (static_cast(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic - : ResourceAllocationStrategy::kStatic)) {} + : ResourceAllocationStrategy::kStatic)) { + this->is_md = std::stoi(serialized_info[IS_MD_ENGINE_IDX]); + if (this->is_md) { + LOG_INFO( + "Loaded distributed engine (built on rank " << serialized_info[OPTIONAL_RANK_IDX] << " of " + << serialized_info[OPTIONAL_WORLD_SIZE_IDX] << ")"); + } +} TRTEngine::TRTEngine( const std::string& mod_name, @@ -497,6 +511,11 @@ std::vector TRTEngine::serialize() { serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize(); serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; + serialized_info[IS_MD_ENGINE_IDX] = this->is_md ? "1" : "0"; + if (this->is_md) { + serialized_info[OPTIONAL_RANK_IDX] = std::to_string(this->rank); + serialized_info[OPTIONAL_WORLD_SIZE_IDX] = std::to_string(this->world_size); + } return serialized_info; } @@ -519,6 +538,61 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt } } +#ifdef ENABLE_TRT_NCCL_COLLECTIVES +void TRTEngine::detect_distributed_context(const std::string& group_name) { + auto pg = c10d::resolve_process_group(group_name); + if (pg) { + this->rank = pg->getRank(); + this->world_size = pg->getSize(); + this->is_md = this->world_size > 1; + LOG_DEBUG("Detected distributed context: rank=" << this->rank << ", world_size=" << this->world_size); + } +} + +void TRTEngine::setup_nccl_comm(const std::string& group_name) { + auto pg = c10d::resolve_process_group(group_name); + TORCHTRT_CHECK(pg != nullptr, "ProcessGroup '" << group_name << "' not found in registry"); + + // Set rank/world_size if not already set (e.g. load from disk without setup_engine) + if (this->rank < 0) { + this->rank = pg->getRank(); + this->world_size = pg->getSize(); + LOG_DEBUG("Set distributed context in setup_nccl_comm: rank=" << this->rank << ", world_size=" << this->world_size); + } + + auto backend = pg->getBackend(c10d::ProcessGroup::BackendType::NCCL); + TORCHTRT_CHECK(backend != nullptr, "ProcessGroup '" << group_name << "' has no NCCL backend"); + + auto* nccl_pg = dynamic_cast(backend.get()); + TORCHTRT_CHECK(nccl_pg != nullptr, "Backend is not ProcessGroupNCCL"); + + at::cuda::set_device(this->device_info.id); + + int64_t comm_ptr = nccl_pg->getCommPtr(); + TORCHTRT_CHECK( + comm_ptr != 0, + "NCCL communicator not initialized for device " << this->device_info.id + << ". Ensure a collective operation has been performed first."); + + this->nccl_comm = reinterpret_cast(comm_ptr); + set_nccl_communicator_to_trt_context(); + LOG_INFO("NCCL comm set up (rank=" << this->rank << ", device=" << this->device_info.id << ")"); +} + +bool TRTEngine::set_nccl_communicator_to_trt_context() { + TORCHTRT_CHECK(exec_ctx != nullptr, "Cannot set NCCL communicator: execution context is null"); + TORCHTRT_CHECK(this->nccl_comm != nullptr, "NCCL communicator is not set"); + + exec_ctx->setCommunicator(this->nccl_comm); + + LOG_INFO( + "NCCL communicator set on TensorRT execution context " + "(rank=" + << this->rank << ", device=" << this->device_info.id << ")"); + return true; +} +#endif // ENABLE_TRT_NCCL_COLLECTIVES + } // namespace runtime } // namespace core } // namespace torch_tensorrt diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 363631863f..35591931aa 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -9,12 +9,25 @@ #include "ATen/core/function_schema.h" #include "ATen/cuda/CUDAGraph.h" #include "NvInfer.h" +#include "NvInferVersion.h" #include "c10/cuda/CUDAStream.h" #include "torch/custom_class.h" #include "core/runtime/TRTEngineProfiler.h" #include "core/util/prelude.h" +// TensorRT 10.16+ has native NCCL collective support via IExecutionContext::setCommunicator() +#if NV_TENSORRT_MAJOR > 10 || (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR >= 16) +#define TRT_HAS_NATIVE_NCCL 1 +#endif + +// Full TRT NCCL collectives support requires both: +// 1. PyTorch built with NCCL (USE_C10D_NCCL defined via Bazel) +// 2. TensorRT 10.16+ (TRT_HAS_NATIVE_NCCL defined above) +#if defined(USE_C10D_NCCL) && defined(TRT_HAS_NATIVE_NCCL) +#define ENABLE_TRT_NCCL_COLLECTIVES 1 +#endif + namespace torch_tensorrt { namespace core { namespace runtime { @@ -196,6 +209,22 @@ struct TRTEngine : torch::CustomClassHolder { bool use_output_allocator_outputs = false; // users specify to use output allocator std::shared_ptr output_allocator; + // Member variables for distributed inference (-1 indicates non-distributed mode) + bool is_md = false; + int64_t rank = -1; + int64_t world_size = -1; + +#ifdef ENABLE_TRT_NCCL_COLLECTIVES + void* nccl_comm = nullptr; + + // Detect rank and world_size from ProcessGroup + void detect_distributed_context(const std::string& group_name); + + // Resolve ProcessGroup, get NCCL communicator, and bind to TRT context + void setup_nccl_comm(const std::string& group_name); + bool set_nccl_communicator_to_trt_context(); +#endif + // TODO: Implement a call method // c10::List Run(c10::List inputs); diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 553469392b..0c7d91848e 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -311,6 +311,14 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->enqueue_profile_path); } + // Distributed setup - bind NCCL communicator to TRT execution context + // setup_nccl_comm must have been called from Python before first forward +#ifdef ENABLE_TRT_NCCL_COLLECTIVES + if (compiled_engine->is_md && compiled_engine->nccl_comm != nullptr) { + compiled_engine->set_nccl_communicator_to_trt_context(); + } +#endif + // Block engine stream until results are available on caller stream at::cuda::CUDAEvent caller_exec_complete; caller_exec_complete.record(compiled_engine->caller_stream); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index e8f6217a21..aaaecbabec 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -108,6 +108,19 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = &TRTEngine::set_device_memory_budget) .def_property("streamable_device_memory_budget", &TRTEngine::get_streamable_device_memory_budget) .def_property("automatic_device_memory_budget", &TRTEngine::get_automatic_device_memory_budget) + .def_readonly("is_md", &TRTEngine::is_md) + .def_readonly("rank", &TRTEngine::rank) + .def_readonly("world_size", &TRTEngine::world_size) +#ifdef ENABLE_TRT_NCCL_COLLECTIVES + .def( + "detect_distributed_context", + [](c10::intrusive_ptr self, std::string group_name) { + self->detect_distributed_context(group_name); + }) + .def( + "setup_nccl_comm", + [](c10::intrusive_ptr self, std::string group_name) { self->setup_nccl_comm(group_name); }) +#endif .def_pickle( [](const c10::intrusive_ptr& self) -> std::vector { return self->serialize(); }, [](std::vector serialized_info) -> c10::intrusive_ptr { @@ -150,6 +163,16 @@ TORCH_LIBRARY(tensorrt, m) { m.def("REQUIRES_OUTPUT_ALLOCATOR_IDX", []() -> int64_t { return REQUIRES_OUTPUT_ALLOCATOR_IDX; }); m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; }); m.def("RESOURCE_ALLOCATION_STRATEGY_IDX", []() -> int64_t { return RESOURCE_ALLOCATION_STRATEGY_IDX; }); + m.def("IS_MD_ENGINE_IDX", []() -> int64_t { return IS_MD_ENGINE_IDX; }); + m.def("OPTIONAL_RANK_IDX", []() -> int64_t { return OPTIONAL_RANK_IDX; }); + m.def("OPTIONAL_WORLD_SIZE_IDX", []() -> int64_t { return OPTIONAL_WORLD_SIZE_IDX; }); + m.def("NATIVE_TRT_COLLECTIVES_AVAIL", []() -> bool { +#ifdef ENABLE_TRT_NCCL_COLLECTIVES + return true; +#else + return false; +#endif + }); m.def("_platform_linux_x86_64", []() -> std::string { auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64); return it->second; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index d8f71683d3..741d8dee3b 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -16,7 +16,7 @@ namespace core { namespace runtime { using EngineID = int64_t; -const std::string ABI_VERSION = "8"; +const std::string ABI_VERSION = "9"; extern bool MULTI_DEVICE_SAFE_MODE; typedef enum { @@ -39,6 +39,9 @@ typedef enum { TARGET_PLATFORM_IDX, REQUIRES_OUTPUT_ALLOCATOR_IDX, RESOURCE_ALLOCATION_STRATEGY_IDX, + IS_MD_ENGINE_IDX, + OPTIONAL_RANK_IDX, + OPTIONAL_WORLD_SIZE_IDX, SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index f2dc6861cb..a7a8d55b8b 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -16,27 +16,51 @@ ----- .. code-block:: bash - mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py + # JIT mode python runtime + mpirun -n 2 python tensor_parallel_simple_example.py --mode jit_cpp + + # JIT mode cpp runtime + mpirun -n 2 python tensor_parallel_simple_example.py --mode jit_python + + WIP: Export and load mode + mpirun -n 2 python tensor_parallel_simple_example.py --mode export --save-path /tmp/tp_model.ep + mpirun -n 2 python tensor_parallel_simple_example.py --mode load --save-path /tmp/tp_model.ep + """ +import argparse import time -import tensorrt as trt import torch import torch.distributed as dist import torch.nn as nn +import torch.utils._pytree from tensor_parallel_initialize_dist import ( cleanup_distributed_env, - get_tensor_parallel_device_mesh, initialize_distributed_env, ) -# Initialize distributed environment and logger BEFORE importing torch_tensorrt -# This ensures logging is configured before any import-time log messages +torch.utils._pytree.register_constant( + torch.distributed.tensor._dtensor_spec.DTensorSpec +) + +parser = argparse.ArgumentParser(description="Tensor Parallel Simple Example") +parser.add_argument( + "--mode", + type=str, + choices=["jit_python", "jit_cpp", "export", "load"], + default="jit_python", +) +parser.add_argument("--save-path", type=str, default="/tmp/tp_model.ep") +args = parser.parse_args() + device_mesh, _world_size, _rank, logger = initialize_distributed_env( "tensor_parallel_simple_example" ) +import torch_tensorrt +from torch_tensorrt.dynamo.runtime._nccl_utils import setup_nccl_library +setup_nccl_library() from torch.distributed._tensor import Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, @@ -92,29 +116,65 @@ def forward(self, x): inp = torch.rand(20, 10, device="cuda") python_result = tp_model(inp) -backend = "torch_tensorrt" -tp_model = torch.compile( - tp_model, - backend=backend, - options={ - "truncate_long_and_double": True, - "enabled_precisions": {torch.float32, torch.float16}, - "use_python_runtime": True, - "min_block_size": 1, - "use_distributed_mode_trace": True, - }, - dynamic=None, -) - -# For TP, input needs to be same across all TP ranks. -# Setting the random seed is to mimic the behavior of dataloader. -torch.manual_seed(0) -inp = torch.rand(20, 10, device="cuda") -start = time.time() -output = tp_model(inp) -end = time.time() -logger.info(f"Compilation time is {end - start}") -assert (python_result - output).std() < 0.01, "Result is not correct." +if args.mode == "load": + # Load per-rank model: /tmp/tp_model.ep -> /tmp/tp_model_rank0_of_2.ep + logger.info(f"Loading from {args.save_path}") + loaded_program = torch_tensorrt.load(args.save_path) + output = loaded_program.module()(inp) + assert (python_result - output).std() < 0.01, "Result mismatch" + logger.info("Load successful!") + +elif args.mode == "jit_python": + trt_model = torch.compile( + tp_model, + backend="torch_tensorrt", + options={ + "truncate_long_and_double": True, + "enabled_precisions": {torch.float32, torch.float16}, + "use_python_runtime": True, + "min_block_size": 1, + "use_distributed_mode_trace": True, + }, + ) + output = trt_model(inp) + assert (python_result - output).std() < 0.01, "Result mismatch" + logger.info("JIT compile successful!") + +elif args.mode == "jit_cpp": + trt_model = torch.compile( + tp_model, + backend="torch_tensorrt", + options={ + "truncate_long_and_double": True, + "enabled_precisions": {torch.float32, torch.float16}, + "use_python_runtime": False, + "min_block_size": 1, + "use_distributed_mode_trace": True, + }, + ) + output = trt_model(inp) + assert (python_result - output).std() < 0.01, "Result mismatch" + logger.info("JIT compile successful!") + +elif args.mode == "export": + # Export: torch.export + dynamo.compile - AOT compilation, can save + exported_program = torch.export.export(tp_model, (inp,), strict=False) + trt_model = torch_tensorrt.dynamo.compile( + exported_program, + inputs=[inp], + # enabled_precisions={torch.float32, torch.float16}, + truncate_double=True, + use_python_runtime=False, + min_block_size=1, + use_distributed_mode_trace=True, + ) + output = trt_model(inp) + assert (python_result - output).std() < 0.01, "Result mismatch" + + # Save per-rank: /tmp/tp_model.ep -> /tmp/tp_model_rank0_of_2.ep + save_path = torch_tensorrt.save(trt_model, args.save_path, inputs=[inp]) + logger.info(f"Saved to {save_path}") + dist.barrier() -# This cleans up the distributed process group cleanup_distributed_env() +logger.info("Done!") diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py index 318fc79461..1f7c389d8c 100644 --- a/py/torch_tensorrt/_features.py +++ b/py/torch_tensorrt/_features.py @@ -7,6 +7,7 @@ import tensorrt from torch_tensorrt._utils import ( check_cross_compile_trt_win_lib, + check_native_trt_collectives, load_tensorrt_llm_for_nccl, sanitized_torch_version, ) @@ -25,6 +26,7 @@ "windows_cross_compile", "tensorrt_rtx", "trtllm_for_nccl", + "native_trt_collectives", ], ) @@ -50,7 +52,16 @@ _FX_FE_AVAIL = False if _TENSORRT_RTX else True _REFIT_AVAIL = True _WINDOWS_CROSS_COMPILE = check_cross_compile_trt_win_lib() -_TRTLLM_AVAIL = load_tensorrt_llm_for_nccl() + +# Check if native TRT collectives are available (TRT 10.16+ with NCCL) +_NATIVE_TRT_COLLECTIVES_AVAIL = check_native_trt_collectives( + linked_file_full_path, linked_file_runtime_full_path +) + +# Only load TRT-LLM for NCCL if native TRT collectives are not available +_TRTLLM_AVAIL = False +if not _NATIVE_TRT_COLLECTIVES_AVAIL: + _TRTLLM_AVAIL = load_tensorrt_llm_for_nccl() if _TENSORRT_RTX: @@ -78,6 +89,7 @@ _WINDOWS_CROSS_COMPILE, _TENSORRT_RTX, _TRTLLM_AVAIL, + _NATIVE_TRT_COLLECTIVES_AVAIL, ) T = TypeVar("T") @@ -85,7 +97,7 @@ def _enabled_features_str() -> str: enabled = lambda x: "ENABLED" if x else "DISABLED" - out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n - TensorRT-LLM for NCCL: {enabled(_TRTLLM_AVAIL)}\n" # type: ignore[no-untyped-call] + out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n - TensorRT-LLM for NCCL: {enabled(_TRTLLM_AVAIL)}\n - Native TRT Collectives: {enabled(_NATIVE_TRT_COLLECTIVES_AVAIL)}\n" # type: ignore[no-untyped-call] return out_str diff --git a/py/torch_tensorrt/_utils.py b/py/torch_tensorrt/_utils.py index 20521590ba..9b2993b56d 100644 --- a/py/torch_tensorrt/_utils.py +++ b/py/torch_tensorrt/_utils.py @@ -299,6 +299,39 @@ def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: return False +def check_native_trt_collectives( + torchtrt_lib_path: Optional[str] = None, + torchtrt_runtime_lib_path: Optional[str] = None, +) -> bool: + """ + Check if native TRT collectives are available (TRT 10.16+ with NCCL). + + This function loads the torch_tensorrt runtime library to register torch ops, + then calls NATIVE_TRT_COLLECTIVES_AVAIL() to check if the runtime was compiled + with NCCL support and TensorRT 10.16+. + + Args: + torchtrt_lib_path: Path to libtorchtrt.so (full library) + torchtrt_runtime_lib_path: Path to libtorchtrt_runtime.so (runtime-only library) + + Returns: + bool: True if native TRT collectives are available, False otherwise. + """ + try: + # Load the runtime library to register torch ops + # Prefer full library if available, otherwise runtime-only + if torchtrt_lib_path and os.path.isfile(torchtrt_lib_path): + torch.ops.load_library(torchtrt_lib_path) + elif torchtrt_runtime_lib_path and os.path.isfile(torchtrt_runtime_lib_path): + torch.ops.load_library(torchtrt_runtime_lib_path) + else: + return False + + return bool(torch.ops.tensorrt.NATIVE_TRT_COLLECTIVES_AVAIL()) + except Exception: + return False + + def load_tensorrt_llm_for_nccl() -> bool: """ Attempts to load the TensorRT-LLM plugin and initialize it. diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index bc3cdc5721..88c941e52f 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -370,7 +370,11 @@ def cross_compile_for_windows( logger.info("Compilation Settings: %s\n", settings) exported_program = pre_export_lowering(exported_program, settings) exported_program = exported_program.run_decompositions( - get_decompositions(enable_experimental_decompositions, decompose_attention) + get_decompositions( + enable_experimental_decompositions, + decompose_attention, + use_distributed_mode_trace, + ) ) gm = exported_program.module() @@ -769,7 +773,11 @@ def compile( logger.info("Compilation Settings: %s\n", settings) exported_program = pre_export_lowering(exported_program, settings) exported_program = exported_program.run_decompositions( - get_decompositions(enable_experimental_decompositions, decompose_attention) + get_decompositions( + enable_experimental_decompositions, + decompose_attention, + use_distributed_mode_trace, + ) ) gm = exported_program.module() @@ -1418,7 +1426,11 @@ def convert_exported_program_to_serialized_trt_engine( logger.info("Compilation Settings: %s\n", settings) exported_program = pre_export_lowering(exported_program, settings) exported_program = exported_program.run_decompositions( - get_decompositions(enable_experimental_decompositions, decompose_attention) + get_decompositions( + enable_experimental_decompositions, + decompose_attention, + use_distributed_mode_trace, + ) ) gm = exported_program.module() diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 0b6af849fa..c179982b3d 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -360,7 +360,9 @@ def refit_module_weights( new_weight_module = pre_export_lowering(new_weight_module, settings) new_weight_module = new_weight_module.run_decompositions( get_decompositions( - settings.enable_experimental_decompositions, settings.decompose_attention + settings.enable_experimental_decompositions, + settings.decompose_attention, + settings.use_distributed_mode_trace, ) ) new_gm = new_weight_module.module() diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 00fb6977e8..b737b0c2d9 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -62,7 +62,9 @@ def aot_torch_tensorrt_aten_backend( ) settings_aot_autograd = {} settings_aot_autograd["decompositions"] = get_decompositions( - settings.enable_experimental_decompositions, settings.decompose_attention + settings.enable_experimental_decompositions, + settings.decompose_attention, + settings.use_distributed_mode_trace, ) # This is added since detach lowering leads to alias nodes # Error - View operation returned a tensor that is the same as the input base tensor @@ -143,6 +145,7 @@ def _pretraced_backend( decompositions=get_decompositions( settings.enable_experimental_decompositions, settings.decompose_attention, + settings.use_distributed_mode_trace, ), ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index f04836c52c..ba0cc90c21 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -225,6 +225,12 @@ def _populate_trt_builder_config( ) -> trt.IBuilderConfig: builder_config = self.builder.create_builder_config() + if ENABLED_FEATURES.native_trt_collectives: + _LOGGER.info("Using native TRT collectives") + builder_config.set_preview_feature( + trt.PreviewFeature.MULTIDEVICE_RUNTIME_10_16, True + ) + if self._debugger_config and self._debugger_config.engine_builder_monitor: builder_config.progress_monitor = TRTBulderMonitor() @@ -453,7 +459,7 @@ def check_weight_equal( except Exception: return torch.all(sd_weight == network_weight) - @needs_refit # type: ignore[misc] + @needs_refit def _save_weight_mapping(self) -> None: """ Construct the weight name mapping from engine weight name to state_dict weight name. diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index e47d3f404f..069ae3f43c 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -4,6 +4,7 @@ import logging from typing import Any, Dict, List, NamedTuple, Optional, Sequence +import tensorrt as trt import torch from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES @@ -25,8 +26,6 @@ ) from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -49,7 +48,7 @@ def infer_module_output_dtypes( """ outputs = [node for node in module.graph.nodes if node.op == "output"] outputs = outputs[0].args - return get_output_dtypes(outputs, truncate_double) + return list(get_output_dtypes(outputs, truncate_double)) def insert_engine_to_cache( @@ -358,6 +357,22 @@ def convert_module( "Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available" ) + if settings.use_distributed_mode_trace: + # Check if distributed backends are available + if ENABLED_FEATURES.native_trt_collectives: + logger.info( + "Native TRT collectives available (TRT 10.16+) for distributed execution" + ) + elif ENABLED_FEATURES.trtllm_for_nccl: + logger.info("TRT-LLM NCCL plugins available for distributed execution") + else: + logger.warning( + "Distributed mode requested but neither native TRT collectives nor TRT-LLM NCCL plugins are available. " + "Distributed execution may not work correctly. " + "For native TRT collectives, ensure TensorRT 10.16+ and torch_tensorrt built with NCCL support. " + "For TRT-LLM fallback, set TRTLLM_PLUGINS_PATH or USE_TRTLLM_PLUGINS=1." + ) + return rt_cls( serialized_engine=serialized_interpreter_result.serialized_engine, input_binding_names=list(serialized_interpreter_result.input_names), diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index 302a254f60..2324b04806 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -14,11 +14,69 @@ ) from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( tensorrt_fused_nccl_all_gather_op, + tensorrt_fused_nccl_all_reduce_op, tensorrt_fused_nccl_reduce_scatter_op, ) _LOGGER: logging.Logger = logging.getLogger(__name__) +if ENABLED_FEATURES.native_trt_collectives: + # Use native TensorRT DistCollective API (no TensorRT-LLM dependency) + _LOGGER.info("Using native TensorRT DistCollective API for distributed operations") + + @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) + def fused_nccl_gather( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + """All-gather using native TensorRT DistCollective API""" + return impl.nccl_ops.nccl_gather_native( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) + + @dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op) + def fused_nccl_reduce_scatter( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + """Reduce-scatter using native TensorRT DistCollective API""" + return impl.nccl_ops.nccl_reduce_scatter_native( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) + + @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_reduce_op) + def fused_nccl_all_reduce( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + """All-reduce using native TensorRT DistCollective API""" + reduce_op = args[1] if len(args) > 1 else "sum" + return impl.nccl_ops.nccl_all_reduce_native( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + reduce_op=reduce_op, + ) + # Conditionally register NCCL converters only if TensorRT-LLM plugin is available. # We use an `if` statement instead of @needs_trtllm_for_nccl decorator because @@ -28,7 +86,7 @@ # Order 1: @needs_trtllm_for_nccl followed by registering the converter leads to plugin registry not finding nccl ops plugins since we register the bare converter, without the decorator # Order 2: registering the converter first followed by @needs_trtllm_for_nccl leads to "NotImplementedError: TensorRT-LLM plugin for NCCL is not available :TensorRT-LLM plugin for NCCL is not available" and no fall back to pytorch -if ENABLED_FEATURES.trtllm_for_nccl: +elif ENABLED_FEATURES.trtllm_for_nccl: _LOGGER.debug( "TensorRT-LLM plugin for NCCL is available. Registering NCCL converters." ) @@ -65,6 +123,22 @@ def fused_nccl_reduce_scatter( [args[0]], ) + @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_reduce_op) + def fused_nccl_all_reduce( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + return impl.nccl_ops.nccl_all_reduce( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) + else: _LOGGER.info( "TensorRT-LLM plugin for NCCL is not available. " diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py index e64c06ca39..55df569f09 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py @@ -1,3 +1,4 @@ +import logging import os from enum import IntEnum, IntFlag, auto from typing import Optional, Tuple, Union @@ -9,6 +10,8 @@ from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import set_layer_name +logger = logging.getLogger(__name__) + # class for AllReduce class AllReduceStrategy(IntEnum): @@ -33,6 +36,34 @@ class AllReduceConfig(IntFlag): PUSH_MODE = auto() +def _get_distributed_rank_and_world_size() -> Tuple[int, int]: + """Get rank and world_size from environment variables. + + Returns: + (rank, world_size) tuple. + + Raises: + RuntimeError: If WORLD_SIZE is not set. + """ + _world_size = os.environ.get("WORLD_SIZE") + if _world_size is None: + raise RuntimeError( + "The WORLD_SIZE env variable is not set in distributed environment" + ) + world_size = int(_world_size) + + # Get rank from environment + _rank = int(os.environ.get("RANK", 0)) + if _rank is not None: + rank = int(_rank) + else: + raise RuntimeError( + "The RANK env variable is not set in distributed environment" + ) + + return rank, world_size + + def nccl_gather( ctx: ConversionContext, target: Union[Target, str], @@ -44,13 +75,11 @@ def nccl_gather( "AllGather", "1", "tensorrt_llm" ) assert allgather_plg_creator is not None - _world_size = os.environ.get("WORLD_SIZE") - if _world_size is not None: - world_size = int(_world_size) - else: - raise RuntimeError( - "The WORLD_SIZE env variable is not set in distributed environment" - ) + rank, world_size = _get_distributed_rank_and_world_size() + logger.debug( + f"Adding TRT-LLM NCCL gather: name={name}, rank={rank}, world_size={world_size}" + ) + group = list(range(world_size)) group = trt.PluginField( "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 @@ -66,6 +95,56 @@ def nccl_gather( return layer.get_output(0) +def nccl_all_reduce( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs: Tuple[Argument, ...], +) -> trt.ITensor: + allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator( + "AllReduce", "1", "tensorrt_llm" + ) + assert allreduce_plg_creator is not None + + counter = 0 + strategy = AllReduceStrategy.NCCL + config = AllReduceConfig(0) + rank, world_size = _get_distributed_rank_and_world_size() + logger.debug( + f"Adding TRT-LLM NCCL all reduce: name={name}, rank={rank}, world_size={world_size}" + ) + group = list(range(world_size)) + group = trt.PluginField( + "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 + ) + + p_dtype = trt.float32 + pf_dtype = trt.PluginField( + "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32 + ) + pfc = [group, pf_dtype] + p_strategy = trt.PluginField( + "strategy", np.array([int(strategy)], np.int8), trt.PluginFieldType.INT8 + ) + pfc.append(p_strategy) + p_config = trt.PluginField( + "config", np.array([int(config)], np.int8), trt.PluginFieldType.INT8 + ) + pfc.append(p_config) + p_counter = trt.PluginField( + "counter", np.array([counter], np.int32), trt.PluginFieldType.INT32 + ) + pfc.append(p_counter) + + pfc = trt.PluginFieldCollection(pfc) + ar_plug = allreduce_plg_creator.create_plugin("allreduce", pfc) + + layer = ctx.net.add_plugin_v2(plug_inputs, ar_plug) + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) + + def nccl_reduce_scatter( ctx: ConversionContext, target: Union[Target, str], @@ -82,13 +161,10 @@ def nccl_reduce_scatter( counter = 0 strategy = AllReduceStrategy.NCCL config = AllReduceConfig(0) - _world_size = os.environ.get("WORLD_SIZE") - if _world_size is not None: - world_size = int(_world_size) - else: - raise RuntimeError( - "The WORLD_SIZE env variable is not set in distributed environment" - ) + rank, world_size = _get_distributed_rank_and_world_size() + logger.debug( + f"Adding TRT-LLM NCCL reduce scatter: name={name}, rank={rank}, world_size={world_size}" + ) group = list(range(world_size)) group = trt.PluginField( "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 @@ -118,3 +194,250 @@ def nccl_reduce_scatter( layer = ctx.net.add_plugin_v2(plug_inputs, ar_plug) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) + + +def nccl_gather_native( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs: Tuple[Argument, ...], +) -> trt.ITensor: + """ + Implement all_gather using native TensorRT DistCollective API. + + This operation gathers tensors from all ranks and concatenates them. + Each rank contributes a tensor, and all ranks receive the concatenated result. + + Returns: + Output tensor after all_gather operation + + Example: + Input on rank 0: [1, 2] shape=(2,) + Input on rank 1: [3, 4] shape=(2,) + Output on all ranks: [1, 2, 3, 4] shape=(4,) + """ + rank, world_size = _get_distributed_rank_and_world_size() + logger.debug( + f"Adding native all_gather: name={name}, rank={rank}, world_size={world_size}" + ) + + # Get the input tensor + input_tensor = plug_inputs[0] + + try: + # Use native TensorRT DistCollective API for ALL_GATHER + # For ALL_GATHER, the reduce operation and root rank parameters are ignored + # The last parameter (group) can be None to include all ranks + import numpy as np + + # Create array of all participating rank IDs [0, 1, 2, ..., world_size-1] + groups = np.arange(world_size, dtype=np.int64) + + logger.debug( + f"Creating ALL_GATHER layer: groups={groups.tolist()}, groupSize={world_size}" + ) + layer = ctx.net.add_dist_collective( + input_tensor, + trt.CollectiveOperation.ALL_GATHER, + trt.ReduceOperation.NONE, # Ignored for ALL_GATHER + -1, # Root rank - ignored for ALL_GATHER + groups, # None means all ranks participate (world_size ranks) + ) + + logger.debug(f"Successfully created native ALL_GATHER layer: {name}") + logger.debug( + f"Calling add_dist_collective: input_shape={input_tensor.shape}, " + f"groups={groups.tolist()}, groupSize={len(groups)} (inferred from array)" + ) + + set_layer_name(layer, target, name, source_ir) + + output = layer.get_output(0) + layer.num_ranks = world_size + + return output + + except AttributeError as e: + error_msg = ( + f"Native ALL_GATHER failed: {e}. " + "This usually means TensorRT doesn't support native distributed collectives. " + f"Your TensorRT version: {trt.__version__}. " + "Native collectives require TensorRT 10.16 or later. " + "Consider using TensorRT-LLM plugins instead by setting USE_NATIVE_TRT_COLLECTIVES=0" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + except Exception as e: + logger.error(f"Native ALL_GATHER failed: {e} (type: {type(e).__name__})") + raise + + +def nccl_reduce_scatter_native( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs: Tuple[Argument, ...], + reduce_op: str = "sum", +) -> trt.ITensor: + """ + Implement reduce_scatter using native TensorRT DistCollective API. + + This operation reduces tensors from all ranks and scatters the result. + The input is split along dimension 0 and each rank receives one chunk + after the reduction operation. + + Returns: + Output tensor after reduce_scatter operation + reduce_op: Reduction operation ("sum", "prod", "min", "max", "avg") + + Example (with SUM reduction): + Input on rank 0: [1, 2, 3, 4] shape=(4,) + Input on rank 1: [5, 6, 7, 8] shape=(4,) + Output on rank 0: [1+5, 2+6] = [6, 8] shape=(2,) + Output on rank 1: [3+7, 4+8] = [10, 12] shape=(2,) + """ + rank, world_size = _get_distributed_rank_and_world_size() + logger.debug( + f"Adding native reduce_scatter: name={name}, rank={rank}, world_size={world_size}, reduce_op={reduce_op}" + ) + + # Get the input tensor + input_tensor = plug_inputs[0] + + # Map string reduction op to TensorRT ReduceOperation enum + reduce_op_map = { + "sum": trt.ReduceOperation.SUM, + "prod": trt.ReduceOperation.PROD, + "min": trt.ReduceOperation.MIN, + "max": trt.ReduceOperation.MAX, + "avg": trt.ReduceOperation.AVG, + } + + if reduce_op.lower() not in reduce_op_map: + raise ValueError( + f"Unsupported reduce operation: {reduce_op}. " + f"Supported: {list(reduce_op_map.keys())}" + ) + + trt_reduce_op = reduce_op_map[reduce_op.lower()] + + try: + layer = ctx.net.add_dist_collective( + input_tensor, + trt.CollectiveOperation.REDUCE_SCATTER, + trt_reduce_op, + -1, + None, # None means all ranks participate + ) + + set_layer_name(layer, target, name, source_ir) + + output = layer.get_output(0) + layer.num_ranks = world_size + logger.debug( + f"Successfully created native REDUCE_SCATTER layer: {name}, reduce_op={reduce_op}" + ) + + return output + + except AttributeError as e: + error_msg = ( + f"Native ALL_REDUCE_SCATTER failed: {e}. " + "This usually means TensorRT doesn't support native distributed collectives. " + f"Your TensorRT version: {trt.__version__}. " + "Native collectives require TensorRT 10.16 or later. " + "Consider using TensorRT-LLM plugins instead by setting USE_NATIVE_TRT_COLLECTIVES=0" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + except Exception as e: + logger.error(f"Native REDUCE_SCATTER failed: {e} (type: {type(e).__name__})") + raise + + +def nccl_all_reduce_native( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs: Tuple[Argument, ...], + reduce_op: str = "sum", +) -> trt.ITensor: + """ + Implement all_reduce using native TensorRT DistCollective API. + + This operation reduces tensors across all ranks in-place. Every rank + receives the same reduced result. + + Returns: + Output tensor after all_reduce operation + + Args: + reduce_op: Reduction operation ("sum", "prod", "min", "max", "avg") + + Example (with SUM reduction): + Input on rank 0: [1, 2, 3, 4] shape=(4,) + Input on rank 1: [5, 6, 7, 8] shape=(4,) + Output on rank 0: [6, 8, 10, 12] shape=(4,) + Output on rank 1: [6, 8, 10, 12] shape=(4,) + """ + rank, world_size = _get_distributed_rank_and_world_size() + logger.debug( + f"Adding native all_reduce: name={name}, rank={rank}, world_size={world_size}, reduce_op={reduce_op}" + ) + + input_tensor = plug_inputs[0] + + reduce_op_map = { + "sum": trt.ReduceOperation.SUM, + "prod": trt.ReduceOperation.PROD, + "min": trt.ReduceOperation.MIN, + "max": trt.ReduceOperation.MAX, + "avg": trt.ReduceOperation.AVG, + } + + if reduce_op.lower() not in reduce_op_map: + raise ValueError( + f"Unsupported reduce operation: {reduce_op}. " + f"Supported: {list(reduce_op_map.keys())}" + ) + + trt_reduce_op = reduce_op_map[reduce_op.lower()] + + try: + layer = ctx.net.add_dist_collective( + input_tensor, + trt.CollectiveOperation.ALL_REDUCE, + trt_reduce_op, + -1, + None, + ) + + set_layer_name(layer, target, name, source_ir) + + output = layer.get_output(0) + layer.num_ranks = world_size + logger.debug( + f"Successfully created native ALL_REDUCE layer: {name}, reduce_op={reduce_op}" + ) + + return output + + except AttributeError as e: + error_msg = ( + f"Native ALL_REDUCE failed: {e}. " + "This usually means TensorRT doesn't support native distributed collectives. " + f"Your TensorRT version: {trt.__version__}. " + "Native collectives require TensorRT 10.16 or later. " + "Consider using TensorRT-LLM plugins instead by setting USE_NATIVE_TRT_COLLECTIVES=0" + ) + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + except Exception as e: + logger.error(f"Native ALL_REDUCE failed: {e} (type: {type(e).__name__})") + raise diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index f6226385de..5df901a300 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -647,6 +647,7 @@ def masked_scatter_decomposition( def get_decompositions( enable_experimental_decompositions: bool = False, decompose_attention: bool = False, + use_distributed_mode_trace: bool = False, ) -> Dict[OpOverload, Callable[[Any], Any]]: trt_decomps = ( TORCH_TRT_DECOMPOSITIONS @@ -658,11 +659,19 @@ def get_decompositions( } ) + # For distributed (DTensor) models, allow aten.linear to decompose to addmm + # so that DTensor's sharding dispatch can find a strategy. + discard_decompositions = ( + torch_disabled_decompositions - {aten.linear.default} + if use_distributed_mode_trace + else torch_disabled_decompositions + ) + if enable_experimental_decompositions: CORE_ATEN_DECOMPOSITIONS_FILTERED: Dict[OpOverload, Callable[[Any], Any]] = { decomp: _core_aten_decompositions[decomp] for decomp in _core_aten_decompositions - if decomp not in torch_disabled_decompositions + if decomp not in discard_decompositions } return {**CORE_ATEN_DECOMPOSITIONS_FILTERED, **trt_decomps} else: @@ -674,7 +683,7 @@ def get_decompositions( DECOMP_TABLE_FILTERED: Dict[OpOverload, Callable[[Any], Any]] = { decomp: decomp_table[decomp] for decomp in decomp_table - if decomp not in torch_disabled_decompositions + if decomp not in discard_decompositions } return { diff --git a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py index 02cb2ccd56..a6a826dbb6 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py @@ -33,6 +33,16 @@ def tensorrt_fused_nccl_reduce_scatter_op( ) +def tensorrt_fused_nccl_all_reduce_op( + inp: Any, reduce_op: str, group_size: int, group_name: str +) -> torch.Tensor: + return torch.ops._c10d_functional.wait_tensor.default( + torch.ops._c10d_functional.all_reduce.default( + inp, reduce_op, group_size, group_name + ) + ) + + def fuse_distributed_ops( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: @@ -43,6 +53,7 @@ def fuse_distributed_ops( in ( torch.ops._c10d_functional.all_gather_into_tensor.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, + torch.ops._c10d_functional.all_reduce.default, ) and len(node.users) == 1 and list(node.users)[0].target @@ -53,14 +64,23 @@ def fuse_distributed_ops( with gm.graph.inserting_after(wait_tensor_node): fused_node = gm.graph.create_node( op="call_function", - target=tensorrt_fused_nccl_all_gather_op, # Define your custom fused function + target=tensorrt_fused_nccl_all_gather_op, args=(node.args[0], node.args[1], node.args[2]), ) + elif ( + node.target == torch.ops._c10d_functional.reduce_scatter_tensor.default + ): + with gm.graph.inserting_after(wait_tensor_node): + fused_node = gm.graph.create_node( + op="call_function", + target=tensorrt_fused_nccl_reduce_scatter_op, + args=(node.args[0], node.args[1], node.args[2], node.args[3]), + ) else: with gm.graph.inserting_after(wait_tensor_node): fused_node = gm.graph.create_node( op="call_function", - target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function + target=tensorrt_fused_nccl_all_reduce_op, args=(node.args[0], node.args[1], node.args[2], node.args[3]), ) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 31182bbe21..39764c6653 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -4,11 +4,14 @@ from contextlib import nullcontext from typing import Any, Dict, List, Optional, Sequence, Tuple +import tensorrt as trt import torch +import torch.distributed as dist import torch_tensorrt from torch.nn import Module from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform, dtype +from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig @@ -21,8 +24,6 @@ multi_gpu_device_check, ) -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -149,7 +150,6 @@ def __init__( weight_name_map (dict): Mapping of engine weight name to state_dict weight name requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) symbolic_shape_expressions (List[str]): List of symbolic shape expressions for each output binding - Example: .. code-block:: py @@ -229,6 +229,15 @@ def __init__( if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() + # Auto-detect distributed context + if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + else: + self.rank = -1 + self.world_size = -1 + self._nccl_comm: Optional[Any] = None + def set_output_tensors_as_unowned(self, enabled: bool) -> None: """ Flag to set if the output tensors of this engine are solely owned by the Torch-TensorRT Runtime or if they might be shared with a user. @@ -280,6 +289,67 @@ def set_default_device_memory_budget(self) -> int: logger.debug(f"Weight streaming budget set to {budget_bytes}B") return self._set_device_memory_budget(budget_bytes) + # Distributed functions + @property + def is_distributed(self) -> bool: + """Check if this module is configured for distributed execution.""" + return bool(self.world_size > 1) + + def setup_nccl_comm(self) -> None: + """Set up NCCL communicator from PyTorch's ProcessGroup. + + In PythonTorchTensorRTModule, this is a single call that gets the NCCL comm + and binds it to the TRT context. rank/world_size are already set in __init__ + via dist.get_rank(). + + In TorchTensorRTModule (C++ runtime), this is split into two calls: + - detect_distributed_context(group_name): sets rank/world_size on the C++ engine + (called in setup_engine, needed for serialization before forward) + - setup_nccl_comm(group_name): gets NCCL comm and binds to TRT context + (called lazily on first forward) + """ + if not self.is_distributed: + return + + if not dist.is_initialized(): + raise RuntimeError( + "torch.distributed must be initialized before calling setup_nccl_comm(). " + "Call dist.init_process_group('nccl') first." + ) + + pg = dist.group.WORLD + if pg is None or dist.get_backend(pg) != "nccl": + raise RuntimeError("Default ProcessGroup must use NCCL backend") + + backend = pg._get_backend(torch.device("cuda")) + + # Force NCCL communicator initialization with a dummy collective + dummy = torch.zeros(1, device="cuda") + dist.all_reduce(dummy) + + comm_ptr = backend._comm_ptr() + if comm_ptr is None or comm_ptr == 0: + raise RuntimeError("Failed to get NCCL communicator from ProcessGroup") + + self._nccl_comm = comm_ptr + + # Bind communicator to TRT execution context (PyCapsule required by TRT Python API) + if self.context is not None: + import ctypes + + ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object + ctypes.pythonapi.PyCapsule_New.argtypes = [ + ctypes.c_void_p, + ctypes.c_char_p, + ctypes.c_void_p, + ] + comm_capsule = ctypes.pythonapi.PyCapsule_New(comm_ptr, None, None) + self.context.set_communicator(comm_capsule) + + logger.info( + f"NCCL comm set up (rank={self.rank}, world_size={self.world_size})" + ) + def setup_engine(self) -> None: assert ( self.target_platform == Platform.current_platform() @@ -333,6 +403,9 @@ def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> Non state_dict[prefix + "input_names"] = self.input_names state_dict[prefix + "output_names"] = self.output_names state_dict[prefix + "platform"] = self.target_platform + # Distributed info (always saved, -1 indicates non-distributed) + state_dict[prefix + "rank"] = self.rank + state_dict[prefix + "world_size"] = self.world_size def _load_from_state_dict( self, @@ -349,6 +422,27 @@ def _load_from_state_dict( self.output_names = state_dict[prefix + "output_names"] self.target_platform = state_dict[prefix + "platform"] + # Same rule as C++ TRTEngine: serialized rank/world_size are build-time + # metadata for logging. Runtime rank is auto-detected from ProcessGroup. + build_rank = state_dict.get(prefix + "rank", -1) + build_world_size = state_dict.get(prefix + "world_size", -1) + if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + else: + self.rank = -1 + self.world_size = -1 + if build_world_size > 1: + if build_rank != self.rank: + logger.info( + f"Distributed engine originally built on rank {build_rank} of {build_world_size}, " + f"now running on rank {self.rank} of {self.world_size}" + ) + else: + logger.info( + f"Distributed engine: rank {self.rank} of {self.world_size}" + ) + # Run multi-gpu device check to validate engine instantiation multi_gpu_device_check() self.setup_engine() @@ -357,10 +451,14 @@ def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() state.pop("engine", None) state.pop("context", None) + # NCCLcomm cannot be pickled + state.pop("_nccl_comm", None) return state def __setstate__(self, state: Dict[str, Any]) -> None: self.__dict__.update(state) + # reset after unpickling, apbose: is this required though? + self._nccl_comm = None self.setup_engine() def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule: @@ -699,6 +797,23 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: ): self._check_initialized() + if self.is_distributed and self._nccl_comm is None: + nccl_type = ( + "native TRT collectives" + if ENABLED_FEATURES.native_trt_collectives + else ( + "TRT-LLM NCCL plugins" + if ENABLED_FEATURES.trtllm_for_nccl + else "unknown backend" + ) + ) + logger.info( + f"Setting up NCCL for distributed execution using {nccl_type} " + f"(rank={self.rank}, world_size={self.world_size})" + ) + self.setup_nccl_comm() + logger.info(f"NCCL setup complete, comm={self._nccl_comm}") + # If in safe mode, check at each iteration for whether a switch is required if ( torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index d77c0bf39f..05aa0f05e8 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform from torch_tensorrt._features import ( @@ -36,6 +37,9 @@ TARGET_PLATFORM_IDX = -1 # Not implemented REQUIRES_OUTPUT_ALLOCATOR_IDX = -1 # Not implemented SERIALIZATION_LEN = -1 # Not implemented +IS_MD_ENGINE_IDX = -1 # Not implemented +OPTIONAL_RANK_IDX = -1 # Not implemented +OPTIONAL_WORLD_SIZE_IDX = -1 # Not implemented if ENABLED_FEATURES.torch_tensorrt_runtime: ABI_TARGET_IDX = torch.ops.tensorrt.ABI_TARGET_IDX() # 0 @@ -53,7 +57,10 @@ RESOURCE_ALLOCATION_STRATEGY_IDX = ( torch.ops.tensorrt.RESOURCE_ALLOCATION_STRATEGY_IDX() ) # 10 - SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 11 + IS_MD_ENGINE_IDX = torch.ops.tensorrt.IS_MD_ENGINE_IDX() # 11 + OPTIONAL_RANK_IDX = torch.ops.tensorrt.OPTIONAL_RANK_IDX() # 12 + OPTIONAL_WORLD_SIZE_IDX = torch.ops.tensorrt.OPTIONAL_WORLD_SIZE_IDX() # 13 + SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 14 @for_all_methods(needs_torch_tensorrt_runtime) @@ -203,6 +210,12 @@ def _pack_engine_info(self) -> List[str | bytes]: engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = str( int(self.dynamically_allocate_resources) ) + is_md = dist.is_initialized() and dist.get_world_size() > 1 + engine_info[IS_MD_ENGINE_IDX] = str(int(is_md)) + # serialized engine info for build time rank and world size + if is_md: + engine_info[OPTIONAL_RANK_IDX] = str(dist.get_rank()) + engine_info[OPTIONAL_WORLD_SIZE_IDX] = str(dist.get_world_size()) return engine_info @@ -239,6 +252,14 @@ def use_dynamically_allocated_resources( self.dynamically_allocate_resources ) + def _get_default_group_name(self) -> str: + """Get the group name of the default ProcessGroup.""" + if dist.is_available() and dist.is_initialized(): + pg = dist.group.WORLD + if pg is not None and hasattr(pg, "group_name"): + return str(pg.group_name) + return "" + def setup_engine(self) -> None: """ Setup engine for a module which has deferred engine setup. @@ -252,6 +273,18 @@ def setup_engine(self) -> None: return self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info()) + # Distributed setup is split into two calls for TorchTensorRTModule (C++ runtime): + # 1. detect_distributed_context: sets rank/world_size on C++ engine (here, at setup time) + # Needed so rank/world_size are available for serialization before any forward call. + # 2. setup_nccl_comm: gets NCCL comm and binds to TRT context (lazily, in forward) + # Deferred because NCCL comm is only needed at execution time. + # + # In PythonTorchTensorRTModule, this is a single setup_nccl_comm() call in forward + # because rank/world_size are set in __init__ via dist.get_rank(). + group_name = self._get_default_group_name() + if group_name: + self.engine.detect_distributed_context(group_name) + def encode_metadata(self, metadata: Any) -> str: metadata = copy.deepcopy(metadata) dumped_metadata = pickle.dumps(metadata) @@ -341,6 +374,17 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: if self.engine is None: raise RuntimeError("Engine has not been setup yet.") + # Lazy NCCL setup on first forward + if ( + not torch.compiler.is_exporting() + and self.engine.is_md + and not hasattr(self, "_nccl_initialized") + ): + group_name = self._get_default_group_name() + if group_name: + self.engine.setup_nccl_comm(group_name) + self._nccl_initialized = True + assert len(inputs) == len( self.input_binding_names ), f"Wrong number of inputs, expected {len(self.input_binding_names)} got {len(inputs)}." diff --git a/py/torch_tensorrt/dynamo/runtime/__init__.py b/py/torch_tensorrt/dynamo/runtime/__init__.py index 0eb66b24b0..63cabbef67 100644 --- a/py/torch_tensorrt/dynamo/runtime/__init__.py +++ b/py/torch_tensorrt/dynamo/runtime/__init__.py @@ -1,4 +1,9 @@ import torch_tensorrt +from torch_tensorrt.dynamo.runtime._nccl_utils import ( # noqa: F401 + check_nccl_library_path, + get_nccl_library_path, + setup_nccl_library, +) from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( # noqa: F401 PythonTorchTensorRTModule, ) diff --git a/py/torch_tensorrt/dynamo/runtime/_nccl_utils.py b/py/torch_tensorrt/dynamo/runtime/_nccl_utils.py new file mode 100644 index 0000000000..9f8519d715 --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/_nccl_utils.py @@ -0,0 +1,176 @@ +""" +NCCL Library Utilities for Distributed TensorRT Inference + +This module handles NCCL library path resolution to ensure TensorRT and PyTorch +use the same NCCL library instance. This is critical for sharing NCCL communicators +between PyTorch's distributed backend and TensorRT's native NCCL collectives. + +Background: +----------- +TensorRT's dlopen("libnccl.so") may load a different NCCL library than PyTorch, +causing crashes when sharing NCCL communicators. + +- PyTorch loads NCCL via RPATH baked at compile time (libnccl.so.2) +- TensorRT lazy-loads NCCL via dlopen("libnccl.so") at runtime + +The mismatch occurs because: +1. pip's nvidia-nccl-cu* package only ships libnccl.so.2 (no libnccl.so symlink) +2. TRT specifically looks for libnccl.so, misses pip's copy, falls back to system NCCL + +Environments: +------------- +- NGC containers: No action needed (both use system NCCL) +- pip install torch: Requires symlink + LD_LIBRARY_PATH setup + +Future: +------- +TensorRT 11.0 will support TRT_NCCL_LIBRARY env var, eliminating the need for +symlink workarounds. +""" + +import logging +import os +from typing import Optional + +logger = logging.getLogger(__name__) + +_nccl_setup_checked = False + + +def get_nccl_library_path() -> Optional[str]: + """ + Get the path to PyTorch's NCCL library directory. + + Returns: + Path to NCCL lib directory if nvidia.nccl package exists, None otherwise. + None indicates system NCCL is being used (e.g., NGC containers). + """ + try: + import nvidia.nccl + + nccl_lib_dir = os.path.join(list(nvidia.nccl.__path__)[0], "lib") + if os.path.isdir(nccl_lib_dir): + return nccl_lib_dir + return None + except ImportError: + # nvidia.nccl not installed - using system NCCL (e.g., NGC container) + return None + + +def ensure_nccl_symlink(nccl_lib_dir: str) -> bool: + """ + Ensure libnccl.so symlink exists pointing to libnccl.so.2. + + TensorRT's dlopen looks for "libnccl.so", but pip's nvidia-nccl package + only ships "libnccl.so.2". This creates the necessary symlink. + + Args: + nccl_lib_dir: Path to the NCCL library directory + + Returns: + True if symlink exists or was created, False otherwise. + """ + nccl_so = os.path.join(nccl_lib_dir, "libnccl.so") + nccl_so_2 = os.path.join(nccl_lib_dir, "libnccl.so.2") + + # Check if symlink already exists + if os.path.lexists(nccl_so): + return True + + # Check if target exists + if not os.path.exists(nccl_so_2): + logger.warning(f"NCCL library not found at {nccl_so_2}") + return False + + # Try to create symlink + try: + os.symlink("libnccl.so.2", nccl_so) + logger.info(f"Created NCCL symlink: {nccl_so} -> libnccl.so.2") + return True + except PermissionError: + logger.warning( + f"Cannot create NCCL symlink at {nccl_so} (permission denied). " + f"Please run: ln -sf libnccl.so.2 {nccl_so}" + ) + return False + except OSError as e: + logger.warning(f"Failed to create NCCL symlink: {e}") + return False + + +def check_nccl_library_path() -> bool: + """ + Check if LD_LIBRARY_PATH includes PyTorch's NCCL directory. + + Returns: + True if configuration is correct, False if LD_LIBRARY_PATH needs updating. + """ + nccl_lib_dir = get_nccl_library_path() + + if nccl_lib_dir is None: + # System NCCL - no action needed + return True + + ld_library_path = os.environ.get("LD_LIBRARY_PATH", "") + return nccl_lib_dir in ld_library_path + + +def setup_nccl_library() -> None: + """ + Setup NCCL library path for TensorRT distributed inference. + + This function: + 1. Detects if nvidia.nccl pip package is installed + 2. Creates libnccl.so symlink if needed + 3. Warns if LD_LIBRARY_PATH is not configured + + Call this before initializing NCCL communicators for TensorRT. + + For NGC containers (system NCCL), this is a no-op. + For pip-installed PyTorch, this ensures TensorRT can find the correct NCCL. + """ + global _nccl_setup_checked + + # Only check once per process + if _nccl_setup_checked: + return + _nccl_setup_checked = True + + nccl_lib_dir = get_nccl_library_path() + + if nccl_lib_dir is None: + # NGC container or system NCCL - no action needed + logger.debug( + "nvidia.nccl package not found. " + "Assuming system NCCL is used by both PyTorch and TensorRT." + ) + return + + logger.debug(f"Found nvidia.nccl package at: {nccl_lib_dir}") + + # Ensure symlink exists + symlink_ok = ensure_nccl_symlink(nccl_lib_dir) + + # Check LD_LIBRARY_PATH + ld_library_path = os.environ.get("LD_LIBRARY_PATH", "") + if nccl_lib_dir not in ld_library_path: + logger.warning( + f"\n" + f"{'=' * 70}\n" + f"NCCL LIBRARY PATH WARNING\n" + f"{'=' * 70}\n" + f"PyTorch's NCCL library directory is not in LD_LIBRARY_PATH.\n" + f"TensorRT may load a different NCCL library than PyTorch,\n" + f"causing distributed inference to fail.\n" + f"\n" + f"NCCL directory: {nccl_lib_dir}\n" + f"\n" + f"Please set before running:\n" + f" export LD_LIBRARY_PATH={nccl_lib_dir}:$LD_LIBRARY_PATH\n" + f"{'=' * 70}\n" + ) + else: + logger.debug(f"LD_LIBRARY_PATH includes NCCL directory: {nccl_lib_dir}") + + if symlink_ok: + logger.debug("NCCL library setup complete") diff --git a/third_party/libtorch/BUILD b/third_party/libtorch/BUILD index 37309f7209..45d45646a9 100644 --- a/third_party/libtorch/BUILD +++ b/third_party/libtorch/BUILD @@ -34,10 +34,12 @@ cc_library( hdrs = glob( [ "include/torch/**/*.h", + "include/torch/**/*.hpp", ], allow_empty = True, exclude = [ "include/torch/csrc/api/include/**/*.h", + "include/torch/csrc/api/include/**/*.hpp", ], ) + glob([ "include/torch/csrc/api/include/**/*.h", diff --git a/toolchains/torch_nccl/BUILD b/toolchains/torch_nccl/BUILD new file mode 100644 index 0000000000..ffd0fb0cdc --- /dev/null +++ b/toolchains/torch_nccl/BUILD @@ -0,0 +1 @@ +package(default_visibility = ["//visibility:public"]) diff --git a/toolchains/torch_nccl/defs.bzl b/toolchains/torch_nccl/defs.bzl new file mode 100644 index 0000000000..9a7a015c2c --- /dev/null +++ b/toolchains/torch_nccl/defs.bzl @@ -0,0 +1,60 @@ +"""NCCL detection for PyTorch builds.""" + +def _torch_nccl_detect_impl(repository_ctx): + """Detect if PyTorch was built with NCCL support.""" + + # Skip detection on non-Linux (NCCL not available) + os_name = repository_ctx.os.name.lower() + if "linux" not in os_name: + has_nccl = False + else: + # Find libtorch path + result = repository_ctx.execute([ + "python3", + "-c", + "import torch; import os; print(os.path.dirname(torch.__file__))", + ]) + + if result.return_code != 0: + has_nccl = False + else: + torch_path = result.stdout.strip() + lib_path = torch_path + "/lib/libtorch_cuda.so" + + # Check for ProcessGroupNCCL symbol + result = repository_ctx.execute([ + "grep", + "-q", + "ProcessGroupNCCL", + lib_path, + ]) + has_nccl = (result.return_code == 0) + + # Generate BUILD file with config_setting + repository_ctx.file("BUILD", """ +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") + +package(default_visibility = ["//visibility:public"]) + +bool_flag( + name = "use_nccl", + build_setting_default = {has_nccl}, +) + +config_setting( + name = "nccl_enabled", + flag_values = {{":use_nccl": "True"}}, +) +""".format(has_nccl = has_nccl)) + +torch_nccl_detect = repository_rule( + implementation = _torch_nccl_detect_impl, + local = True, # Re-run on each build to detect changes +) + +def if_torch_nccl(if_true, if_false = []): + """Returns if_true if PyTorch has NCCL, else if_false.""" + return select({ + "@torch_nccl//:nccl_enabled": if_true, + "//conditions:default": if_false, + }) diff --git a/tools/llm/tensor_parallel_llama_llm.py b/tools/llm/tensor_parallel_llama_llm.py new file mode 100644 index 0000000000..4ce2524e6d --- /dev/null +++ b/tools/llm/tensor_parallel_llama_llm.py @@ -0,0 +1,343 @@ +""" +.. _run_llm_tp: + +Tensor Parallel LLM inference with Torch-TensorRT +================================================== + +This script extends run_llm.py to support Tensor Parallelism (TP) across multiple GPUs. +Weights in Attention (Q, K, V, O projections) and MLP (gate, up, down projections) +are sharded across ranks using PyTorch's parallelize_module API. AllReduce is inserted +automatically by RowwiseParallel at the output projection of each sub-block. +Torch-TensorRT lowers the resulting collective ops into TRT ncclwrapper via TRT-MD. + +Usage +----- +.. code-block:: bash + + mpirun -n 2 python3 tensor_parallel_llama_llm.py \\ + --model meta-llama/Llama-3.2-1B-Instruct \\ + --prompt "What is parallel programming?" \\ + --model_precision FP16 --num_tokens 128 +""" + +import argparse +import logging +import os +from contextlib import nullcontext + +# Distributed init must happen before importing torch_tensorrt. +# mpirun sets OMPI_COMM_WORLD_LOCAL_RANK / OMPI_COMM_WORLD_SIZE but NOT the +# RANK/WORLD_SIZE/MASTER_ADDR/MASTER_PORT vars that PyTorch's env:// rendezvous +# expects, so we translate them here. +import torch +import torch.distributed as dist +import torch.distributed.tensor._dtensor_spec +import torch.utils._pytree +from torch.distributed.device_mesh import init_device_mesh + +# DTensorSpec appears in the graph during torch.export of a TP model. +# Register it as a pytree constant so the exporter treats it as a +# compile-time constant rather than a dynamic input. +torch.utils._pytree.register_constant( + torch.distributed.tensor._dtensor_spec.DTensorSpec +) + +_ompi_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0)) +_ompi_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) +os.environ.setdefault("RANK", str(_ompi_rank)) +os.environ.setdefault("WORLD_SIZE", str(_ompi_size)) +os.environ.setdefault("MASTER_ADDR", "127.0.0.1") +os.environ.setdefault("MASTER_PORT", "29501") + +dist.init_process_group(backend="nccl") +rank = dist.get_rank() +world_size = dist.get_world_size() +DEVICE = torch.device(f"cuda:{rank}") +torch.cuda.set_device(DEVICE) + + +def initialize_logger( + rank, logger_file_name, file_level=logging.DEBUG, console_level=logging.INFO +): + """Initialize rank-specific Torch-TensorRT logger with configurable handler levels.""" + logger = logging.getLogger("torch_tensorrt") + logger.setLevel(logging.DEBUG) + logger.handlers.clear() + + # File handler + fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") + fh.setLevel(file_level) + fh.setFormatter( + logging.Formatter( + f"[Rank {rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + ) + logger.addHandler(fh) + + # Console handler + ch = logging.StreamHandler() + ch.setLevel(console_level) + ch.setFormatter(logging.Formatter(f"[Rank {rank}] %(levelname)s: %(message)s")) + logger.addHandler(ch) + + logger.propagate = False + return logger + + +# Initialize logger for this rank +logger = initialize_logger( + rank, "llm_tp_log_mod", file_level=logging.DEBUG, console_level=logging.INFO +) +logger.info( + f"Initialized distributed environment: rank={rank}, world_size={world_size}, device={DEVICE}" +) + +import torch_tensorrt +from torch_tensorrt.dynamo.runtime._nccl_utils import setup_nccl_library + +setup_nccl_library() +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + parallelize_module, +) +from torchtrt_ext import register_sdpa +from transformers import AutoModelForCausalLM, AutoTokenizer +from utils import generate, record_stats, time_generate + + +def get_model(args, device_mesh): + with torch.no_grad(): + model = ( + AutoModelForCausalLM.from_pretrained( + args.model, + use_cache=False, + attn_implementation="sdpa", + ignore_mismatched_sizes=True, + ) + .eval() + .to(DEVICE) + ) + register_sdpa.enable_sdpa_converter(args.model, model.config) + + if args.model_precision == "FP16": + model = model.to(torch.float16) + elif args.model_precision == "BF16": + model = model.to(torch.bfloat16) + + # Build TP plan: ColwiseParallel for first linear in each pair, + # RowwiseParallel for second linear (inserts AllReduce when output_layouts=Replicate). + tp_plan = {} + for i in range(model.config.num_hidden_layers): + tp_plan.update( + { + f"model.layers.{i}.self_attn.q_proj": ColwiseParallel(), + f"model.layers.{i}.self_attn.k_proj": ColwiseParallel(), + f"model.layers.{i}.self_attn.v_proj": ColwiseParallel(), + f"model.layers.{i}.self_attn.o_proj": RowwiseParallel(), + f"model.layers.{i}.mlp.gate_proj": ColwiseParallel(), + f"model.layers.{i}.mlp.up_proj": ColwiseParallel(), + f"model.layers.{i}.mlp.down_proj": RowwiseParallel(), + } + ) + parallelize_module(model, device_mesh, tp_plan) + + # HuggingFace attention uses self.num_heads / self.num_key_value_heads to + # reshape Q/K/V outputs. After weight sharding each rank only holds + # num_heads // world_size columns, so these attributes must be updated or + # the reshape will produce a shape mismatch error. + for layer in model.model.layers: + layer.self_attn.num_heads = model.config.num_attention_heads // world_size + layer.self_attn.num_key_value_heads = ( + model.config.num_key_value_heads // world_size + ) + + return model + + +def compile_torchtrt(model, input_ids, args): + use_fp32_acc = False + use_explicit_typing = False + if args.model_precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + use_explicit_typing = True + elif args.model_precision == "BF16": + enabled_precisions = {torch.bfloat16} + else: + enabled_precisions = {torch.float32} + + # torch.export does not support DTensor-parallelized models (sharding propagation + # fails during run_decompositions). Use torch.compile with dynamic=True so that + # torch._dynamo traces via aot_autograd (use_distributed_mode_trace=True path) + # and builds a single TRT engine with dynamic sequence-length profiles. + with torch_tensorrt.logging.debug() if args.debug else nullcontext(): + trt_model = torch.compile( + model, + backend="torch_tensorrt", + dynamic=False, + options={ + "enabled_precisions": enabled_precisions, + "use_explicit_typing": use_explicit_typing, + "use_fp32_acc": use_fp32_acc, + "device": DEVICE, + "disable_tf32": True, + "use_python_runtime": True, + "use_distributed_mode_trace": True, + "debug": args.debug, + "min_block_size": args.min_block_size, + "assume_dynamic_shape_support": True, + }, + ) + + return trt_model + + +def print_outputs(backend_name, gen_tokens, tokenizer): + print(f"========= {backend_name} =========") + print( + f"{backend_name} model generated text: ", + tokenizer.decode(gen_tokens[0], skip_special_tokens=True), + ) + print("===================================") + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser( + description="Run tensor parallel LLM inference with Torch-TensorRT" + ) + arg_parser.add_argument( + "--model", + type=str, + default="meta-llama/Llama-3.2-1B-Instruct", + help="Name of LLM model", + ) + arg_parser.add_argument( + "--tokenizer", + type=str, + default="", + help="Name of LLM model tokenizer", + ) + arg_parser.add_argument( + "--prompt", + type=str, + default="What is parallel programming?", + help="Prompt", + ) + arg_parser.add_argument( + "--model_precision", + type=str, + default="FP16", + help="Precision to use in the model. Options: FP16, BF16, FP32", + ) + arg_parser.add_argument( + "--num_tokens", + type=int, + default=128, + help="Number of output tokens to generate", + ) + arg_parser.add_argument( + "--min_block_size", + type=int, + default=1, + help="Minimum block size for TensorRT compilation", + ) + arg_parser.add_argument( + "--benchmark", + action="store_true", + help="Enable benchmarking mode (default: False)", + ) + arg_parser.add_argument( + "--iterations", + type=int, + default=5, + help="Number of benchmark iterations", + ) + arg_parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size for benchmarking", + ) + arg_parser.add_argument( + "--isl", + type=int, + default=2048, + help="Input sequence length for benchmarking", + ) + arg_parser.add_argument( + "--debug", + action="store_true", + help="Enable debug logging (default: False)", + ) + args = arg_parser.parse_args() + + device_mesh = init_device_mesh("cuda", (world_size,)) + + with torch.inference_mode(): + model = get_model(args, device_mesh) + + assert model.config.num_key_value_heads % world_size == 0, ( + f"num_key_value_heads ({model.config.num_key_value_heads}) must be " + f"divisible by world_size ({world_size})." + ) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + if args.benchmark: + input_ids = torch.randint( + 1, 10000, (args.batch_size, args.isl), dtype=torch.int64 + ).to(DEVICE) + else: + model_inputs = tokenizer(args.prompt, return_tensors="pt") + input_ids = model_inputs["input_ids"].to(DEVICE) + + MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.num_tokens + + # Run uncompiled torch model first for comparison + torch_gen_tokens = generate( + model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + ) + if rank == 0: + print_outputs("Torch-TP (uncompiled)", torch_gen_tokens, tokenizer) + + trt_model = compile_torchtrt(model, input_ids, args) + + trt_gen_tokens = generate( + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + ) + + if args.benchmark: + trt_timings = time_generate( + generate, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + + if rank == 0: + if not args.benchmark: + print_outputs("TensorRT-TP", trt_gen_tokens, tokenizer) + else: + trt_stats = record_stats( + "TensorRT-TP", + trt_timings, + args.model_precision, + batch_size=args.batch_size, + compile_time_s=None, + ) + print("=========TensorRT-TP PERFORMANCE============") + print(trt_stats) + + dist.destroy_process_group() diff --git a/trtengine_change.md b/trtengine_change.md new file mode 100644 index 0000000000..3cd10c8036 --- /dev/null +++ b/trtengine_change.md @@ -0,0 +1,751 @@ +# Design Changes for PR #4157 + +This document contains the exact code changes for the redesigned Multi-Device TensorRT Runtime, addressing review comments. + +**Note:** Build config changes (MODULE.bazel, pyproject.toml, setup.py, py/requirements.txt) and debug logging additions (backends.py, remove_sym_nodes.py, partitioning/common.py, utils.py) are NOT included — those are local environment changes. + +--- + +## 1. `core/runtime/runtime.h` + +**Change:** ABI version bump and renamed serialization indices. + +```diff +-const std::string ABI_VERSION = "8"; ++const std::string ABI_VERSION = "9"; +``` + +```diff +- RANK_IDX, +- WORLD_SIZE_IDX, ++ IS_MD_ENGINE_IDX, ++ OPTIONAL_RANK_IDX, ++ OPTIONAL_WORLD_SIZE_IDX, + SERIALIZATION_LEN, +``` + +--- + +## 2. `core/runtime/TRTEngine.h` + +**Change:** Removed `set_rank`, `set_world_size`, `set_nccl_comm`, `init_nccl_comm`, `set_process_group_from_registry`. Added `detect_distributed_context` and `setup_nccl_comm`. + +```diff + // Distributed inference fields (-1 indicates non-distributed mode) + int64_t rank = -1; + int64_t world_size = -1; + +- // Set rank and world_size for distributed inference +- void set_rank(int64_t rank_val); +- void set_world_size(int64_t world_size_val); +- + #ifdef ENABLE_TRT_NCCL_COLLECTIVES + ncclComm_t nccl_comm = nullptr; +- void set_nccl_comm(int64_t comm_ptr); +- void init_nccl_comm(const std::string& group_name = "default"); +- bool set_process_group_from_registry(const std::string& group_name = "default"); ++ ++ // Detect rank and world_size from ProcessGroup ++ void detect_distributed_context(const std::string& group_name); ++ ++ // Resolve ProcessGroup, get NCCL communicator, and bind to TRT context ++ void setup_nccl_comm(const std::string& group_name); + bool set_nccl_communicator_to_trt_context(); + #endif +``` + +--- + +## 3. `core/runtime/TRTEngine.cpp` + +### 3a. Constructor 2 (deserialization) — log build-time rank, don't overwrite + +```diff + : ResourceAllocationStrategy::kStatic)) { +- // Load distributed info if available (backward compatible with older ABI versions) +- if (serialized_info.size() > RANK_IDX && !serialized_info[RANK_IDX].empty()) { +- this->rank = std::stoll(serialized_info[RANK_IDX]); +- } +- if (serialized_info.size() > WORLD_SIZE_IDX && !serialized_info[WORLD_SIZE_IDX].empty()) { +- this->world_size = std::stoll(serialized_info[WORLD_SIZE_IDX]); +- } ++ if (std::stoi(serialized_info[IS_MD_ENGINE_IDX])) { ++ int64_t build_rank = std::stoll(serialized_info[OPTIONAL_RANK_IDX]); ++ int64_t build_world_size = std::stoll(serialized_info[OPTIONAL_WORLD_SIZE_IDX]); ++ if (build_rank != this->rank) { ++ LOG_INFO( ++ "Distributed engine originally built on rank " << build_rank << " of " << build_world_size ++ << ", now running on rank " << this->rank << " of " << this->world_size); ++ } else { ++ LOG_INFO( ++ "Distributed engine: rank " << this->rank << " of " << this->world_size); ++ } ++ } + } +``` + +### 3b. Constructor 3 — no distributed logic (removed detect_distributed_context call) + +No changes to constructor 3. It is clean — no distributed code. + +### 3c. Removed set_rank, set_world_size + +```diff +-void TRTEngine::set_rank(int64_t rank_val) { +- this->rank = rank_val; +- LOG_DEBUG("Rank set on TRTEngine: " << this->rank); +-} +- +-void TRTEngine::set_world_size(int64_t world_size_val) { +- this->world_size = world_size_val; +- LOG_DEBUG("World size set on TRTEngine: " << this->world_size); +-} +``` + +### 3d. Removed set_nccl_comm, init_nccl_comm, set_process_group_from_registry + +All three functions removed entirely. + +### 3e. New: detect_distributed_context + +```cpp +#ifdef ENABLE_TRT_NCCL_COLLECTIVES +void TRTEngine::detect_distributed_context(const std::string& group_name) { + auto pg = c10d::resolve_process_group(group_name); + if (pg) { + this->rank = pg->getRank(); + this->world_size = pg->getSize(); + LOG_DEBUG("Detected distributed context: rank=" << this->rank << ", world_size=" << this->world_size); + } +} +``` + +### 3f. New: setup_nccl_comm (replaces set_process_group_from_registry) + +```cpp +void TRTEngine::setup_nccl_comm(const std::string& group_name) { + auto pg = c10d::resolve_process_group(group_name); + TORCHTRT_CHECK(pg != nullptr, "ProcessGroup '" << group_name << "' not found in registry"); + + auto backend = pg->getBackend(c10d::ProcessGroup::BackendType::NCCL); + TORCHTRT_CHECK(backend != nullptr, "ProcessGroup '" << group_name << "' has no NCCL backend"); + + auto* nccl_pg = dynamic_cast(backend.get()); + TORCHTRT_CHECK(nccl_pg != nullptr, "Backend is not ProcessGroupNCCL"); + + at::cuda::set_device(this->device_info.id); + + int64_t comm_ptr = nccl_pg->getCommPtr(); + TORCHTRT_CHECK( + comm_ptr != 0, + "NCCL communicator not initialized for device " << this->device_info.id + << ". Ensure a collective operation has been performed first."); + + this->nccl_comm = reinterpret_cast(comm_ptr); + set_nccl_communicator_to_trt_context(); + LOG_INFO("NCCL comm set up (rank=" << this->rank << ", device=" << this->device_info.id << ")"); +} +``` + +### 3g. set_nccl_communicator_to_trt_context — replaced try-catch with TORCHTRT_CHECK + +```diff + bool TRTEngine::set_nccl_communicator_to_trt_context() { +- if (!exec_ctx) { +- LOG_ERROR("Cannot set NCCL communicator: execution context is null"); +- return false; +- } +- if (this->nccl_comm == nullptr) { +- LOG_WARNING(...); +- return false; +- } +- try { +- void* comm_ptr = static_cast(this->nccl_comm); +- exec_ctx->setCommunicator(comm_ptr); +- LOG_INFO(...); +- return true; +- } catch (const std::exception& e) { +- LOG_ERROR(...); +- return false; +- } ++ TORCHTRT_CHECK(exec_ctx != nullptr, "Cannot set NCCL communicator: execution context is null"); ++ TORCHTRT_CHECK(this->nccl_comm != nullptr, "NCCL communicator is not set"); ++ ++ void* comm_ptr = static_cast(this->nccl_comm); ++ exec_ctx->setCommunicator(comm_ptr); ++ ++ LOG_INFO( ++ "NCCL communicator set on TensorRT execution context " ++ "(rank=" << this->rank << ", device=" << this->device_info.id << ")"); ++ return true; + } +``` + +### 3h. serialize() — write IS_MD_ENGINE and optional rank/world_size + +```diff + serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = + this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; ++ bool is_md = this->world_size > 1; ++ serialized_info[IS_MD_ENGINE_IDX] = is_md ? "1" : "0"; ++ if (is_md) { ++ serialized_info[OPTIONAL_RANK_IDX] = std::to_string(this->rank); ++ serialized_info[OPTIONAL_WORLD_SIZE_IDX] = std::to_string(this->world_size); ++ } + + return serialized_info; +``` + +--- + +## 4. `core/runtime/register_jit_hooks.cpp` + +### 4a. Removed old bindings, added new ones + +```diff + .def_readonly("rank", &TRTEngine::rank) + .def_readonly("world_size", &TRTEngine::world_size) +- .def("set_rank", &TRTEngine::set_rank) +- .def("set_world_size", &TRTEngine::set_world_size) + #ifdef ENABLE_TRT_NCCL_COLLECTIVES +- .def("set_nccl_comm", &TRTEngine::set_nccl_comm) + .def( +- "init_nccl_comm", +- [](c10::intrusive_ptr self, std::string group_name = "default") { +- self->init_nccl_comm(group_name); ++ "detect_distributed_context", ++ [](c10::intrusive_ptr self, std::string group_name) { ++ self->detect_distributed_context(group_name); ++ }) ++ .def( ++ "setup_nccl_comm", ++ [](c10::intrusive_ptr self, std::string group_name) { ++ self->setup_nccl_comm(group_name); + }) + #endif +``` + +### 4b. Updated constant names + +```diff +- m.def("RANK_IDX", []() -> int64_t { return RANK_IDX; }); +- m.def("WORLD_SIZE_IDX", []() -> int64_t { return WORLD_SIZE_IDX; }); ++ m.def("IS_MD_ENGINE_IDX", []() -> int64_t { return IS_MD_ENGINE_IDX; }); ++ m.def("OPTIONAL_RANK_IDX", []() -> int64_t { return OPTIONAL_RANK_IDX; }); ++ m.def("OPTIONAL_WORLD_SIZE_IDX", []() -> int64_t { return OPTIONAL_WORLD_SIZE_IDX; }); +``` + +--- + +## 5. `core/runtime/execute_engine.cpp` + +**Change:** Only binds NCCL comm to TRT context. Does NOT call `setup_nccl_comm` — Python handles it. + +```diff +- // Distributed setup - set NCCL communicator on TensorRT execution context ++ // Distributed setup - bind NCCL communicator to TRT execution context ++ // setup_nccl_comm must have been called from Python before first forward + #ifdef ENABLE_TRT_NCCL_COLLECTIVES +- if (compiled_engine->rank >= 0 && compiled_engine->world_size > 1) { +- bool result = compiled_engine->set_nccl_communicator_to_trt_context(); +- if (!result) { +- LOG_ERROR("Failed to set NCCL communicator on TRT context"); +- LOG_ERROR("This will cause collective operations to fail at runtime"); +- LOG_ERROR("Make sure to call module.init_nccl_comm() after compilation"); +- } +- } else { +- LOG_DEBUG( +- "Single-device mode (rank=" << compiled_engine->rank << ", world_size=" << compiled_engine->world_size +- << ") - skipping NCCL setup"); ++ if (compiled_engine->world_size > 1 && compiled_engine->nccl_comm != nullptr) { ++ compiled_engine->set_nccl_communicator_to_trt_context(); + } + #endif +``` + +--- + +## 6. `py/torch_tensorrt/dynamo/conversion/_conversion.py` + +**Change:** Removed rank/world_size detection and passing to module constructor. + +```diff +- rank = -1 +- world_size = -1 + if settings.use_distributed_mode_trace: +- import os +- import torch.distributed as dist + # Check if distributed backends are available + ... + + return rt_cls( + serialized_engine=..., + ... +- rank=rank, +- world_size=world_size, + ) +``` + +--- + +## 7. `py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py` + +### 7a. Updated constants + +```diff +- RANK_IDX = torch.ops.tensorrt.RANK_IDX() # 11 +- WORLD_SIZE_IDX = torch.ops.tensorrt.WORLD_SIZE_IDX() # 12 +- SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 13 ++ IS_MD_ENGINE_IDX = torch.ops.tensorrt.IS_MD_ENGINE_IDX() # 11 ++ OPTIONAL_RANK_IDX = torch.ops.tensorrt.OPTIONAL_RANK_IDX() # 12 ++ OPTIONAL_WORLD_SIZE_IDX = torch.ops.tensorrt.OPTIONAL_WORLD_SIZE_IDX() # 13 ++ SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 14 +``` + +### 7b. Constructor — removed rank/world_size args + +```diff + def __init__( + self, + serialized_engine: Optional[bytes] = None, + ... +- rank: int = -1, +- world_size: int = 1, + ): +``` + +Removed `self.rank = rank`, `self.world_size = world_size`, `self._nccl_setup_done`. + +### 7c. New helper: _get_default_group_name + +```python +def _get_default_group_name(self) -> str: + """Get the group name of the default ProcessGroup.""" + import torch.distributed as dist + if dist.is_available() and dist.is_initialized(): + pg = dist.group.WORLD + if pg is not None and hasattr(pg, "group_name"): + return pg.group_name + return "" +``` + +### 7d. setup_engine — calls detect_distributed_context + +```diff + def setup_engine(self) -> None: + if self.engine is not None: + return + self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info()) ++ ++ # Detect distributed context (rank/world_size) from ProcessGroup ++ group_name = self._get_default_group_name() ++ if group_name: ++ self.engine.detect_distributed_context(group_name) +``` + +### 7e. _pack_engine_info — uses dist.is_initialized + +```diff +- engine_info[RANK_IDX] = str(self.rank) +- engine_info[WORLD_SIZE_IDX] = str(self.world_size) ++ import torch.distributed as dist ++ is_md = dist.is_initialized() and dist.get_world_size() > 1 ++ engine_info[IS_MD_ENGINE_IDX] = str(int(is_md)) ++ if is_md: ++ engine_info[OPTIONAL_RANK_IDX] = str(dist.get_rank()) ++ engine_info[OPTIONAL_WORLD_SIZE_IDX] = str(dist.get_world_size()) +``` + +### 7f. forward — lazy NCCL setup + +```diff + def forward(self, *inputs): + if self.engine is None: + raise RuntimeError("Engine has not been setup yet.") + ++ # Lazy NCCL setup on first forward ++ if self.engine.world_size > 1 and not hasattr(self, '_nccl_initialized'): ++ group_name = self._get_default_group_name() ++ if group_name: ++ self.engine.setup_nccl_comm(group_name) ++ self._nccl_initialized = True ++ + assert len(inputs) == len(self.input_binding_names), ... +``` + +### 7g. Removed functions + +- `_auto_init_distributed()` — replaced by lazy setup in forward +- `set_distributed_info()` — called removed `set_rank`/`set_world_size` +- `init_nccl_comm()` — replaced by `setup_nccl_comm` in forward +- `setup_nccl_library` import — no longer needed + +--- + +## 8. `py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py` + +### 8a. Constructor — removed rank/world_size args, auto-detect + +```diff + def __init__( + self, + ... +- rank: int = -1, +- world_size: int = 1, + ): + ... +- self.rank = rank +- self.world_size = world_size ++ # Auto-detect distributed context ++ import torch.distributed as dist ++ if dist.is_initialized(): ++ self.rank = dist.get_rank() ++ self.world_size = dist.get_world_size() ++ else: ++ self.rank = -1 ++ self.world_size = -1 + self._nccl_comm: Optional[Any] = None +``` + +### 8b. Simplified setup_nccl_comm + +Replaced `setup_nccl`, `set_nccl_communicator`, `get_nccl_communicator`, `_get_nccl_comm_from_process_group`, `_create_nccl_comm_via_nccl_lib` with a single function: + +```python +def setup_nccl_comm(self) -> None: + """Set up NCCL communicator from PyTorch's ProcessGroup. + + In PythonTorchTensorRTModule, this is a single call that gets the NCCL comm + and binds it to the TRT context. rank/world_size are already set in __init__ + via dist.get_rank(). + + In TorchTensorRTModule (C++ runtime), this is split into two calls: + - detect_distributed_context(group_name): sets rank/world_size on the C++ engine + (called in setup_engine, needed for serialization before forward) + - setup_nccl_comm(group_name): gets NCCL comm and binds to TRT context + (called lazily on first forward) + """ + if not self.is_distributed: + return + + setup_nccl_library() + + import torch.distributed as dist + if not dist.is_initialized(): + raise RuntimeError( + "torch.distributed must be initialized before calling setup_nccl(). " + "Call dist.init_process_group('nccl') first." + ) + + pg = dist.group.WORLD + if pg is None or dist.get_backend(pg) != "nccl": + raise RuntimeError("Default ProcessGroup must use NCCL backend") + + backend = pg._get_backend(torch.device("cuda")) + + # Force NCCL communicator initialization with a dummy collective + dummy = torch.zeros(1, device="cuda") + dist.all_reduce(dummy) + + comm_ptr = backend._comm_ptr() + if comm_ptr is None or comm_ptr == 0: + raise RuntimeError("Failed to get NCCL communicator from ProcessGroup") + + self._nccl_comm = comm_ptr + + # Bind communicator to TRT execution context (PyCapsule required by TRT Python API) + if self.context is not None: + import ctypes + ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object + ctypes.pythonapi.PyCapsule_New.argtypes = [ + ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p, + ] + comm_capsule = ctypes.pythonapi.PyCapsule_New(comm_ptr, None, None) + self.context.set_communicator(comm_capsule) + + logger.info(f"NCCL comm set up (rank={self.rank}, world_size={self.world_size})") +``` + +### 8c. Removed functions + +- `get_rank()`, `get_world_size()` — fields are public +- `set_nccl_communicator()` — merged into `setup_nccl` +- `get_nccl_communicator()` — `_nccl_comm` is accessible +- `has_native_trt_collectives` property — use `ENABLED_FEATURES.native_trt_collectives` +- `_create_nccl_comm_via_nccl_lib()` — removed `nccl.core` dependency +- `_get_nccl_comm_from_process_group()` — merged into `setup_nccl` + +### 8d. _load_from_state_dict — auto-detect rank, log build-time rank + +```diff + self.target_platform = state_dict[prefix + "platform"] +- self.rank = state_dict.get(prefix + "rank", -1) +- self.world_size = state_dict.get(prefix + "world_size", -1) ++ ++ build_rank = state_dict.get(prefix + "rank", -1) ++ build_world_size = state_dict.get(prefix + "world_size", -1) ++ import torch.distributed as dist ++ if dist.is_initialized(): ++ self.rank = dist.get_rank() ++ self.world_size = dist.get_world_size() ++ else: ++ self.rank = -1 ++ self.world_size = -1 ++ if build_world_size > 1: ++ if build_rank != self.rank: ++ logger.info( ++ f"Distributed engine originally built on rank {build_rank} of {build_world_size}, " ++ f"now running on rank {self.rank} of {self.world_size}" ++ ) ++ else: ++ logger.info(f"Distributed engine: rank {self.rank} of {self.world_size}") +``` + +### 8e. forward — uses ENABLED_FEATURES directly + +```diff + if self.is_distributed and self._nccl_comm is None: + nccl_type = ( + "native TRT collectives" +- if self.has_native_trt_collectives ++ if ENABLED_FEATURES.native_trt_collectives + else ( +``` + +--- + +## 9. Remove `nccl.h` dependency — use `void*` for NCCL communicator + +**Rationale:** `nccl.h` is not a Bazel dependency — it's picked up from the system/PyTorch install path. Using `void*` instead of `ncclComm_t` removes this fragile dependency. We don't own the communicator (PyTorch's ProcessGroupNCCL owns it), so we just pass it as an opaque pointer to TRT's `setCommunicator(void*)`. + +### `core/runtime/TRTEngine.h` + +```diff + #ifdef ENABLE_TRT_NCCL_COLLECTIVES +-#include ++// Using void* instead of ncclComm_t to avoid nccl.h dependency. ++// We don't own the communicator — it's owned by PyTorch's ProcessGroupNCCL. + #endif +``` + +```diff + #ifdef ENABLE_TRT_NCCL_COLLECTIVES +- ncclComm_t nccl_comm = nullptr; ++ void* nccl_comm = nullptr; +``` + +### `core/runtime/TRTEngine.cpp` + +In `setup_nccl_comm`: +```diff +- this->nccl_comm = reinterpret_cast(comm_ptr); ++ this->nccl_comm = reinterpret_cast(comm_ptr); +``` + +In `set_nccl_communicator_to_trt_context`: +```diff +- void* comm_ptr = static_cast(this->nccl_comm); +- exec_ctx->setCommunicator(comm_ptr); ++ exec_ctx->setCommunicator(this->nccl_comm); +``` + +Also update section 3f `setup_nccl_comm` code to use `void*`: + +In section 3f above, replace: +```cpp + this->nccl_comm = reinterpret_cast(comm_ptr); +``` +with: +```cpp + this->nccl_comm = reinterpret_cast(comm_ptr); +``` + +And section 3g `set_nccl_communicator_to_trt_context` simplifies to: +```cpp +bool TRTEngine::set_nccl_communicator_to_trt_context() { + TORCHTRT_CHECK(exec_ctx != nullptr, "Cannot set NCCL communicator: execution context is null"); + TORCHTRT_CHECK(this->nccl_comm != nullptr, "NCCL communicator is not set"); + + exec_ctx->setCommunicator(this->nccl_comm); + + LOG_INFO( + "NCCL communicator set on TensorRT execution context " + "(rank=" << this->rank << ", device=" << this->device_info.id << ")"); + return true; +} +``` + +--- + +## Compatibility bug fixes (for PyTorch 2.10 / NGC 26.01) + +These are separate from the design changes but needed to run on the test environment: + +### `_FakeTensorUpdater.py` — guard for `torch._inductor.fx_passes.reinplace` + +```python +is_scatter = False +if hasattr(torch._inductor.fx_passes, "reinplace"): + is_scatter = ( + node.target + is torch._inductor.fx_passes.reinplace._generalized_scatter + ) +``` + +### `fuse_distributed_ops.py` — handle all_reduce with 3 args + +```python +fused_args = tuple(node.args) +if len(fused_args) < 4: + logger.debug(f"all_reduce node has {len(fused_args)} args instead of 4") +``` + +### `_TorchTensorRTModule.py` — typo fix + +```diff +- self.int_nccl_comm(pg) ++ self.init_nccl_comm(pg) +``` + +(This line is now removed entirely in the redesign, but was needed for the original PR.) + +--- + +## 10. Move `import torch.distributed as dist` to top-level + +Both Python runtime modules had `import torch.distributed as dist` scattered as local imports +inside multiple functions. Moved to top-level since `torch.distributed` is part of PyTorch +(no external dependency). + +### `_TorchTensorRTModule.py` + +```diff + import torch ++import torch.distributed as dist + from torch_tensorrt._Device import Device +``` + +Removed local imports from `_pack_engine_info()` and `_get_default_group_name()`. + +### `_PythonTorchTensorRTModule.py` + +```diff + import torch ++import torch.distributed as dist + import torch_tensorrt +``` + +Removed local imports from `__init__()`, `setup_nccl_comm()`, and `_load_from_state_dict()`. + +Added comment in `_load_from_state_dict` explaining the design rule: + +```python +# Same rule as C++ TRTEngine: serialized rank/world_size are build-time +# metadata for logging. Runtime rank is auto-detected from ProcessGroup. +build_rank = state_dict.get(prefix + "rank", -1) +build_world_size = state_dict.get(prefix + "world_size", -1) +if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() +else: + self.rank = -1 + self.world_size = -1 +if build_world_size > 1: + if build_rank != self.rank: + logger.info( + f"Distributed engine originally built on rank {build_rank} of {build_world_size}, " + f"now running on rank {self.rank} of {self.world_size}" + ) + else: + logger.info(f"Distributed engine: rank {self.rank} of {self.world_size}") +``` + +--- + +## 11. Function naming: `setup_nccl_comm` in both runtimes + +Both runtime modules use `setup_nccl_comm` as the function name for setting up NCCL, +but they work differently due to the C++ vs Python runtime distinction: + +### `_TorchTensorRTModule` (C++ runtime) — two separate calls + +```python +# In setup_engine(): +self.engine.detect_distributed_context(group_name) # sets rank/world_size on C++ engine + +# In forward() (lazily): +self.engine.setup_nccl_comm(group_name) # gets NCCL comm, binds to TRT context +``` + +**Why split:** rank/world_size must be available for serialization before any forward call. +The NCCL communicator is only needed at execution time. + +### `_PythonTorchTensorRTModule` (Python runtime) — single call + +```python +# In forward() (lazily): +self.setup_nccl_comm() # gets NCCL comm, converts to PyCapsule, sets on TRT context +``` + +**Why single:** rank/world_size are already set in `__init__` via `dist.get_rank()`. +No C++ engine to populate. Only need to get the NCCL comm and bind it. + +Comment in `_PythonTorchTensorRTModule.setup_nccl_comm`: +```python +def setup_nccl_comm(self) -> None: + """Set up NCCL communicator from PyTorch's ProcessGroup. + + In PythonTorchTensorRTModule, this is a single call that gets the NCCL comm + and binds it to the TRT context. rank/world_size are already set in __init__ + via dist.get_rank(). + + In TorchTensorRTModule (C++ runtime), this is split into two calls: + - detect_distributed_context(group_name): sets rank/world_size on the C++ engine + (called in setup_engine, needed for serialization before forward) + - setup_nccl_comm(group_name): gets NCCL comm and binds to TRT context + (called lazily on first forward) + """ +``` + +## 12. Move `setup_nccl_library()` to user scripts + +**Rationale:** `setup_nccl_library()` sets `LD_LIBRARY_PATH` so TensorRT can find `libnccl.so`. +This is a one-time environment setup, not an engine-level concern. The reviewer said this +should be a utility the user calls, not hidden inside engine code. + +### `_PythonTorchTensorRTModule.py` — removed call and import + +```diff +-from torch_tensorrt.dynamo.runtime._nccl_utils import setup_nccl_library +``` + +```diff + def setup_nccl_comm(self) -> None: + if not self.is_distributed: + return + +- setup_nccl_library() +- + if not dist.is_initialized(): +``` + +### Example scripts — added call after imports + +`examples/distributed_inference/tensor_parallel_simple_example.py`: +```python +import torch_tensorrt +from torch_tensorrt.dynamo.runtime._nccl_utils import setup_nccl_library + +setup_nccl_library() +``` + +`tools/llm/tensor_parallel_llama_llm.py`: +```python +import torch_tensorrt +from torch_tensorrt.dynamo.runtime._nccl_utils import setup_nccl_library + +setup_nccl_library() +``` + +The user is now responsible for calling `setup_nccl_library()` once before +distributed TRT inference. The function remains in `_nccl_utils` as a public utility.