Source code for espnet2.enh.loss.criterions.tf_domain

from abc import ABC
from abc import abstractmethod
from distutils.version import LooseVersion
from functools import reduce

import torch

from espnet2.enh.layers.complex_utils import is_complex
from espnet2.enh.loss.criterions.abs_loss import AbsEnhLoss


is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")

EPS = torch.finfo(torch.get_default_dtype()).eps


def _create_mask_label(mix_spec, ref_spec, mask_type="IAM"):
    """Create mask label.

    Args:
        mix_spec: ComplexTensor(B, T, [C,] F)
        ref_spec: List[ComplexTensor(B, T, [C,] F), ...]
        mask_type: str
    Returns:
        labels: List[Tensor(B, T, [C,] F), ...] or List[ComplexTensor(B, T, F), ...]
    """

    # Must be upper case
    assert mask_type in [
        "IBM",
        "IRM",
        "IAM",
        "PSM",
        "NPSM",
        "PSM^2",
    ], f"mask type {mask_type} not supported"
    mask_label = []
    for r in ref_spec:
        mask = None
        if mask_type == "IBM":
            flags = [abs(r) >= abs(n) for n in ref_spec]
            mask = reduce(lambda x, y: x * y, flags)
            mask = mask.int()
        elif mask_type == "IRM":
            # TODO(Wangyou): need to fix this,
            #  as noise referecens are provided separately
            mask = abs(r) / (sum(([abs(n) for n in ref_spec])) + EPS)
        elif mask_type == "IAM":
            mask = abs(r) / (abs(mix_spec) + EPS)
            mask = mask.clamp(min=0, max=1)
        elif mask_type == "PSM" or mask_type == "NPSM":
            phase_r = r / (abs(r) + EPS)
            phase_mix = mix_spec / (abs(mix_spec) + EPS)
            # cos(a - b) = cos(a)*cos(b) + sin(a)*sin(b)
            cos_theta = phase_r.real * phase_mix.real + phase_r.imag * phase_mix.imag
            mask = (abs(r) / (abs(mix_spec) + EPS)) * cos_theta
            mask = (
                mask.clamp(min=0, max=1)
                if mask_type == "NPSM"
                else mask.clamp(min=-1, max=1)
            )
        elif mask_type == "PSM^2":
            # This is for training beamforming masks
            phase_r = r / (abs(r) + EPS)
            phase_mix = mix_spec / (abs(mix_spec) + EPS)
            # cos(a - b) = cos(a)*cos(b) + sin(a)*sin(b)
            cos_theta = phase_r.real * phase_mix.real + phase_r.imag * phase_mix.imag
            mask = (abs(r).pow(2) / (abs(mix_spec).pow(2) + EPS)) * cos_theta
            mask = mask.clamp(min=-1, max=1)
        assert mask is not None, f"mask type {mask_type} not supported"
        mask_label.append(mask)
    return mask_label


[docs]class FrequencyDomainLoss(AbsEnhLoss, ABC): # The loss will be computed on mask or on spectrum @property @abstractmethod def compute_on_mask() -> bool: pass # the mask type @property @abstractmethod def mask_type() -> str: pass
[docs] def create_mask_label(self, mix_spec, ref_spec): return _create_mask_label( mix_spec=mix_spec, ref_spec=ref_spec, mask_type=self.mask_type )
[docs]class FrequencyDomainMSE(FrequencyDomainLoss): def __init__(self, compute_on_mask=False, mask_type="IBM"): super().__init__() self._compute_on_mask = compute_on_mask self._mask_type = mask_type @property def compute_on_mask(self) -> bool: return self._compute_on_mask @property def mask_type(self) -> str: return self._mask_type @property def name(self) -> str: if self.compute_on_mask: return f"MSE_on_{self.mask_type}" else: return "MSE_on_Spec"
[docs] def forward(self, ref, inf) -> torch.Tensor: """time-frequency MSE loss. Args: ref: (Batch, T, F) or (Batch, T, C, F) inf: (Batch, T, F) or (Batch, T, C, F) Returns: loss: (Batch,) """ assert ref.shape == inf.shape, (ref.shape, inf.shape) diff = ref - inf if is_complex(diff): mseloss = diff.real**2 + diff.imag**2 else: mseloss = diff**2 if ref.dim() == 3: mseloss = mseloss.mean(dim=[1, 2]) elif ref.dim() == 4: mseloss = mseloss.mean(dim=[1, 2, 3]) else: raise ValueError( "Invalid input shape: ref={}, inf={}".format(ref.shape, inf.shape) ) return mseloss
[docs]class FrequencyDomainL1(FrequencyDomainLoss): def __init__(self, compute_on_mask=False, mask_type="IBM"): super().__init__() self._compute_on_mask = compute_on_mask self._mask_type = mask_type @property def compute_on_mask(self) -> bool: return self._compute_on_mask @property def mask_type(self) -> str: return self._mask_type @property def name(self) -> str: if self.compute_on_mask: return f"L1_on_{self.mask_type}" else: return "L1_on_Spec"
[docs] def forward(self, ref, inf) -> torch.Tensor: """time-frequency L1 loss. Args: ref: (Batch, T, F) or (Batch, T, C, F) inf: (Batch, T, F) or (Batch, T, C, F) Returns: loss: (Batch,) """ assert ref.shape == inf.shape, (ref.shape, inf.shape) if is_complex(inf): l1loss = abs(ref - inf + EPS) else: l1loss = abs(ref - inf) if ref.dim() == 3: l1loss = l1loss.mean(dim=[1, 2]) elif ref.dim() == 4: l1loss = l1loss.mean(dim=[1, 2, 3]) else: raise ValueError( "Invalid input shape: ref={}, inf={}".format(ref.shape, inf.shape) ) return l1loss