-
Notifications
You must be signed in to change notification settings - Fork 392
Multi-Device TensorRT Runtime with Native NCCL Collectives #4157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
6152d3a
d3a2ef5
02268c0
5b6abf4
8f911f5
b791b40
b5c203e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,13 @@ | |
| #include "core/util/prelude.h" | ||
| #include "torch/torch.h" | ||
|
|
||
| #ifdef ENABLE_TRT_NCCL_COLLECTIVES | ||
| #include "torch/csrc/distributed/c10d/GroupRegistry.hpp" | ||
| #include "torch/csrc/distributed/c10d/NCCLUtils.hpp" | ||
| #include "torch/csrc/distributed/c10d/ProcessGroup.hpp" | ||
| #include "torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp" | ||
| #endif | ||
|
|
||
| namespace torch_tensorrt { | ||
| namespace core { | ||
| namespace runtime { | ||
|
|
@@ -88,7 +95,15 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info) | |
| serialized_info[SERIALIZED_METADATA_IDX], | ||
| (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) | ||
| ? ResourceAllocationStrategy::kDynamic | ||
| : ResourceAllocationStrategy::kStatic)) {} | ||
| : ResourceAllocationStrategy::kStatic)) { | ||
| // Load distributed info if available (backward compatible with older ABI versions) | ||
| if (serialized_info.size() > RANK_IDX && !serialized_info[RANK_IDX].empty()) { | ||
| this->rank = std::stoll(serialized_info[RANK_IDX]); | ||
| } | ||
| if (serialized_info.size() > WORLD_SIZE_IDX && !serialized_info[WORLD_SIZE_IDX].empty()) { | ||
| this->world_size = std::stoll(serialized_info[WORLD_SIZE_IDX]); | ||
| } | ||
| } | ||
|
|
||
| TRTEngine::TRTEngine( | ||
| const std::string& mod_name, | ||
|
|
@@ -519,6 +534,149 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt | |
| } | ||
| } | ||
|
|
||
| void TRTEngine::set_rank(int64_t rank_val) { | ||
| this->rank = rank_val; | ||
| LOG_DEBUG("Rank set on TRTEngine: " << this->rank); | ||
| } | ||
|
|
||
| void TRTEngine::set_world_size(int64_t world_size_val) { | ||
| this->world_size = world_size_val; | ||
| LOG_DEBUG("World size set on TRTEngine: " << this->world_size); | ||
| } | ||
|
|
||
| #ifdef ENABLE_TRT_NCCL_COLLECTIVES | ||
| void TRTEngine::set_nccl_comm(int64_t comm_ptr) { | ||
| this->nccl_comm = reinterpret_cast<ncclComm_t>(comm_ptr); | ||
| LOG_DEBUG("NCCL communicator stored on TRTEngine (rank=" << this->rank << ")"); | ||
|
|
||
| // Also set on TensorRT execution context | ||
| set_nccl_communicator_to_trt_context(); | ||
| } | ||
|
|
||
| bool TRTEngine::set_nccl_communicator_to_trt_context() { | ||
| // Set NCCL communicator on TensorRT execution context | ||
| // The communicator should be set from Python via set_nccl_comm() or set_process_group() | ||
|
|
||
| if (!exec_ctx) { | ||
| LOG_ERROR("Cannot set NCCL communicator: execution context is null"); | ||
| return false; | ||
| } | ||
|
|
||
| if (this->nccl_comm == nullptr) { | ||
|
||
| LOG_WARNING( | ||
| "Distributed inference enabled but no NCCL communicator set. " | ||
| "Call set_process_group() or set_nccl_comm() from Python first."); | ||
| return false; | ||
| } | ||
|
|
||
| // Set NCCL communicator on TensorRT execution context | ||
| try { | ||
|
||
| // Cast ncclComm_t to void* for TensorRT API | ||
| void* comm_ptr = static_cast<void*>(this->nccl_comm); | ||
|
|
||
| // Set the NCCL communicator on the execution context | ||
| // The device ID is used to identify which GPU's communicator this is | ||
| exec_ctx->setCommunicator(comm_ptr); | ||
|
|
||
| LOG_INFO( | ||
| "NCCL communicator set on TensorRT execution context " | ||
| "(rank=" | ||
| << this->rank << ", device=" << this->device_info.id << ")"); | ||
| return true; | ||
| } catch (const std::exception& e) { | ||
| LOG_ERROR("Failed to set NCCL communicator on execution context: " << e.what()); | ||
| return false; | ||
| } | ||
| } | ||
|
|
||
| void TRTEngine::init_nccl_comm(const std::string& group_name) { | ||
|
||
| // Use C++ registry to get NCCL communicator | ||
| set_process_group_from_registry(group_name); | ||
| } | ||
|
|
||
| bool TRTEngine::set_process_group_from_registry(const std::string& group_name) { | ||
| // Get ProcessGroup from C++ registry and extract NCCL communicator | ||
| // This avoids the need to pass the ProcessGroup from Python | ||
| LOG_INFO("TRTEngine::set_process_group_from_registry() called with group_name: " << group_name); | ||
| LOG_INFO(" Current rank: " << this->rank); | ||
| LOG_INFO(" Current world_size: " << this->world_size); | ||
| LOG_INFO(" Current device_id: " << this->device_info.id); | ||
|
|
||
| try { | ||
|
||
| // Resolve ProcessGroup from the native registry | ||
| auto pg = c10d::resolve_process_group(group_name); | ||
| if (!pg) { | ||
| LOG_ERROR("Failed to resolve ProcessGroup '" << group_name << "' from registry"); | ||
| return false; | ||
| } | ||
| LOG_INFO(" Resolved ProcessGroup from registry: rank=" << pg->getRank() << ", size=" << pg->getSize()); | ||
|
|
||
| // Update rank and world_size from the ProcessGroup if not already set | ||
| if (this->rank < 0) { | ||
| this->rank = pg->getRank(); | ||
| LOG_INFO(" Set rank from ProcessGroup: " << this->rank); | ||
| } | ||
| if (this->world_size < 0) { | ||
| this->world_size = pg->getSize(); | ||
| LOG_INFO(" Set world_size from ProcessGroup: " << this->world_size); | ||
| } | ||
|
|
||
| // Get the NCCL backend from the ProcessGroup | ||
| // ProcessGroup wraps Backend objects - we need to get the NCCL backend explicitly | ||
| c10::intrusive_ptr<c10d::Backend> backend; | ||
| try { | ||
| backend = pg->getBackend(c10d::ProcessGroup::BackendType::NCCL); | ||
| } catch (const std::exception& e) { | ||
| LOG_ERROR("Failed to get NCCL backend from ProcessGroup: " << e.what()); | ||
| return false; | ||
| } | ||
|
|
||
| if (!backend) { | ||
| LOG_ERROR("ProcessGroup '" << group_name << "' does not have an NCCL backend"); | ||
| return false; | ||
| } | ||
| LOG_INFO(" Got NCCL backend from ProcessGroup"); | ||
|
|
||
| // Cast the backend to ProcessGroupNCCL | ||
| auto* nccl_pg = dynamic_cast<c10d::ProcessGroupNCCL*>(backend.get()); | ||
|
||
| if (!nccl_pg) { | ||
| LOG_ERROR("Backend is not ProcessGroupNCCL (unexpected)"); | ||
| return false; | ||
| } | ||
| LOG_INFO(" Successfully cast to ProcessGroupNCCL"); | ||
|
|
||
| // Set current CUDA device to match the engine's device before getting comm | ||
| // getCommPtr() uses at::cuda::current_device() internally | ||
| at::cuda::set_device(this->device_info.id); | ||
| LOG_INFO(" Set current CUDA device to: " << this->device_info.id); | ||
|
|
||
| // Get NCCL comm pointer using the public getCommPtr() method | ||
| // This returns the communicator for the current CUDA device | ||
| int64_t comm_ptr = nccl_pg->getCommPtr(); | ||
| if (comm_ptr == 0) { | ||
| LOG_ERROR( | ||
| "Failed to get NCCL communicator for device " << this->device_info.id | ||
| << ". The communicator may not be initialized yet."); | ||
| LOG_ERROR("Hint: Ensure a collective operation has been performed on this device first."); | ||
| return false; | ||
| } | ||
|
|
||
| // Convert int64_t pointer to ncclComm_t | ||
| ncclComm_t comm = reinterpret_cast<ncclComm_t>(comm_ptr); | ||
|
|
||
| this->nccl_comm = comm; | ||
| LOG_INFO(" Successfully extracted NCCL communicator from registry"); | ||
| LOG_INFO(" nccl_comm: " << (void*)this->nccl_comm); | ||
| // Set on TensorRT execution context | ||
| return True; | ||
|
|
||
| } catch (const std::exception& e) { | ||
| LOG_ERROR("Failed to get ProcessGroup from registry: " << e.what()); | ||
| return false; | ||
| } | ||
| } | ||
| #endif // ENABLE_TRT_NCCL_COLLECTIVES | ||
|
|
||
| } // namespace runtime | ||
| } // namespace core | ||
| } // namespace torch_tensorrt | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,12 +9,29 @@ | |
| #include "ATen/core/function_schema.h" | ||
| #include "ATen/cuda/CUDAGraph.h" | ||
| #include "NvInfer.h" | ||
| #include "NvInferVersion.h" | ||
| #include "c10/cuda/CUDAStream.h" | ||
| #include "torch/custom_class.h" | ||
|
|
||
| #include "core/runtime/TRTEngineProfiler.h" | ||
| #include "core/util/prelude.h" | ||
|
|
||
| // TensorRT 10.16+ has native NCCL collective support via IExecutionContext::setCommunicator() | ||
| #if NV_TENSORRT_MAJOR > 10 || (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR >= 16) | ||
| #define TRT_HAS_NATIVE_NCCL 1 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work for say jetson or windows? |
||
| #endif | ||
|
|
||
| // Full TRT NCCL collectives support requires both: | ||
| // 1. PyTorch built with NCCL (USE_C10D_NCCL defined via Bazel) | ||
| // 2. TensorRT 10.16+ (TRT_HAS_NATIVE_NCCL defined above) | ||
| #if defined(USE_C10D_NCCL) && defined(TRT_HAS_NATIVE_NCCL) | ||
| #define ENABLE_TRT_NCCL_COLLECTIVES 1 | ||
| #endif | ||
|
|
||
| #ifdef ENABLE_TRT_NCCL_COLLECTIVES | ||
| #include <nccl.h> | ||
|
||
| #endif | ||
|
|
||
| namespace torch_tensorrt { | ||
| namespace core { | ||
| namespace runtime { | ||
|
|
@@ -196,6 +213,22 @@ struct TRTEngine : torch::CustomClassHolder { | |
| bool use_output_allocator_outputs = false; // users specify to use output allocator | ||
| std::shared_ptr<DynamicOutputAllocator> output_allocator; | ||
|
|
||
| // Member variables for distributed inference (-1 indicates non-distributed mode) | ||
| int64_t rank = -1; | ||
| int64_t world_size = -1; | ||
|
|
||
| // Set rank and world_size for distributed inference | ||
| void set_rank(int64_t rank_val); | ||
|
||
| void set_world_size(int64_t world_size_val); | ||
|
|
||
| #ifdef ENABLE_TRT_NCCL_COLLECTIVES | ||
| ncclComm_t nccl_comm = nullptr; | ||
| void set_nccl_comm(int64_t comm_ptr); | ||
| void init_nccl_comm(const std::string& group_name = "default"); | ||
| bool set_process_group_from_registry(const std::string& group_name = "default"); | ||
| bool set_nccl_communicator_to_trt_context(); | ||
| #endif | ||
|
|
||
| // TODO: Implement a call method | ||
| // c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs); | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,6 +39,8 @@ typedef enum { | |
| TARGET_PLATFORM_IDX, | ||
| REQUIRES_OUTPUT_ALLOCATOR_IDX, | ||
| RESOURCE_ALLOCATION_STRATEGY_IDX, | ||
| RANK_IDX, | ||
|
||
| WORLD_SIZE_IDX, | ||
|
||
| SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO | ||
| } SerializedInfoIndex; | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We dont need backwards compat unless some semantic definition changed, just bump the version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rank and world_size fields are optional — they are only populated for distributed engines. The size/empty guards ensure backward compatibility with non-distributed engines and older ABI versions where these fields don't exist. Since these fields are not always present, bumping the ABI version would unnecessarily break non-distributed engines that don't use them.
Do you want something like we bump the ABI version and introduce a field to handle the rank and world_size fields? Something like IS_MD_ENGINE as you said? Would that not just be a different approach?