Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/build-test-linux-x86_64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ jobs:
set -euo pipefail
pushd .
cd tests/py/dynamo
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_converter_tests_results.xml --dist=loadscope --maxfail=20 conversion/
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_converter_tests_results.xml --maxfail=20 conversion/
popd

L0-dynamo-core-tests:
Expand Down Expand Up @@ -236,6 +236,7 @@ jobs:
cd tests/py/dynamo
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_*
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_partitioning_tests_results.xml partitioning/test_001_*
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_hlo_tests_results.xml hlo/

popd

Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/build-test-linux-x86_64_rtx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ jobs:
set -euo pipefail
pushd .
cd tests/py/dynamo
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_converter_tests_results.xml --dist=loadscope --maxfail=20 conversion/
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_converter_tests_results.xml --maxfail=20 conversion/
popd

L0-dynamo-core-tests:
Expand Down Expand Up @@ -204,6 +204,7 @@ jobs:
pushd .
cd tests/py/dynamo
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_*
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_hlo_tests_results.xml hlo/
popd

L1-dynamo-compile-tests:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-test-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ jobs:
cd tests/py/dynamo
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_*
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_partitioning_tests_results.xml partitioning/test_001_*
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_hlo_tests_results.xml hlo/
popd

L1-dynamo-compile-tests:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-test-windows_rtx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ jobs:
pushd .
cd tests/py/dynamo
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_*
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_hlo_tests_results.xml hlo/
popd

L1-dynamo-compile-tests:
Expand Down
233 changes: 229 additions & 4 deletions docsrc/contributors/complex_number_support.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,19 +135,244 @@ runtime modules handle the conversion:
Key Implementation Invariants
-------------------------------

* **``originally_complex`` set** — the set of nodes that were complex-dtype
*before* any rewrites. After ``replace_input_node``, complex placeholders become
``float32`` so ``is_complex_dtype()`` returns ``False``. The ``originally_complex``
set is used to decide which ``mul.Tensor`` nodes need the complex mul rewrite.
* **``node.meta["is_complex_layout"]``** — every node that represents a complex
quantity (either originally complex-dtype, or a real ``(..., 2)`` tensor produced
by the rewriter) is annotated with ``node.meta["is_complex_layout"] = True``.
This annotation is set during the detection phase (before any rewrites begin) and
propagated by every rewrite handler as it emits new nodes. It survives dtype
changes: after ``replace_input_node`` converts a ``placeholder`` from complex to
``float32``, the dtype-based check ``is_complex_dtype()`` would return ``False``,
but the metadata flag remains. ``_is_complex_layout_node(n)`` is simply
``n.meta.get("is_complex_layout", False)`` — no shape heuristics or recursion.
* **FakeTensorMode reuse** — ``propagate_metadata`` must use the ``FakeTensorMode``
from existing placeholder fake tensors (not a fresh mode) to avoid mode-mismatch
errors under ``torch.compile`` and to preserve SymInt for dynamic shapes.
* **Dotted buffer names** — ``register_buffer`` rejects names containing ``.``.
Nested submodule parameter names (e.g. ``layers.0.weight``) must have ``.``
replaced with ``__`` before registration.

The Decomposition System — How It Is Built
-------------------------------------------

The rewriter is split across two classes and wired together by a lightweight
dispatch mechanism. This section walks through each piece and explains the
design decisions.

ComplexOpDetector — Subgraph Discovery
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

``ComplexOpDetector`` walks the graph to find the set of nodes that participate
in complex arithmetic.

``node_include_in_subgraph``
""""""""""""""""""""""""""""

A node is included in a complex subgraph if:

1. Its output dtype is ``complex64`` or ``complex128`` (``is_complex_dtype``), **or**
2. Any of its inputs are complex (``has_complex_input``).

The second condition is necessary to catch real-output ops — ``abs``, ``angle``,
``real``, ``imag`` — whose inputs are complex. These must be rewritten alongside
the rest of the subgraph even though their outputs are real.

``subgraph_from_anchor``
""""""""""""""""""""""""

For ``view_as_real``-bounded subgraphs, detection starts at a ``view_as_real``
*anchor* node and performs a backward BFS:

.. code-block:: text

view_as_real ← mul (complex) ← reshape ← placeholder (complex)
↑ anchor ↑ subgraph ↑ subgraph ↑ input

At each step, if an upstream node satisfies ``node_include_in_subgraph`` it is
added to the subgraph; otherwise it becomes an *input node* (the boundary). The
result is a ``ComplexSubGraphInfo`` containing anchor nodes, subgraph nodes, and
input nodes.

After collection the subgraph is **sorted in topological order** (by position in
the graph's node list). This is critical: without it a ``mul`` node could be
processed before its ``sin`` or ``cos`` operands, causing the rewriter to see the
original complex node instead of the already-rewritten real node.

``find_complex_op_subgraphs`` and subgraph merging
"""""""""""""""""""""""""""""""""""""""""""""""""""

