Source code for espnet2.enh.loss.criterions.time_domain

from abc import ABC

import ci_sdr
import torch

from espnet2.enh.loss.criterions.abs_loss import AbsEnhLoss


[docs]class TimeDomainLoss(AbsEnhLoss, ABC): pass
EPS = torch.finfo(torch.get_default_dtype()).eps
[docs]class CISDRLoss(TimeDomainLoss): """CI-SDR loss Reference: Convolutive Transfer Function Invariant SDR Training Criteria for Multi-Channel Reverberant Speech Separation; C. Boeddeker et al., 2021; https://arxiv.org/abs/2011.15003 Args: ref: (Batch, samples) inf: (Batch, samples) filter_length (int): a time-invariant filter that allows slight distortion via filtering Returns: loss: (Batch,) """ def __init__(self, filter_length=512): super().__init__() self.filter_length = filter_length @property def name(self) -> str: return "ci_sdr_loss"
[docs] def forward( self, ref: torch.Tensor, inf: torch.Tensor, ) -> torch.Tensor: assert ref.shape == inf.shape, (ref.shape, inf.shape) return ci_sdr.pt.ci_sdr_loss( inf, ref, compute_permutation=False, filter_length=self.filter_length )
[docs]class SNRLoss(TimeDomainLoss): def __init__(self, eps=EPS): super().__init__() self.eps = float(eps) @property def name(self) -> str: return "snr_loss"
[docs] def forward( self, ref: torch.Tensor, inf: torch.Tensor, ) -> torch.Tensor: # the return tensor should be shape of (batch,) noise = inf - ref snr = 20 * ( torch.log10(torch.norm(ref, p=2, dim=1).clamp(min=self.eps)) - torch.log10(torch.norm(noise, p=2, dim=1).clamp(min=self.eps)) ) return -snr
[docs]class SISNRLoss(TimeDomainLoss): def __init__(self, eps=EPS): super().__init__() self.eps = float(eps) @property def name(self) -> str: return "si_snr_loss"
[docs] def forward( self, ref: torch.Tensor, inf: torch.Tensor, ) -> torch.Tensor: # the return tensor should be shape of (batch,) assert ref.size() == inf.size() B, T = ref.size() # Step 1. Zero-mean norm mean_target = torch.sum(ref, dim=1, keepdim=True) / T mean_estimate = torch.sum(inf, dim=1, keepdim=True) / T zero_mean_target = ref - mean_target zero_mean_estimate = inf - mean_estimate # Step 2. SI-SNR with order # reshape to use broadcast s_target = zero_mean_target # [B, T] s_estimate = zero_mean_estimate # [B, T] # s_target = <s', s>s / ||s||^2 pair_wise_dot = torch.sum(s_estimate * s_target, dim=1, keepdim=True) # [B, 1] s_target_energy = ( torch.sum(s_target**2, dim=1, keepdim=True) + self.eps ) # [B, 1] pair_wise_proj = pair_wise_dot * s_target / s_target_energy # [B, T] # e_noise = s' - s_target e_noise = s_estimate - pair_wise_proj # [B, T] # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2) pair_wise_si_snr = torch.sum(pair_wise_proj**2, dim=1) / ( torch.sum(e_noise**2, dim=1) + self.eps ) pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + self.eps) # [B] return -1 * pair_wise_si_snr