Description
please see attached
unique_2525.py
System info (python version, jaxlib version, accelerator, etc.)
expected (pytorch_eager): [8.619985e-09 8.619985e-09 8.619985e-09 8.619985e-09 8.619985e-09
8.619985e-09]
actual (jax.jit/XLA) : [6.821211e-08 6.821211e-08 6.821211e-08 6.821211e-08 6.821211e-08
6.821211e-08]
rel L2 : 6.6355e+00
Description
please see attached
unique_2525.py
System info (python version, jaxlib version, accelerator, etc.)
expected (pytorch_eager): [8.619985e-09 8.619985e-09 8.619985e-09 8.619985e-09 8.619985e-09
8.619985e-09]
actual (jax.jit/XLA) : [6.821211e-08 6.821211e-08 6.821211e-08 6.821211e-08 6.821211e-08
6.821211e-08]
rel L2 : 6.6355e+00