When a model has multiple ``view_as_real`` anchors that share upstream nodes
(e.g. ``xq_out`` and ``xk_out`` in a RoPE layer both descend from the same
``freqs_cis`` placeholder), their subgraphs would otherwise be detected
separately. ``find_complex_op_subgraphs`` merges overlapping subgraphs by
set intersection so each node is rewritten exactly once.

``find_all_complex_subgraphs`` — unbounded complex ops
"""""""""""""""""""""""""""""""""""""""""""""""""""""""

Some models produce a complex tensor as a graph *output* without passing it
through ``view_as_real``. ``find_all_complex_subgraphs`` is a forward scan that
collects every ``call_function`` node with a complex output, regardless of
anchoring. The resulting subgraph is processed the same way as an
anchor-bounded one.

ComplexGraphRewriter — Dispatch-Based Rewriting
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

``ComplexGraphRewriter`` is decorated with ``@_register_unpackers``, which at
class-definition time scans every method for the ``@_complex_unpacker(op, ...)``
decorator and builds a ``cls._DISPATCH`` dictionary mapping aten ops to rewrite
methods.

.. code-block:: python

@_complex_unpacker(torch.ops.aten.mul.Tensor)
def _rewrite_mul(self, node: Node, b: SubgraphBuilder, ...):
...

The entry point ``rewrite_subgraph_nodes`` iterates over the (topologically
ordered) subgraph nodes and for each node:

1. Looks up ``node.target`` in ``_DISPATCH``.
2. If found, calls the corresponding rewrite method.
3. If not found but the op is in ``_ELEMENTWISE_SAFE``, skips it (the op applies
independently to every scalar, so the ``(..., 2)`` real layout is already
correct).
4. Otherwise logs a warning and leaves the node unchanged.

``_ELEMENTWISE_SAFE``
"""""""""""""""""""""

The ``_ELEMENTWISE_SAFE`` set contains ops that apply to every element of the
tensor independently — ``add.Tensor``, ``sub.Tensor``, ``neg``, ``mul.Scalar``,
``clone``, ``where``, etc. On the ``(..., 2)`` real layout these are already
correct: adding two complex tensors element-wise is the same as adding their
real and imaginary parts independently.

Notably **excluded** from this set:

* ``permute.default`` — must append the trailing real/imag dim index.
* ``add.Scalar`` / ``sub.Scalar`` — a scalar added to a complex number only
shifts the real part; on the ``(..., 2)`` layout both parts would be shifted.
* ``reshape`` / ``view`` — shape arguments need updating for the extra ``2`` dim.

Complex Multiply Decomposition
"""""""""""""""""""""""""""""""

The most important rewrite is ``mul.Tensor`` between two complex operands.
The rewriter calls ``complex_mul_replacement``:

.. code-block:: python

# inputs a, b have shape (..., 2) — last dim is [real, imag]
re_a = select(a, -1, 0); im_a = select(a, -1, 1)
re_b = select(b, -1, 0); im_b = select(b, -1, 1)
real_out = re_a * re_b - im_a * im_b # ac - bd
imag_out = re_a * im_b + im_a * re_b # ad + bc
result = stack([real_out, imag_out], dim=-1)

Each step is inserted via a ``SubgraphBuilder`` anchored at the ``mul`` node,
so all six new nodes appear immediately after it in topological order. The
original ``mul`` node is then replaced and erased.

See :ref:`subgraph_builder` for more on how ``SubgraphBuilder`` manages
cursor-based insertion.

The ``is_complex_layout`` Metadata Invariant
"""""""""""""""""""""""""""""""""""""""""""""

Input replacement (Stage 2) converts complex ``placeholder`` nodes to
``float32``. After that, ``is_complex_dtype(node)`` returns ``False`` for those
nodes even though they logically represent complex quantities.

To avoid missed rewrites, every node that represents a complex quantity is
annotated with ``node.meta["is_complex_layout"] = True`` during the detection
phase (lines in ``rewrite_subgraph_nodes`` before any rewrites begin). The
annotation is then propagated forward by every rewrite handler:

* ``replace_input_node`` stamps it on the new placeholder and ``get_attr`` nodes.
* ``_inline_cat_re_im`` stamps it on every ``[re_u, im_u]`` concatenation node,
covering all math handlers (``exp``, ``log``, ``sin``, ``mul``, etc.) at once.
* Each shape-manipulation handler (``reshape``, ``permute``, ``unsqueeze``,
``cat``, ``stack``, etc.) stamps it on its output node explicitly.

``_is_complex_layout_node(n)`` is therefore a direct metadata lookup — no shape
heuristics (``val.shape[-1] == 2``), no recursive ``_SHAPE_TRANSPARENT_OPS``
propagation. This also eliminates false-positives on real parameters that
coincidentally have a trailing dimension of size 2.

