JAX.lax.scan tutorial (for autoregressive rollout)

Поделиться
HTML-код
  • Опубликовано: 5 окт 2024

Комментарии • 10

  • @rene9901
    @rene9901 5 месяцев назад +2

    Hi .. was looking for JAX tutorials .. found this video and your channel and hence must say THANKS for the nice explanations.

  • @Michael-vs1mw
    @Michael-vs1mw 10 месяцев назад

    Nice! This is literally my favorite JAX function. I use it all the time for time-series modeling with custom recurrent neural networks. Looking forward to the next video!

    • @MachineLearningSimulation
      @MachineLearningSimulation  9 месяцев назад

      Nice, yes. 😊
      Thanks for the support. See in one of the next JAX videos (probably on custom jvps)

  • @neelg7057
    @neelg7057 10 месяцев назад

    short and concise - really cool! need more vids about jax and its sharp edges, XLA and how the internal autograd works!

  • @dilipkrishnan9677
    @dilipkrishnan9677 8 месяцев назад +1

    Nice tutorial - thanks!

  • @paveltolmachev1898
    @paveltolmachev1898 7 месяцев назад +1

    Is there a way to incorporate a stopping criteria? For instance, you want to run your simulations not just for 2000 steps, but for the number of steps until it hits a certain state.

    • @MachineLearningSimulation
      @MachineLearningSimulation  7 месяцев назад

      If you are only interested in the final iterate at the point you hit your criterion, you might want to look into "Jax.lax.while_loop".
      I assume you are interested in still stacking all iterates (like here in the video): in that case, I'm unsure if that fits within the JAX compute model because the array shape would be unknown at compile time (in case you wrap your jax.lax.scan in any outer function transformation). A remedy, if you can computationally afford it, is to stack as many iterates as you are sure you will need at a maximum and then compute the index for which the condition is fulfilled. Then you can slice the array later on.