Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ new_local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local.

local_torch = use_repo_rule("//toolchains:local_torch.bzl", "local_torch")

local_executorch = use_repo_rule("//toolchains:local_executorch.bzl", "local_executorch")

# Detect the locally installed ExecuTorch source tree at build time.
# Set EXECUTORCH_PATH to the directory containing runtime/, extension/, cmake-out/.
# Requires cmake-out/libexecutorch_core.a to be built first.
local_executorch(name = "executorch")

# External dependency for torch_tensorrt if you already have precompiled binaries.
new_local_repository(
name = "torch_tensorrt",
Expand Down Expand Up @@ -77,6 +84,15 @@ local_torch(name = "libtorch")
# build_file = "third_party/libtorch/BUILD"
#)

# ExecuTorch source tree. The repository root is the *parent* of the
# executorch/ directory so that headers resolve as <executorch/runtime/...>.
# Requires a cmake-out/ build inside the executorch source tree.
#new_local_repository(
# name = "executorch",
# build_file = "@//third_party/executorch:BUILD",
# path = "/home/lanl/git/executorch",
#)

#new_local_repository(
# name = "tensorrt",
# path = "/usr/",
Expand Down
32 changes: 29 additions & 3 deletions core/runtime/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_pkg//:pkg.bzl", "pkg_tar")
load("@rules_pkg//pkg:mappings.bzl", "pkg_files")

package(default_visibility = ["//visibility:public"])

config_setting(
Expand Down Expand Up @@ -58,17 +59,22 @@ config_setting(
],
)

# runtime_base: TRTEngine + device management + serialization utilities.
# Does NOT include register_jit_hooks.cpp (TorchScript torch::class_ /
# TORCH_LIBRARY registrations), so it can be linked into
# libtrt_executorch_backend.so without causing a duplicate-registration
# crash when libtorchtrt.so is also loaded in the same process.
cc_library(
name = "runtime",
name = "runtime_base",
srcs = [
"DeviceList.cpp",
"Platform.cpp",
"RTDevice.cpp",
"TRTEngine.cpp",
"TRTEngineProfiler.cpp",
"execute_engine.cpp",
"register_jit_hooks.cpp",
"runtime.cpp",
"runtime_utils.cpp",
],
hdrs = [
"Platform.h",
Expand Down Expand Up @@ -100,6 +106,26 @@ cc_library(
alwayslink = True,
)

# runtime: full runtime including TorchScript torch::class_ / TORCH_LIBRARY
# registrations. Used by the main libtorchtrt.so.
cc_library(
name = "runtime",
srcs = [
"register_jit_hooks.cpp",
],
hdrs = [
"Platform.h",
"RTDevice.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
"runtime.h",
],
deps = [
":runtime_base",
],
alwayslink = True,
)

filegroup(
name = "include_files",
srcs = [
Expand All @@ -121,6 +147,6 @@ pkg_tar(
pkg_files(
name = "include_pkg_files",
srcs = [":include_files"],
visibility = ["//visibility:public"],
prefix = "include/torch_tensorrt/core/runtime/",
visibility = ["//visibility:public"],
)
23 changes: 23 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,29 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
}

void TRTEngine::verify_serialization_fmt(const std::vector<std::string>& serialized_info) {
static const char* kIndexNames[] = {
"ABI_TARGET_IDX",
"NAME_IDX",
"DEVICE_IDX",
"ENGINE_IDX",
"INPUT_BINDING_NAMES_IDX",
"OUTPUT_BINDING_NAMES_IDX",
"HW_COMPATIBLE_IDX",
"SERIALIZED_METADATA_IDX",
"TARGET_PLATFORM_IDX",
"REQUIRES_OUTPUT_ALLOCATOR_IDX",
"RESOURCE_ALLOCATION_STRATEGY_IDX",
};
fprintf(stderr, "[verify_serialization_fmt] %zu entries (expected %d):\n", serialized_info.size(), SERIALIZATION_LEN);
for (size_t i = 0; i < serialized_info.size(); ++i) {
const char* name = (i < sizeof(kIndexNames) / sizeof(kIndexNames[0])) ? kIndexNames[i] : "?";
if (i == ENGINE_IDX) {
fprintf(stderr, " [%zu] %-35s = <binary, %zu bytes>\n", i, name, serialized_info[i].size());
} else {
fprintf(stderr, " [%zu] %-35s = \"%s\"\n", i, name, serialized_info[i].c_str());
}
}

TORCHTRT_CHECK(
serialized_info.size() == SERIALIZATION_LEN,
"Program to be deserialized targets an incompatible Torch-TensorRT ABI");
Expand Down
82 changes: 82 additions & 0 deletions core/runtime/executorch/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
load("@rules_cc//cc:defs.bzl", "cc_library")

package(default_visibility = ["//visibility:public"])

config_setting(
name = "use_torch_whl",
flag_values = {
"//toolchains/dep_src:torch": "whl",
},
)

config_setting(
name = "rtx_x86_64",
constraint_values = [
"@platforms//cpu:x86_64",
"@platforms//os:linux",
],
flag_values = {
"//toolchains/dep_collection:compute_libs": "rtx",
},
)

config_setting(
name = "rtx_win",
constraint_values = [
"@platforms//os:windows",
],
flag_values = {
"//toolchains/dep_collection:compute_libs": "rtx",
},
)

config_setting(
name = "sbsa",
constraint_values = [
"@platforms//cpu:aarch64",
],
flag_values = {
"//toolchains/dep_collection:compute_libs": "default",
},
)

config_setting(
name = "jetpack",
constraint_values = [
"@platforms//cpu:aarch64",
],
flag_values = {
"//toolchains/dep_collection:compute_libs": "jetpack",
},
)

config_setting(
name = "windows",
constraint_values = [
"@platforms//os:windows",
],
)

cc_library(
name = "tensorrt_executorch_backend",
srcs = ["TensorRTBackend.cpp"],
hdrs = ["TensorRTBackend.h"],
# Use executorch_headers (no static link) so that register_backend /
# find_backend / FreeableBuffer etc. remain undefined symbols in the
# final libtrt_executorch_backend.so. At runtime they are resolved from
# libqnn_executorch_backend.so, which _portable_lib.so loads with
# RTLD_GLOBAL, ensuring both share the same registry instance.
deps = [
"//core/runtime:runtime_base",
"//core/util:prelude",
"@executorch//:executorch_headers",
] + select({
":jetpack": ["@tensorrt_l4t//:nvinfer"],
":rtx_win": ["@tensorrt_rtx_win//:nvinfer"],
":rtx_x86_64": ["@tensorrt_rtx//:nvinfer"],
":sbsa": ["@tensorrt_sbsa//:nvinfer"],
":windows": ["@tensorrt_win//:nvinfer"],
"//conditions:default": ["@tensorrt//:nvinfer"],
}),
alwayslink = True,
)
Loading
Loading