Skip to content

Fix torch.compile graph breaks from Params4bit __getattr__ (#1904, #1917)#1916

Merged
Titus-von-Koeller merged 3 commits intomainfrom
fix/issue-1904
Apr 10, 2026
Merged

Fix torch.compile graph breaks from Params4bit __getattr__ (#1904, #1917)#1916
Titus-von-Koeller merged 3 commits intomainfrom
fix/issue-1904

Conversation

@Titus-von-Koeller
Copy link
Copy Markdown
Collaborator

@Titus-von-Koeller Titus-von-Koeller commented Apr 7, 2026

Summary

  • Replace __getattr__ + _QUANT_STATE_ATTR_MAP on Params4bit with @property descriptors to eliminate torch.compile graph breaks under activation checkpointing
  • Remove dead quant_state.dtype mutation in matmul_4bit CPU path that caused a separate torch.compile graph break under activation checkpointing
  • Add regression test test_linear4bit_torch_compile_activation_checkpointing that compiles with fullgraph=True + torch.utils.checkpoint to catch this class of issue

Context

PR #1866 added __getattr__ to Params4bit for FSDP state_dict traversal. Since Params4bit is a torch.Tensor subclass, Dynamo cannot trace through __getattr__, creating graph breaks on every attribute access. With activation checkpointing these multiply across layers, causing significant compilation overhead (#1904).

@property descriptors use Python's descriptor protocol (resolved at class level), which Dynamo handles correctly — no graph breaks. FSDP still works because getattr(weight, "absmax") resolves the same way.

Attributes that collide with Params4bit instance attrs (blocksize, quant_type) or torch.Tensor attrs (dtype, shape) are intentionally omitted — they're packed into the bitsandbytes__* blob and never traversed by FSDP as separate keys.

QuantState.__getattr__ is left unchanged since QuantState is not a Tensor subclass.

Additionally, matmul_4bit() mutated quant_state.dtype = A.dtype on the CPU path (#1917). This in-place mutation is unnecessary — MatMul4Bit.forward already casts via .to(A.dtype), and gemv_4bit doesn't read state.dtype. Removing it eliminates the Dynamo graph break on CPU under activation checkpointing.

Validated against three code states

State Regression test Existing tests
Before #1866 (no __getattr__) 4 passed N/A
At #1866 (with __getattr__) 4 failed (graph break) no existing tests failed
After this fix (@property) 4 passed 2765 passed, none failed

Test plan

  • test_linear4bit_torch_compile_activation_checkpointing — 4 variants pass (nf4/fp4 × compress_statistics), including CPU
  • test_linear4bit_torch_compile — all 64 variants pass (no regressions)
  • test_params4bit_quant_state_attr_access — all 4 variants pass (FSDP traversal still works)
  • ran complete test suite — clean
  • pre-commit run --all-files — clean

Fixes #1904
Fixes #1917

🤖 Generated with Claude Code

Replace __getattr__ + _QUANT_STATE_ATTR_MAP on Params4bit with @Property
descriptors. Dynamo cannot trace __getattr__ on torch.Tensor subclasses,
causing graph breaks that multiply under activation checkpointing.
Properties use the descriptor protocol which Dynamo handles correctly.

Add regression test that compiles Linear4bit with fullgraph=True and
torch.utils.checkpoint to catch this class of issue.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 7, 2026

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Titus-von-Koeller
Copy link
Copy Markdown
Collaborator Author

PR Review: #1916 -- Fix torch.compile graph breaks from Params4bit __getattr__ (#1904)

Classification: [bug-fix]
Size: Small-medium (150 added / 37 deleted, 2 files)
Author: Titus-von-Koeller (maintainer)
Verdict: Approve

Bug fix replacing __getattr__ + _QUANT_STATE_ATTR_MAP on Params4bit with @property descriptors to eliminate torch.compile graph breaks under activation checkpointing. Root cause is well-understood: Dynamo cannot trace through __getattr__ on torch.Tensor subclasses, so every attribute access creates a graph break. With activation checkpointing these multiply across layers. The descriptor protocol (resolved at class level) is Dynamo-compatible.

No blocking issues.

The approach is correct. Properties are resolved via Python's descriptor protocol at the class MRO level, which Dynamo handles natively. The behavioral semantics are preserved: each property delegates to self.__dict__.get("quant_state") and raises AttributeError with the same message format when the attribute is unavailable. The four attributes that collided with Params4bit instance attrs (blocksize, quant_type) or Tensor attrs (dtype, shape) are correctly omitted -- these were dead code in __getattr__ since regular attribute lookup would find them first, and as properties they would incorrectly shadow the instance/tensor attributes.

One observation (non-blocking):

The nested_offset property checks qs is not None but does not check qs.state2 is not None, unlike the other nested_* properties. This matches the old _QUANT_STATE_ATTR_MAP behavior exactly ("nested_offset": lambda qs: qs.offset), and it matches as_dict() which serializes nested_offset as self.offset.item() (not self.state2.offset). So this is correct as-is -- just noting it since the asymmetry is easy to misread.

Regression test is well-designed. Uses fullgraph=True which causes torch.compile to error on graph breaks rather than silently falling back to eager. Covers forward + backward with activation checkpointing, parametrized across devices, quant types, and compress_statistics. Existing test_params4bit_quant_state_attr_access continues to verify FSDP traversal of all state_dict keys.

CI status: All CUDA/GPU tests pass across A10, L40S, T4 (CUDA 11.8/12.8/13.0). All builds pass. Lint passes. 4 CPU test failures on torch 2.10.0 appear pre-existing/infra -- PR only touches Python code.

  • Security: Clear
  • Downstream impact: None -- behavioral contract preserved. getattr(weight, "absmax") still resolves. PEFT/Transformers __dict__ round-trips unaffected (properties don't appear in __dict__). FSDP traversal works identically.
  • Serialization: No change -- as_dict(), from_dict(), state_dict keys all unchanged
  • Cross-PR conflicts: PR fix: Stop quantize_blockwise and quantize_4bit from mutating user-provided absmax #1863 also touches tests/test_linear4bit.py but in a different test function; should merge cleanly in either order

Titus-von-Koeller and others added 2 commits April 8, 2026 15:10
matmul_4bit mutates quant_state.dtype in-place on CPU, which Dynamo
flags as a side effect under fullgraph=True + activation checkpointing.
This is a pre-existing issue unrelated to the __getattr__ → @Property
fix. Skip on CPU and track the mutation fix separately in #1917.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The mutation `quant_state.dtype = A.dtype` is unnecessary: MatMul4Bit.forward
already casts via `.to(A.dtype)`, and gemv_4bit doesn't read state.dtype.
Removing it eliminates the Dynamo graph break on CPU under activation
checkpointing, so the regression test no longer needs a CPU skip.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@Titus-von-Koeller Titus-von-Koeller changed the title Fix torch.compile graph breaks from Params4bit __getattr__ (#1904) Fix torch.compile graph breaks from Params4bit __getattr__ (#1904, #1917) Apr 8, 2026
@matthewdouglas matthewdouglas added this to the v0.50.0 milestone Apr 8, 2026
@Titus-von-Koeller Titus-von-Koeller merged commit 023bb36 into main Apr 10, 2026
144 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.compile fullgraph breaks on CPU: matmul_4bit mutates quant_state.dtype Params4bit.__getattr__ breaks torch.compile - use @property instead

2 participants