Skip to content

jax.jit produces incorrect results vs eager execution on CPU #36619

@beanduan22

Description

@beanduan22

Description

please see attach

unique_0016.py

System info (python version, jaxlib version, accelerator, etc.)

expected (jax.disable_jit): [7.539458e-05 7.539458e-05 7.539458e-05 7.539458e-05 7.539458e-05
 7.539458e-05]
actual   (jax.jit)        : [0.00011158 0.00011158 0.00011158 0.00011158 0.00011158 0.00011158]
rel L2 : 4.8000e-01

jax.jit produces significantly incorrect results compared to eager execution
(jax.disable_jit), even on CPU.

Observed relative L2 error: ~48%.

The issue arises in a model combining:

  • constant folding (mul-zero elimination)
  • where with constant condition
  • fused matmul + bias + relu
  • layernorm + relu

This suggests that XLA optimization (likely fusion or constant propagation)
is altering the computation semantics, not just numerical precision.

The issue is deterministic and reproducible with fixed inputs.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions