Skip to content

JAX/XLA produces numerically unstable results for normalization-heavy pipelines under near-zero inputs. #36623

@beanduan22

Description

@beanduan22

Description

please see attached

unique_2525.py

System info (python version, jaxlib version, accelerator, etc.)

expected (pytorch_eager): [8.619985e-09 8.619985e-09 8.619985e-09 8.619985e-09 8.619985e-09
8.619985e-09]
actual (jax.jit/XLA) : [6.821211e-08 6.821211e-08 6.821211e-08 6.821211e-08 6.821211e-08
6.821211e-08]
rel L2 : 6.6355e+00

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions