dm_haiku==0.0.9
e3nn_jax==0.14.0
jax==0.4.1
jaxlib==0.4.1
jraph==0.0.6.dev0
numpy>=1.23.4
optax==0.1.3
