Source code for espnet2.enh.loss.criterions.abs_loss

from abc import ABC
from abc import abstractmethod


import torch

EPS = torch.finfo(torch.get_default_dtype()).eps


[docs]class AbsEnhLoss(torch.nn.Module, ABC): # the name will be the key that appears in the reporter @property def name(self) -> str: return NotImplementedError
[docs] @abstractmethod def forward( self, ref, inf, ) -> torch.Tensor: # the return tensor should be shape of (batch) raise NotImplementedError