JAX compared to PyTorch 2: Get a feeling for JAX!

Поделиться
HTML-код
  • Опубликовано: 25 авг 2024
  • A simple torch.nn.Module for neural network model definition and training with gradient descent in PyTorch2 compared to a similar code implementation in JAX, in functional programming.
    How to convert a stateful to a stateless operation in JAX, in functional programming? A simple coding example in JAX: regression via gradient descent, where there is one kind of state: the model parameters.
    Link to documentation and free Colab NB:
    jax.readthedoc...
    colab.research...
    #jax
    #ai
    #parallel
    #computerscience
    #computertipsandtricks

Комментарии • 4

  • @MaxwellHay
    @MaxwellHay Год назад +2

    I like the way jax handles the dependencies explicitly from a software engineer point of view.

  • @EffectCrash
    @EffectCrash Год назад +3

    love every video you made!

  • @riser9644
    @riser9644 Год назад +2

    Very interesting video

  • @riser9644
    @riser9644 Год назад +1

    Beyond experiment and curiosity, is there any benifit to JAX and it's variebts Trax,etc over pytorch, Tf