Source code for espnet2.asr.maskctc_model

from contextlib import contextmanager
from distutils.version import LooseVersion
from itertools import groupby
import logging
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import numpy
import torch
from typeguard import check_argument_types

from espnet.nets.beam_search import Hypothesis
from espnet.nets.e2e_asr_common import ErrorCalculator
from espnet.nets.pytorch_backend.maskctc.add_mask_token import mask_uniform
from espnet.nets.pytorch_backend.nets_utils import th_accuracy
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import (
    LabelSmoothingLoss,  # noqa: H301
from espnet2.asr.ctc import CTC
from espnet2.asr.decoder.mlm_decoder import MLMDecoder
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.espnet_model import ESPnetASRModel
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.text.token_id_converter import TokenIDConverter
from espnet2.torch_utils.device_funcs import force_gatherable

if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
    # Nothing to do if torch<1.6.0
    def autocast(enabled=True):

[docs]class MaskCTCModel(ESPnetASRModel): """Hybrid CTC/Masked LM Encoder-Decoder model (Mask-CTC)""" def __init__( self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], encoder: AbsEncoder, postencoder: Optional[AbsPostEncoder], decoder: MLMDecoder, ctc: CTC, joint_network: Optional[torch.nn.Module] = None, ctc_weight: float = 0.5, interctc_weight: float = 0.0, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, sym_space: str = "<space>", sym_blank: str = "<blank>", sym_mask: str = "<mask>", extract_feats_in_collect_stats: bool = True, ): assert check_argument_types() super().__init__( vocab_size=vocab_size, token_list=token_list, frontend=frontend, specaug=specaug, normalize=normalize, preencoder=preencoder, encoder=encoder, postencoder=postencoder, decoder=decoder, ctc=ctc, joint_network=joint_network, ctc_weight=ctc_weight, interctc_weight=interctc_weight, ignore_id=ignore_id, lsm_weight=lsm_weight, length_normalized_loss=length_normalized_loss, report_cer=report_cer, report_wer=report_wer, sym_space=sym_space, sym_blank=sym_blank, extract_feats_in_collect_stats=extract_feats_in_collect_stats, ) # Add <mask> and override inherited fields token_list.append(sym_mask) vocab_size += 1 self.vocab_size = vocab_size self.mask_token = vocab_size - 1 self.token_list = token_list.copy() # MLM loss del self.criterion_att self.criterion_mlm = LabelSmoothingLoss( size=vocab_size, padding_idx=ignore_id, smoothing=lsm_weight, normalize_length=length_normalized_loss, ) self.error_calculator = None if report_cer or report_wer: self.error_calculator = ErrorCalculator( token_list, sym_space, sym_blank, report_cer, report_wer )
[docs] def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert ( speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0] ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) batch_size = speech.shape[0] # For data-parallel text = text[:, : text_lengths.max()] # Define stats to report loss_mlm, acc_mlm = None, None loss_ctc, cer_ctc = None, None stats = dict() # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) intermediate_outs = None if isinstance(encoder_out, tuple): intermediate_outs = encoder_out[1] encoder_out = encoder_out[0] # 2. CTC branch if self.ctc_weight != 0.0: loss_ctc, cer_ctc = self._calc_ctc_loss( encoder_out, encoder_out_lens, text, text_lengths ) # Collect CTC branch stats stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None stats["cer_ctc"] = cer_ctc # 2a. Intermediate CTC (optional) loss_interctc = 0.0 if self.interctc_weight != 0.0 and intermediate_outs is not None: for layer_idx, intermediate_out in intermediate_outs: # we assume intermediate_out has the same length & padding # as those of encoder_out loss_ic, cer_ic = self._calc_ctc_loss( intermediate_out, encoder_out_lens, text, text_lengths ) loss_interctc = loss_interctc + loss_ic # Collect Intermedaite CTC stats stats["loss_interctc_layer{}".format(layer_idx)] = ( loss_ic.detach() if loss_ic is not None else None ) stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic loss_interctc = loss_interctc / len(intermediate_outs) # calculate whole encoder loss loss_ctc = ( 1 - self.interctc_weight ) * loss_ctc + self.interctc_weight * loss_interctc # 3. MLM decoder branch if self.ctc_weight != 1.0: loss_mlm, acc_mlm = self._calc_mlm_loss( encoder_out, encoder_out_lens, text, text_lengths ) # 4. CTC/MLM loss definition if self.ctc_weight == 0.0: loss = loss_mlm elif self.ctc_weight == 1.0: loss = loss_ctc else: loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_mlm # Collect MLM branch stats stats["loss_mlm"] = loss_mlm.detach() if loss_mlm is not None else None stats["acc_mlm"] = acc_mlm # Collect total loss stats stats["loss"] = loss.detach() # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
def _calc_mlm_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ): # 1. Apply masks ys_in_pad, ys_out_pad = mask_uniform( ys_pad, self.mask_token, self.eos, self.ignore_id ) # 2. Forward decoder decoder_out, _ = self.decoder( encoder_out, encoder_out_lens, ys_in_pad, ys_pad_lens ) # 3. Compute mlm loss loss_mlm = self.criterion_mlm(decoder_out, ys_out_pad) acc_mlm = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) return loss_mlm, acc_mlm
[docs] def nll( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ) -> torch.Tensor: raise NotImplementedError
[docs] def batchify_nll( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, batch_size: int = 100, ): raise NotImplementedError
[docs]class MaskCTCInference(torch.nn.Module): """Mask-CTC-based non-autoregressive inference""" def __init__( self, asr_model: MaskCTCModel, n_iterations: int, threshold_probability: float, ): """Initialize Mask-CTC inference""" super().__init__() self.ctc = asr_model.ctc self.mlm = asr_model.decoder self.mask_token = asr_model.mask_token self.n_iterations = n_iterations self.threshold_probability = threshold_probability self.converter = TokenIDConverter(token_list=asr_model.token_list)
[docs] def ids2text(self, ids: List[int]): text = "".join(self.converter.ids2tokens(ids)) return text.replace("<mask>", "_").replace("<space>", " ")
[docs] def forward(self, enc_out: torch.Tensor) -> List[Hypothesis]: """Perform Mask-CTC inference""" # greedy ctc outputs enc_out = enc_out.unsqueeze(0) ctc_probs, ctc_ids = torch.exp(self.ctc.log_softmax(enc_out)).max(dim=-1) y_hat = torch.stack([x[0] for x in groupby(ctc_ids[0])]) y_idx = torch.nonzero(y_hat != 0).squeeze(-1)"ctc:{}".format(self.ids2text(y_hat[y_idx].tolist()))) # calculate token-level ctc probabilities by taking # the maximum probability of consecutive frames with # the same ctc symbols probs_hat = [] cnt = 0 for i, y in enumerate(y_hat.tolist()): probs_hat.append(-1) while cnt < ctc_ids.shape[1] and y == ctc_ids[0][cnt]: if probs_hat[i] < ctc_probs[0][cnt]: probs_hat[i] = ctc_probs[0][cnt].item() cnt += 1 probs_hat = torch.from_numpy(numpy.array(probs_hat)) # mask ctc outputs based on ctc probabilities p_thres = self.threshold_probability mask_idx = torch.nonzero(probs_hat[y_idx] < p_thres).squeeze(-1) confident_idx = torch.nonzero(probs_hat[y_idx] >= p_thres).squeeze(-1) mask_num = len(mask_idx) y_in = torch.zeros(1, len(y_idx), dtype=torch.long) + self.mask_token y_in[0][confident_idx] = y_hat[y_idx][confident_idx]"msk:{}".format(self.ids2text(y_in[0].tolist()))) # iterative decoding if not mask_num == 0: K = self.n_iterations num_iter = K if mask_num >= K and K > 0 else mask_num for t in range(num_iter - 1): pred, _ = self.mlm(enc_out, [enc_out.size(1)], y_in, [y_in.size(1)]) pred_score, pred_id = pred[0][mask_idx].max(dim=-1) cand = torch.topk(pred_score, mask_num // num_iter, -1)[1] y_in[0][mask_idx[cand]] = pred_id[cand] mask_idx = torch.nonzero(y_in[0] == self.mask_token).squeeze(-1)"msk:{}".format(self.ids2text(y_in[0].tolist()))) # predict leftover masks (|masks| < mask_num // num_iter) pred, _ = self.mlm(enc_out, [enc_out.size(1)], y_in, [y_in.size(1)]) y_in[0][mask_idx] = pred[0][mask_idx].argmax(dim=-1)"msk:{}".format(self.ids2text(y_in[0].tolist()))) # pad with mask tokens to ensure compatibility with sos/eos tokens yseq = torch.tensor( [self.mask_token] + y_in.tolist()[0] + [self.mask_token], device=y_in.device ) return Hypothesis(yseq=yseq)