Source code for espnet2.diar.espnet_model

# Copyright 2021 Jiatong Shi
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

from contextlib import contextmanager
from distutils.version import LooseVersion
from itertools import permutations
from typing import Dict
from typing import Optional
from typing import Tuple

import numpy as np
import torch
from typeguard import check_argument_types

from espnet.nets.pytorch_backend.nets_utils import to_device
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.diar.attractor.abs_attractor import AbsAttractor
from espnet2.diar.decoder.abs_decoder import AbsDecoder
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel

if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield


[docs]class ESPnetDiarizationModel(AbsESPnetModel): """Speaker Diarization model If "attractor" is "None", SA-EEND will be used. Else if "attractor" is not "None", EEND-EDA will be used. For the details about SA-EEND and EEND-EDA, refer to the following papers: SA-EEND: https://arxiv.org/pdf/1909.06247.pdf EEND-EDA: https://arxiv.org/pdf/2005.09921.pdf, https://arxiv.org/pdf/2106.10654.pdf """ def __init__( self, frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], label_aggregator: torch.nn.Module, encoder: AbsEncoder, decoder: AbsDecoder, attractor: Optional[AbsAttractor], attractor_weight: float = 1.0, ): assert check_argument_types() super().__init__() self.encoder = encoder self.normalize = normalize self.frontend = frontend self.specaug = specaug self.label_aggregator = label_aggregator self.attractor_weight = attractor_weight self.attractor = attractor self.decoder = decoder if self.attractor is not None: self.decoder = None elif self.decoder is not None: self.num_spk = decoder.num_spk else: raise NotImplementedError
[docs] def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor = None, spk_labels: torch.Tensor = None, spk_labels_lengths: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, samples) speech_lengths: (Batch,) default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py spk_labels: (Batch, ) """ assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape) batch_size = speech.shape[0] # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) if self.attractor is None: # 2a. Decoder (baiscally a predction layer after encoder_out) pred = self.decoder(encoder_out, encoder_out_lens) else: # 2b. Encoder Decoder Attractors # Shuffle the chronological order of encoder_out, then calculate attractor encoder_out_shuffled = encoder_out.clone() for i in range(len(encoder_out_lens)): encoder_out_shuffled[i, : encoder_out_lens[i], :] = encoder_out[ i, torch.randperm(encoder_out_lens[i]), : ] attractor, att_prob = self.attractor( encoder_out_shuffled, encoder_out_lens, to_device( self, torch.zeros( encoder_out.size(0), spk_labels.size(2) + 1, encoder_out.size(2) ), ), ) # Remove the final attractor which does not correspond to a speaker # Then multiply the attractors and encoder_out pred = torch.bmm(encoder_out, attractor[:, :-1, :].permute(0, 2, 1)) # 3. Aggregate time-domain labels spk_labels, spk_labels_lengths = self.label_aggregator( spk_labels, spk_labels_lengths ) # If encoder uses conv* as input_layer (i.e., subsampling), # the sequence length of 'pred' might be slighly less than the # length of 'spk_labels'. Here we force them to be equal. length_diff_tolerance = 2 length_diff = spk_labels.shape[1] - pred.shape[1] if length_diff > 0 and length_diff <= length_diff_tolerance: spk_labels = spk_labels[:, 0 : pred.shape[1], :] if self.attractor is None: loss_pit, loss_att = None, None loss, perm_idx, perm_list, label_perm = self.pit_loss( pred, spk_labels, encoder_out_lens ) else: loss_pit, perm_idx, perm_list, label_perm = self.pit_loss( pred, spk_labels, encoder_out_lens ) loss_att = self.attractor_loss(att_prob, spk_labels) loss = loss_pit + self.attractor_weight * loss_att ( correct, num_frames, speech_scored, speech_miss, speech_falarm, speaker_scored, speaker_miss, speaker_falarm, speaker_error, ) = self.calc_diarization_error(pred, label_perm, encoder_out_lens) if speech_scored > 0 and num_frames > 0: sad_mr, sad_fr, mi, fa, cf, acc, der = ( speech_miss / speech_scored, speech_falarm / speech_scored, speaker_miss / speaker_scored, speaker_falarm / speaker_scored, speaker_error / speaker_scored, correct / num_frames, (speaker_miss + speaker_falarm + speaker_error) / speaker_scored, ) else: sad_mr, sad_fr, mi, fa, cf, acc, der = 0, 0, 0, 0, 0, 0, 0 stats = dict( loss=loss.detach(), loss_att=loss_att.detach() if loss_att is not None else None, loss_pit=loss_pit.detach() if loss_pit is not None else None, sad_mr=sad_mr, sad_fr=sad_fr, mi=mi, fa=fa, cf=cf, acc=acc, der=der, ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
[docs] def collect_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor, spk_labels: torch.Tensor = None, spk_labels_lengths: torch.Tensor = None, ) -> Dict[str, torch.Tensor]: feats, feats_lengths = self._extract_feats(speech, speech_lengths) return {"feats": feats, "feats_lengths": feats_lengths}
[docs] def encode( self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder Args: speech: (Batch, Length, ...) speech_lengths: (Batch,) """ with autocast(False): # 1. Extract feats feats, feats_lengths = self._extract_feats(speech, speech_lengths) # 2. Data augmentation if self.specaug is not None and self.training: feats, feats_lengths = self.specaug(feats, feats_lengths) # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN if self.normalize is not None: feats, feats_lengths = self.normalize(feats, feats_lengths) # 4. Forward encoder # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim) encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) assert encoder_out.size(0) == speech.size(0), ( encoder_out.size(), speech.size(0), ) assert encoder_out.size(1) <= encoder_out_lens.max(), ( encoder_out.size(), encoder_out_lens.max(), ) return encoder_out, encoder_out_lens
def _extract_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: batch_size = speech.shape[0] speech_lengths = ( speech_lengths if speech_lengths is not None else torch.ones(batch_size).int() * speech.shape[1] ) assert speech_lengths.dim() == 1, speech_lengths.shape # for data-parallel speech = speech[:, : speech_lengths.max()] if self.frontend is not None: # Frontend # e.g. STFT and Feature extract # data_loader may send time-domain signal in this case # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) feats, feats_lengths = self.frontend(speech, speech_lengths) else: # No frontend and no feature extract feats, feats_lengths = speech, speech_lengths return feats, feats_lengths
[docs] def pit_loss_single_permute(self, pred, label, length): bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none") mask = self.create_length_mask(length, label.size(1), label.size(2)) loss = bce_loss(pred, label) loss = loss * mask loss = torch.sum(torch.mean(loss, dim=2), dim=1) loss = torch.unsqueeze(loss, dim=1) return loss
[docs] def pit_loss(self, pred, label, lengths): # Note (jiatong): Credit to https://github.com/hitachi-speech/EEND num_output = label.size(2) permute_list = [np.array(p) for p in permutations(range(num_output))] loss_list = [] for p in permute_list: label_perm = label[:, :, p] loss_perm = self.pit_loss_single_permute(pred, label_perm, lengths) loss_list.append(loss_perm) loss = torch.cat(loss_list, dim=1) min_loss, min_idx = torch.min(loss, dim=1) loss = torch.sum(min_loss) / torch.sum(lengths.float()) batch_size = len(min_idx) label_list = [] for i in range(batch_size): label_list.append(label[i, :, permute_list[min_idx[i]]].data.cpu().numpy()) label_permute = torch.from_numpy(np.array(label_list)).float() return loss, min_idx, permute_list, label_permute
[docs] def create_length_mask(self, length, max_len, num_output): batch_size = len(length) mask = torch.zeros(batch_size, max_len, num_output) for i in range(batch_size): mask[i, : length[i], :] = 1 mask = to_device(self, mask) return mask
[docs] def attractor_loss(self, att_prob, label): batch_size = len(label) bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none") # create attractor label [1, 1, ..., 1, 0] # att_label: (Batch, num_spk + 1, 1) att_label = to_device(self, torch.zeros(batch_size, label.size(2) + 1, 1)) att_label[:, : label.size(2), :] = 1 loss = bce_loss(att_prob, att_label) loss = torch.mean(torch.mean(loss, dim=1)) return loss
[docs] @staticmethod def calc_diarization_error(pred, label, length): # Note (jiatong): Credit to https://github.com/hitachi-speech/EEND (batch_size, max_len, num_output) = label.size() # mask the padding part mask = np.zeros((batch_size, max_len, num_output)) for i in range(batch_size): mask[i, : length[i], :] = 1 # pred and label have the shape (batch_size, max_len, num_output) label_np = label.data.cpu().numpy().astype(int) pred_np = (pred.data.cpu().numpy() > 0).astype(int) label_np = label_np * mask pred_np = pred_np * mask length = length.data.cpu().numpy() # compute speech activity detection error n_ref = np.sum(label_np, axis=2) n_sys = np.sum(pred_np, axis=2) speech_scored = float(np.sum(n_ref > 0)) speech_miss = float(np.sum(np.logical_and(n_ref > 0, n_sys == 0))) speech_falarm = float(np.sum(np.logical_and(n_ref == 0, n_sys > 0))) # compute speaker diarization error speaker_scored = float(np.sum(n_ref)) speaker_miss = float(np.sum(np.maximum(n_ref - n_sys, 0))) speaker_falarm = float(np.sum(np.maximum(n_sys - n_ref, 0))) n_map = np.sum(np.logical_and(label_np == 1, pred_np == 1), axis=2) speaker_error = float(np.sum(np.minimum(n_ref, n_sys) - n_map)) correct = float(1.0 * np.sum((label_np == pred_np) * mask) / num_output) num_frames = np.sum(length) return ( correct, num_frames, speech_scored, speech_miss, speech_falarm, speaker_scored, speaker_miss, speaker_falarm, speaker_error, )