Source code for drdmannturb.spectra_fitting.loss_functions

"""
This module implements and exposes the various terms that can be
added and scaled in a generic loss function
"""
from typing import Optional

import torch
from torch.utils.tensorboard import SummaryWriter

from ..parameters import LossParameters


[docs] class LossAggregator: def __init__( self, params: LossParameters, k1space: torch.Tensor, zref: float, tb_log_dir: Optional[str] = None, tb_comment: str = "", ): r"""Combines all loss functions and evaluates each term for the optimizer. The loss function for spectra fitting is determined by the following minimization problem: .. math:: \min _{\boldsymbol{\theta}}\left\{\operatorname{MSE}[\boldsymbol{\theta}]+\alpha \operatorname{Pen}[\boldsymbol{\theta}]+\beta \operatorname{Reg}\left[\boldsymbol{\theta}_{\mathrm{NN}}\right]\right\} where the loss terms are defined in the class's methods. In the following, :math:`L` is the number of data points :math:`f_j \in \mathcal{D}`. The data :math:`J_i (f_j)` is evaluated using the Kaimal spectra. Note that the norm :math:`||\cdot||_{\mathcal{D}}` is defined as .. math:: \|g\|_{\mathcal{D}}^2:=\int_{\mathcal{D}}|g(f)|^2 \mathrm{~d}(\log f). Parameters ---------- params : LossParameters Dataclass with parameters that determine the loss function k1space : torch.Tensor The spectra space for k1, this is assumed to be in logspace. zref : float Reference velocity, needed for computing penalty derivatives wrt ``k1 z``. tb_log_dir : Optional[str] Logging directory for the TensorBoard logger. Conventions are those of TensorBoard, which by default result in the creation of a ``runs`` subdirectory where the script is being run if this parameter is left as None. fn_comment : str Filename comment used by tensorboard; useful for distinguishing between architectures and hyperparameters. Refer to tensorboard documentation for examples of use. By default, the empty string, which results in default tensorboard filenames. """ self.writer = SummaryWriter(log_dir=tb_log_dir, comment=tb_comment) self.params = params self.zref = zref self.k1space = k1space self.logk1 = torch.log(self.k1space) self.h1 = torch.diff(self.logk1) self.h2 = torch.diff(0.5 * (self.logk1[:-1] + self.logk1[1:])) t_alphapen1 = ( lambda y, theta_NN, epoch: self.Pen1stOrder(y, epoch) if self.params.alpha_pen1 else 0.0 ) t_alphapen2 = ( lambda y, theta_NN, epoch: self.Pen2ndOrder(y, epoch) if self.params.alpha_pen2 else 0.0 ) t_beta_reg = ( lambda y, theta_NN, epoch: self.Regularization(theta_NN, epoch) if self.params.beta_reg else 0.0 ) self.loss_func = ( lambda y, theta_NN, epoch: t_alphapen1(y, theta_NN, epoch) + t_alphapen2(y, theta_NN, epoch) + t_beta_reg(y, theta_NN, epoch) )
[docs] def MSE_term( self, model: torch.Tensor, target: torch.Tensor, epoch: int ) -> torch.Tensor: r"""Evaluates the loss term .. math:: \operatorname{MSE}[\boldsymbol{\theta}]:=\frac{1}{L} \sum_{i=1}^4 \sum_{j=1}^L\left(\log \left|J_i\left(f_j\right)\right|-\log \left|\widetilde{J}_i\left(f_j, \boldsymbol{\theta}\right)\right|\right)^2, Parameters ---------- model : torch.Tensor Model output of spectra on the k1space provided in the constructor. target : torch.Tensor True spectra data on the k1space provided in the constructor. epoch : int Epoch number, used for the TensorBoard loss writer. Returns ------- torch.Tensor Evaluated MSE loss term. """ mse_loss = torch.mean(torch.log(torch.abs(model / target)).square()) self.writer.add_scalar("MSE Loss", mse_loss, epoch) return mse_loss
[docs] def Pen2ndOrder(self, y: torch.Tensor, epoch: int) -> torch.Tensor: r"""Evaluates he second-order penalization term, defined as .. math:: \operatorname{Pen}_2[\boldsymbol{\theta}]:=\frac{1}{|\mathcal{D}|} \sum_{i=1}^4\left\|\operatorname{ReLU}\left(\frac{\partial^2 \log \left|\widetilde{J}_i(\cdot, \boldsymbol{\theta})\right|}{\left(\partial \log k_1\right)^2}\right)\right\|_{\mathcal{D}}^2, Parameters ---------- y : torch.Tensor Model spectra output. epoch : int Epoch number, used for the TensorBoard loss writer. Returns ------- torch.Tensor 2nd order penalty loss term. """ logy = torch.log(torch.abs(y)) d2logy = torch.diff(torch.diff(logy, dim=-1) / self.h1, dim=-1) / self.h2 pen2ndorder_loss = ( self.params.alpha_pen2 * torch.mean(torch.relu(d2logy).square()) / self.zref**2 ) self.writer.add_scalar("2nd Order Penalty", pen2ndorder_loss, epoch) return pen2ndorder_loss
[docs] def Pen1stOrder(self, y: torch.Tensor, epoch: int) -> torch.Tensor: r"""Evaluates the first-order penalization term, defined as .. math:: \operatorname{Pen}_1[\boldsymbol{\theta}]:=\frac{1}{|\mathcal{D}|} \sum_{i=1}^4\left\|\operatorname{ReLU}\left(\frac{\partial \log \left|\widetilde{J}_i(\cdot, \boldsymbol{\theta})\right|}{\partial \log k_1}\right)\right\|_{\mathcal{D}}^2. Parameters ---------- y : torch.Tensor Model spectra output. epoch : int Epoch number, used for the TensorBoard loss writer. Returns ------- torch.Tensor 1st order penalty loss. """ logy = torch.log(torch.abs(y)) d1logy = torch.diff(logy, dim=-1) / self.h1 pen1storder_loss = ( self.params.alpha_pen1 * torch.mean(torch.relu(d1logy).square()) / self.zref ) self.writer.add_scalar("1st Order Penalty", pen1storder_loss, epoch) return pen1storder_loss
[docs] def Regularization(self, theta_NN: torch.Tensor, epoch: int) -> torch.Tensor: r"""Evaluates the regularization term, defined as .. math:: \operatorname{Reg}\left[\boldsymbol{\theta}_{\mathrm{NN}}\right]:=\frac{1}{N} \sum_{i=1}^N \theta_{\mathrm{NN}, i}^2. Parameters ---------- theta_NN : torch.Tensor Neural network parameters. epoch : int Epoch number, used for the TensorBoard loss writer. Returns ------- torch.Tensor Regularization loss of neural network model. """ if theta_NN is None: raise ValueError( "Regularization loss requires a neural network model to be used. Set the regularization hyperparameter (beta_reg) to 0 if using a non-neural network model." ) reg_loss = self.params.beta_reg * theta_NN.square().mean() self.writer.add_scalar("Regularization", reg_loss, epoch) return reg_loss
[docs] def eval( self, y: torch.Tensor, y_data: torch.Tensor, theta_NN: Optional[torch.Tensor], epoch: int, ) -> torch.Tensor: """Evaluation method for computing the full loss term at a given epoch. This method sequentially evaluates each term in the loss and returns the sum total loss. Parameters ---------- y : torch.Tensor Model spectra output on k1space in constructor. y_data : torch.Tensor True spectra data on k1space in constructor. theta_NN : Optional[torch.Tensor] Neural network parameters, used in the regularization loss term if activated. Can be set to None if no neural network is used. epoch : int Epoch number, used for the TensorBoard loss writer. Returns ------- torch.Tensor Evaluated total loss, as a weighted sum of all terms with non-zero hyperparameters. """ total_loss = self.MSE_term(y, y_data, epoch) + self.loss_func( y, theta_NN, epoch ) self.writer.add_scalar("Total Loss", total_loss, epoch) return total_loss