"""Enhancement model module."""
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import torch
from typeguard import check_argument_types
from espnet2.enh.decoder.abs_decoder import AbsDecoder
from espnet2.enh.encoder.abs_encoder import AbsEncoder
from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainLoss
from espnet2.enh.loss.criterions.time_domain import TimeDomainLoss
from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper
from espnet2.enh.separator.abs_separator import AbsSeparator
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel
is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
EPS = torch.finfo(torch.get_default_dtype()).eps
[docs]class ESPnetEnhancementModel(AbsESPnetModel):
"""Speech enhancement or separation Frontend model"""
def __init__(
self,
encoder: AbsEncoder,
separator: AbsSeparator,
decoder: AbsDecoder,
loss_wrappers: List[AbsLossWrapper],
stft_consistency: bool = False,
loss_type: str = "mask_mse",
mask_type: Optional[str] = None,
):
assert check_argument_types()
super().__init__()
self.encoder = encoder
self.separator = separator
self.decoder = decoder
self.loss_wrappers = loss_wrappers
self.num_spk = separator.num_spk
self.num_noise_type = getattr(self.separator, "num_noise_type", 1)
# get mask type for TF-domain models
# (only used when loss_type="mask_*") (deprecated, keep for compatibility)
self.mask_type = mask_type.upper() if mask_type else None
# get loss type for model training (deprecated, keep for compatibility)
self.loss_type = loss_type
# whether to compute the TF-domain loss
# while enforcing STFT consistency (deprecated, keep for compatibility)
self.stft_consistency = stft_consistency
# for multi-channel signal (deprecated, keep for compatibility)
self.ref_channel = getattr(self.separator, "ref_channel", -1)
[docs] def forward(
self,
speech_mix: torch.Tensor,
speech_mix_lengths: torch.Tensor = None,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech_mix: (Batch, samples) or (Batch, samples, channels)
speech_ref: (Batch, num_speaker, samples)
or (Batch, num_speaker, samples, channels)
speech_mix_lengths: (Batch,), default None for chunk interator,
because the chunk-iterator does not have the
speech_lengths returned. see in
espnet2/iterators/chunk_iter_factory.py
"""
# clean speech signal of each speaker
speech_ref = [
kwargs["speech_ref{}".format(spk + 1)] for spk in range(self.num_spk)
]
# (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels)
speech_ref = torch.stack(speech_ref, dim=1)
if "noise_ref1" in kwargs:
# noise signal (optional, required when using
# frontend models with beamformering)
noise_ref = [
kwargs["noise_ref{}".format(n + 1)] for n in range(self.num_noise_type)
]
# (Batch, num_noise_type, samples) or
# (Batch, num_noise_type, samples, channels)
noise_ref = torch.stack(noise_ref, dim=1)
else:
noise_ref = None
# dereverberated (noisy) signal
# (optional, only used for frontend models with WPE)
if "dereverb_ref1" in kwargs:
# noise signal (optional, required when using
# frontend models with beamformering)
dereverb_speech_ref = [
kwargs["dereverb_ref{}".format(n + 1)]
for n in range(self.num_spk)
if "dereverb_ref{}".format(n + 1) in kwargs
]
assert len(dereverb_speech_ref) in (1, self.num_spk), len(
dereverb_speech_ref
)
# (Batch, N, samples) or (Batch, N, samples, channels)
dereverb_speech_ref = torch.stack(dereverb_speech_ref, dim=1)
else:
dereverb_speech_ref = None
batch_size = speech_mix.shape[0]
speech_lengths = (
speech_mix_lengths
if speech_mix_lengths is not None
else torch.ones(batch_size).int().fill_(speech_mix.shape[1])
)
assert speech_lengths.dim() == 1, speech_lengths.shape
# Check that batch_size is unified
assert speech_mix.shape[0] == speech_ref.shape[0] == speech_lengths.shape[0], (
speech_mix.shape,
speech_ref.shape,
speech_lengths.shape,
)
# for data-parallel
speech_ref = speech_ref[..., : speech_lengths.max()]
speech_ref = speech_ref.unbind(dim=1)
speech_mix = speech_mix[:, : speech_lengths.max()]
# model forward
feature_mix, flens = self.encoder(speech_mix, speech_lengths)
feature_pre, flens, others = self.separator(feature_mix, flens)
if feature_pre is not None:
speech_pre = [self.decoder(ps, speech_lengths)[0] for ps in feature_pre]
else:
# some models (e.g. neural beamformer trained with mask loss)
# do not predict time-domain signal in the training stage
speech_pre = None
loss = 0.0
stats = dict()
o = {}
for loss_wrapper in self.loss_wrappers:
criterion = loss_wrapper.criterion
if isinstance(criterion, TimeDomainLoss):
if speech_ref[0].dim() == 3:
# For multi-channel reference,
# only select one channel as the reference
speech_ref = [sr[..., self.ref_channel] for sr in speech_ref]
# for the time domain criterions
l, s, o = loss_wrapper(speech_ref, speech_pre, o)
elif isinstance(criterion, FrequencyDomainLoss):
# for the time-frequency domain criterions
if criterion.compute_on_mask:
# compute on mask
tf_ref = criterion.create_mask_label(
feature_mix,
[self.encoder(sr, speech_lengths)[0] for sr in speech_ref],
)
tf_pre = [
others["mask_spk{}".format(spk + 1)]
for spk in range(self.num_spk)
]
else:
# compute on spectrum
if speech_ref[0].dim() == 3:
# For multi-channel reference,
# only select one channel as the reference
speech_ref = [sr[..., self.ref_channel] for sr in speech_ref]
tf_ref = [self.encoder(sr, speech_lengths)[0] for sr in speech_ref]
tf_pre = feature_pre
l, s, o = loss_wrapper(tf_ref, tf_pre, o)
loss += l * loss_wrapper.weight
stats.update(s)
stats["loss"] = loss.detach()
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
[docs] def collect_feats(
self, speech_mix: torch.Tensor, speech_mix_lengths: torch.Tensor, **kwargs
) -> Dict[str, torch.Tensor]:
# for data-parallel
speech_mix = speech_mix[:, : speech_mix_lengths.max()]
feats, feats_lengths = speech_mix, speech_mix_lengths
return {"feats": feats, "feats_lengths": feats_lengths}