Source code for pararealml.operators.ml.pidon.loss

from __future__ import annotations

from typing import Optional, Tuple, Sequence, NamedTuple

import tensorflow as tf


[docs]class Loss(NamedTuple): """ A collection of the various losses of a physics-informed DeepONet. """ diff_eq_loss: tf.Tensor ic_loss: tf.Tensor bc_losses: Optional[Tuple[tf.Tensor, tf.Tensor]] weighted_total_loss: tf.Tensor def __str__(self): string = f'Weighted Total: {self.weighted_total_loss}; ' + \ f'DE: {self.diff_eq_loss}; ' + \ f'IC: {self.ic_loss}' if self.bc_losses: string += f'; Dirichlet BC: {self.bc_losses[0]}; ' + \ f'Neumann BC: {self.bc_losses[1]}' return string
[docs] @classmethod def construct( cls, diff_eq_loss: tf.Tensor, ic_loss: tf.Tensor, bc_losses: Optional[Tuple[tf.Tensor, tf.Tensor]], diff_eq_loss_weight: float, ic_loss_weight: float, bc_loss_weight: float) -> Loss: """ Calculates the weighted total loss given the weights for the different components of the total loss and returns a Loss instance. :param diff_eq_loss: the differential equation loss tensor :param ic_loss: the initial condition loss tensor :param bc_losses: a tuple of the Dirichlet and Neumann boundary condition loss tensors :param diff_eq_loss_weight: the weight of the differential equation part of the total physics-informed loss :param ic_loss_weight: the weight of the initial condition part of the total physics-informed loss :param bc_loss_weight: the weight of the boundary condition part of the total physics-informed loss :return: the losses including the weighted total """ weighted_total_loss = \ tf.scalar_mul(diff_eq_loss_weight, diff_eq_loss) + \ tf.scalar_mul(ic_loss_weight, ic_loss) if bc_losses: weighted_total_loss += \ tf.scalar_mul(bc_loss_weight, bc_losses[0] + bc_losses[1]) return Loss(diff_eq_loss, ic_loss, bc_losses, weighted_total_loss)
[docs] @classmethod def mean( cls, losses: Sequence[Loss], diff_eq_loss_weight: float, ic_loss_weight: float, bc_loss_weight: float) -> Loss: """ Computes the mean of the provided losses. :param losses: the losses to average over :param diff_eq_loss_weight: the weight of the differential equation part of the total physics-informed loss :param ic_loss_weight: the weight of the initial condition part of the total physics-informed loss :param bc_loss_weight: the weight of the boundary condition part of the total physics-informed loss :return: the mean loss """ diff_eq_losses = [] ic_losses = [] dirichlet_bc_losses = [] neumann_bc_losses = [] for loss in losses: diff_eq_losses.append(loss.diff_eq_loss) ic_losses.append(loss.ic_loss) if loss.bc_losses: dirichlet_bc_losses.append(loss.bc_losses[0]) neumann_bc_losses.append(loss.bc_losses[1]) mean_diff_eq_loss = tf.reduce_mean(tf.stack(diff_eq_losses), axis=0) mean_ic_loss = tf.reduce_mean(tf.stack(ic_losses), axis=0) mean_bc_losses = None \ if len(dirichlet_bc_losses) + len(neumann_bc_losses) == 0 \ else (tf.reduce_mean(tf.stack(dirichlet_bc_losses), axis=0), tf.reduce_mean(tf.stack(neumann_bc_losses), axis=0)) return cls.construct( mean_diff_eq_loss, mean_ic_loss, mean_bc_losses, diff_eq_loss_weight, ic_loss_weight, bc_loss_weight)