Nice! This is literally my favorite JAX function. I use it all the time for time-series modeling with custom recurrent neural networks. Looking forward to the next video!
Is there a way to incorporate a stopping criteria? For instance, you want to run your simulations not just for 2000 steps, but for the number of steps until it hits a certain state.
If you are only interested in the final iterate at the point you hit your criterion, you might want to look into "Jax.lax.while_loop". I assume you are interested in still stacking all iterates (like here in the video): in that case, I'm unsure if that fits within the JAX compute model because the array shape would be unknown at compile time (in case you wrap your jax.lax.scan in any outer function transformation). A remedy, if you can computationally afford it, is to stack as many iterates as you are sure you will need at a maximum and then compute the index for which the condition is fulfilled. Then you can slice the array later on.
Hi .. was looking for JAX tutorials .. found this video and your channel and hence must say THANKS for the nice explanations.
You're very welcome 🤗
Welcome to the JAX side!
Nice! This is literally my favorite JAX function. I use it all the time for time-series modeling with custom recurrent neural networks. Looking forward to the next video!
Nice, yes. 😊
Thanks for the support. See in one of the next JAX videos (probably on custom jvps)
short and concise - really cool! need more vids about jax and its sharp edges, XLA and how the internal autograd works!
Thanks :).
More on JAX to come!
Nice tutorial - thanks!
You're welcome! 🙏
Is there a way to incorporate a stopping criteria? For instance, you want to run your simulations not just for 2000 steps, but for the number of steps until it hits a certain state.
If you are only interested in the final iterate at the point you hit your criterion, you might want to look into "Jax.lax.while_loop".
I assume you are interested in still stacking all iterates (like here in the video): in that case, I'm unsure if that fits within the JAX compute model because the array shape would be unknown at compile time (in case you wrap your jax.lax.scan in any outer function transformation). A remedy, if you can computationally afford it, is to stack as many iterates as you are sure you will need at a maximum and then compute the index for which the condition is fulfilled. Then you can slice the array later on.