fastprogress>=1.0.0
jax>=0.4.16
jaxlib>=0.4.16
jaxopt>=0.8
optax>=0.1.7
typing-extensions>=4.4.0