FakeTensorMode Reuse for Dynamic Shapes
"""""""""""""""""""""""""""""""""""""""""

When inserting a new ``placeholder`` for a complex input, the pass must populate
``meta["val"]`` with a ``FakeTensor`` of the new real shape. Using a fresh
``FakeTensorMode()`` would create a *new* ``ShapeEnv``, which is incompatible
with the one that ``torch.export`` used to encode dynamic shape constraints
(SymInt ranges).

The fix is to extract the ``FakeTensorMode`` from the *original* placeholder's
``meta["val"].fake_mode`` and reuse it. The new fake tensor is then constructed
by appending a concrete ``2`` to the symbolic shape list:

.. code-block:: python

orig_fake = input_node.meta["val"]
sym_shape = list(orig_fake.shape) + [2]
with orig_fake.fake_mode:
fake_tensor = torch.empty(sym_shape, dtype=new_dtype, device=device)

This preserves all SymInt identity across the graph and keeps
dynamic-shape exports working correctly.

Entry Point: ``complex_graph_detection``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The public entry point called by the lowering pipeline is
``complex_graph_detection(gm, settings)``. It:

1. Instantiates ``ComplexOpDetector`` and ``ComplexGraphRewriter``.
2. Calls ``find_complex_op_subgraphs`` anchored on ``view_as_real`` to find
bounded complex subgraphs.
3. Calls ``find_all_complex_subgraphs`` for any remaining complex nodes that
are not ``view_as_real``-bounded.
4. For each subgraph:

a. Calls ``replace_input_node`` on every boundary input node (Stage 2).
b. Calls ``rewrite_subgraph_nodes`` on the ordered subgraph (Stage 3).
c. Calls ``clean_up_graph_after_modifications`` to remove dead nodes.

5. Returns the modified ``GraphModule``.

Adding New Op Rewrites
^^^^^^^^^^^^^^^^^^^^^^^

To teach the rewriter about a new complex op, add a method to
``ComplexGraphRewriter`` tagged with ``@_complex_unpacker``:

.. code-block:: python

@_complex_unpacker(torch.ops.aten.my_new_op.default)
def _rewrite_my_new_op(self, node: Node) -> bool:
inp = node.args[0]
with SubgraphBuilder(self.gm.graph, node) as b:
re = b(torch.ops.aten.select.int, inp, -1, 0)
im = b(torch.ops.aten.select.int, inp, -1, 1)
out = b(my_real_impl, re, im)
# If the output is still a complex-layout [..., 2] tensor, annotate it.
# (Not needed if using _inline_cat_re_im, which sets the flag automatically.)
out.meta["is_complex_layout"] = True
node.replace_all_uses_with(out)
self.gm.graph.erase_node(node)
return True

``@_register_unpackers`` (applied to the class) picks up the new entry
automatically at import time — no other registration is required.

If the new op is elementwise-safe on the ``(..., 2)`` layout (i.e. it acts
independently on every scalar), add it to ``_ELEMENTWISE_SAFE`` instead.

Related
-------

* :ref:`lowering` — the complex rewrite is a lowering pass.
* :ref:`subgraph_builder` — the ``SubgraphBuilder`` helper used in every rewrite method.
* :ref:`lowering_passes_catalog` — pass ordering and management.
3 changes: 2 additions & 1 deletion docsrc/tutorials/advanced_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Advanced Usage
==============

Step-by-step tutorials covering engine caching, quantization, custom kernels,
dynamic shapes, weight streaming, debugging, and more.
dynamic shapes, weight streaming, debugging, complex numerics, and more.

.. toctree::
:maxdepth: 2
Expand All @@ -14,5 +14,6 @@ dynamic shapes, weight streaming, debugging, and more.
weight_refit/index
runtime_opt/index
deployment/index
complex_numerics/index
Example: Distributed Inference <_rendered_examples/distributed_inference/index>
../indices/supported_ops
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ compilation.
This page explains what the rewriter does, which patterns are supported, and what
limitations to be aware of when compiling models with complex inputs.

.. seealso::

:doc:`../_rendered_examples/dynamo/torch_export_3d_rope` — a runnable
end-to-end example compiling a video-transformer 3D RoPE attention block
(CogVideoX / Wan / HunyuanVideo style) with dynamic T×H×W shapes.

----

How the Rewriter Works
Expand Down
10 changes: 10 additions & 0 deletions docsrc/tutorials/complex_numerics/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Complex Numerics
===================

Compatiblity support for numerical datatypes like complex numerics which are not natively supported by TensorRT

.. toctree::
:maxdepth: 1

complex_tensors
Example: 3D RoPE with Complex Numerics <../_rendered_examples/dynamo/torch_export_3d_rope>
1 change: 0 additions & 1 deletion docsrc/tutorials/deployment/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,3 @@ complex-valued model support.
cross_compile_windows
Example: Cross-runtime Compilation for Windows <../_rendered_examples/dynamo/cross_runtime_compilation_for_windows>
distributed_inference
complex_tensors
1 change: 1 addition & 0 deletions docsrc/tutorials/extensibility/lowering/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ rewrite ATen ops before TensorRT compilation.
:maxdepth: 1

writing_dynamo_aten_lowering_passes
subgraph_builder
Loading
Loading