DeepONet Tutorial in JAX

Поделиться
HTML-код
  • Опубликовано: 5 окт 2024

Комментарии • 18

  • @MachineLearningSimulation
    @MachineLearningSimulation  6 месяцев назад +4

    Sorry for the short, less nice audio segments (4x ~10 seconds). I changed the recording setup for this video. It seems that it requires some further tuning ;)

  • @sabaokangan
    @sabaokangan 6 месяцев назад +1

    Thank you so much for sharing this with us ❤ from Seoul 🇰🇷

  • @user-kn4wt
    @user-kn4wt 6 месяцев назад

    great vid! what are your thoughts on flax vs equinox? seems like flax has a few more things implemented natively, but equinox seems maybe slightly nicer to build your own custom model (FNO etc) from scratch? thanks for all the great content!

  • @digambarkilledar003
    @digambarkilledar003 6 месяцев назад +1

    can you write the code for 1D burgers equation using DeepONet just like you did with FNO. Thanks !! what will be branch input and trunk inputs for burgers equation ?

    • @MachineLearningSimulation
      @MachineLearningSimulation  3 месяца назад

      For now, I want to explore more topics in SciML instead of focusing on one more closer 😊. Regarding your particular question: If your mapping problem is from the initial condition of the Burgers PDE to its solution at a later point in time (e.g. at t=1), the branch inputs would be the degrees of freedom of the initial condition, and the trunk input was the coordinate "x" at which you want to evaluate the solution at the later point in time.

  • @jesusmtz29
    @jesusmtz29 5 месяцев назад +1

    Great stuff

  • @particularlypythonic
    @particularlypythonic 6 месяцев назад +1

    Is there a reason you used the function form of eqx.filter_value_and_grad instead of the decorator form on the loss function?

    • @MachineLearningSimulation
      @MachineLearningSimulation  3 месяца назад

      Great question! 👍
      I don't like using it as a decorator because then the name of the function no longer describes its purpose. Instead of a "loss_fn" it should then be better called "loss_and_grad_fn". Transforming the function whenever I need it gives me more flexibility. I believe (though not 100 % certain) that this has no performance implications because the "update_fn" is wrapped in a JIT anyway.

    • @particularlypythonic
      @particularlypythonic 3 месяца назад

      @@MachineLearningSimulation that makes sense, tysm

  • @HydronauticaCFD
    @HydronauticaCFD 2 месяца назад

    Thank you for the video. I have a question regarding saving historical states. With the jax.vmap, it essentially predicting the full domain in one command, is that correct? I would like to be able to store historical states and then use the historical states in the prediction, but I am not sure how to do it in this configuration. I assume it will consist of performing the operator prediction sequentially, saving the historical state, and then repeating until the end of the domain.
    Thank you

    • @HydronauticaCFD
      @HydronauticaCFD 2 месяца назад

      def loss_fn(model,length, epoch):
      predictions_list = []
      historical_states = jnp.zeros((151, 5))
      for i in range(length):
      trunk_input = jnp.hstack((historical_states[i, :], trunk_inputs_train[i, :])).reshape(1, -1)
      predictions = jax.vmap(model)(branch_inputs_train[epoch,:].reshape(1, -1),branch_IC_inputs_train[epoch,:].reshape(1, -1), trunk_input)
      predictions_list.append(predictions)
      # Update the historical state buffer for the next timestep
      if i < length - 1: # Ensure we don't update beyond the last index
      # Shift the historical states left and add the new prediction
      historical_states = historical_states.at[i + 1, 1:].set(historical_states[i, :-1])
      historical_states = historical_states.at[i + 1, 0].set(predictions[0])
      predictions_list.append(predictions)
      predictions_array = jnp.vstack(predictions_list)
      mse = jnp.mean(jnp.square(predictions_array - outputs_train[epoch, :length]))
      return mse
      Something like this, but with no loops, of course.

  • @lksmac1595
    @lksmac1595 6 месяцев назад +1

    Amazing

  •  6 месяцев назад +1

    Nice!