Description
please see attach
unique_0016.py
System info (python version, jaxlib version, accelerator, etc.)
expected (jax.disable_jit): [7.539458e-05 7.539458e-05 7.539458e-05 7.539458e-05 7.539458e-05
7.539458e-05]
actual (jax.jit) : [0.00011158 0.00011158 0.00011158 0.00011158 0.00011158 0.00011158]
rel L2 : 4.8000e-01
jax.jit produces significantly incorrect results compared to eager execution
(jax.disable_jit), even on CPU.
Observed relative L2 error: ~48%.
The issue arises in a model combining:
- constant folding (mul-zero elimination)
- where with constant condition
- fused matmul + bias + relu
- layernorm + relu
This suggests that XLA optimization (likely fusion or constant propagation)
is altering the computation semantics, not just numerical precision.
The issue is deterministic and reproducible with fixed inputs.
Description
please see attach
unique_0016.py
System info (python version, jaxlib version, accelerator, etc.)
jax.jit produces significantly incorrect results compared to eager execution
(jax.disable_jit), even on CPU.
Observed relative L2 error: ~48%.
The issue arises in a model combining:
This suggests that XLA optimization (likely fusion or constant propagation)
is altering the computation semantics, not just numerical precision.
The issue is deterministic and reproducible with fixed inputs.