Neural Networks in pure JAX (with automatic differentiation)
HTML-код
- Опубликовано: 16 июл 2024
- (Reverse-mode) automatic differentiation is the secret sauce of deep learning, allowing to differentiate almost arbitrary neural architectures. Let's use the abstractions of the JAX DL framework in Python to implement a simple MLP. Here is the code: github.com/Ceyron/machine-lea...
-------
👉 This educational series is supported by the world-leaders in integrating machine learning and artificial intelligence with simulation and scientific computing, Pasteur Labs and Institute for Simulation Intelligence. Check out simulation.science/ for more on their pursuit of 'Nobel-Turing' technologies (arxiv.org/abs/2112.03235 ), and for partnership or career opportunities.
-------
📝 : Check out the GitHub Repository of the channel, where I upload all the handwritten notes and source-code files (contributions are very welcome): github.com/Ceyron/machine-lea...
📢 : Follow me on LinkedIn or Twitter for updates on the channel and other cool Machine Learning & Simulation stuff: / felix-koehler and / felix_m_koehler
💸 : If you want to support my work on the channel, you can become a Patreon here: / mlsim
🪙: Or you can make a one-time donation via PayPal: www.paypal.com/paypalme/Felix...
-------
⚙️ My Gear:
(Below are affiliate links to Amazon. If you decide to purchase the product or something else on Amazon through this link, I earn a small commission.)
- 🎙️ Microphone: Blue Yeti: amzn.to/3NU7OAs
- ⌨️ Logitech TKL Mechanical Keyboard: amzn.to/3JhEtwp
- 🎨 Gaomon Drawing Tablet (similar to a WACOM Tablet, but cheaper, works flawlessly under Linux): amzn.to/37katmf
- 🔌 Laptop Charger: amzn.to/3ja0imP
- 💻 My Laptop (generally I like the Dell XPS series): amzn.to/38xrABL
- 📱 My Phone: Fairphone 4 (I love the sustainability and repairability aspect of it): amzn.to/3Jr4ZmV
If I had to purchase these items again, I would probably change the following:
- 🎙️ Rode NT: amzn.to/3NUIGtw
- 💻 Framework Laptop (I do not get a commission here, but I love the vision of Framework. It will definitely be my next Ultrabook): frame.work
As an Amazon Associate I earn from qualifying purchases.
-------
Timestamps:
00:00 Intro
01:18 Dataset that somehow looks like a sine function
01:56 Forward pass of the Multilayer Perceptron
03:22 Weight initialization due to Xavier Glorot
04:20 Idea of "Learning" as approximate optimization
04:49 Reverse-mode autodiff requires us to only write the forward pass
05:34 Imports
05:52 Constants and Hyperparameters
06:19 Producing the random toy dataset
08:33 Draw initial parameter guesses
12:05 Implementing the forward/primal pass
13:58 Implementing the loss metric
14:57 Transform forward pass to get gradients by autodiff
20:03 Training loop (using plain gradient descent)
23:21 Improving training speed by JIT compilation
24:25 Plotting loss history
24:47 Plotting final network prediction & Discussion
25:44 Summary
26:59 Outro
Thanks a bunch.
I can never find enough JAX tuts lol
Even large language models like ChatGPT have outdated info a lot of the time due to their current knowledge cutoff dates.
It's just this and the docs for me hehe...
Glad I could help! :)
Very nice explanation!
Thanks a lot 😊
Very informative video as always!
Thanks a lot 😊
Great, seems like magic. Thanks!
Nice, thanks :).
Indeed, JAX is really smooth to use. I find the AD interface better than in tensorflow or Pytorch.