from contextlib import contextmanager
from distutils.version import LooseVersion
from itertools import groupby
import logging
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy
import torch
from typeguard import check_argument_types
from espnet.nets.beam_search import Hypothesis
from espnet.nets.e2e_asr_common import ErrorCalculator
from espnet.nets.pytorch_backend.maskctc.add_mask_token import mask_uniform
from espnet.nets.pytorch_backend.nets_utils import th_accuracy
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from espnet2.asr.ctc import CTC
from espnet2.asr.decoder.mlm_decoder import MLMDecoder
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.espnet_model import ESPnetASRModel
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.text.token_id_converter import TokenIDConverter
from espnet2.torch_utils.device_funcs import force_gatherable
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
[docs]class MaskCTCModel(ESPnetASRModel):
"""Hybrid CTC/Masked LM Encoder-Decoder model (Mask-CTC)"""
def __init__(
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: MLMDecoder,
ctc: CTC,
joint_network: Optional[torch.nn.Module] = None,
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
ignore_id: int = -1,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
sym_mask: str = "<mask>",
extract_feats_in_collect_stats: bool = True,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
specaug=specaug,
normalize=normalize,
preencoder=preencoder,
encoder=encoder,
postencoder=postencoder,
decoder=decoder,
ctc=ctc,
joint_network=joint_network,
ctc_weight=ctc_weight,
interctc_weight=interctc_weight,
ignore_id=ignore_id,
lsm_weight=lsm_weight,
length_normalized_loss=length_normalized_loss,
report_cer=report_cer,
report_wer=report_wer,
sym_space=sym_space,
sym_blank=sym_blank,
extract_feats_in_collect_stats=extract_feats_in_collect_stats,
)
# Add <mask> and override inherited fields
token_list.append(sym_mask)
vocab_size += 1
self.vocab_size = vocab_size
self.mask_token = vocab_size - 1
self.token_list = token_list.copy()
# MLM loss
del self.criterion_att
self.criterion_mlm = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
self.error_calculator = None
if report_cer or report_wer:
self.error_calculator = ErrorCalculator(
token_list, sym_space, sym_blank, report_cer, report_wer
)
[docs] def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
# For data-parallel
text = text[:, : text_lengths.max()]
# Define stats to report
loss_mlm, acc_mlm = None, None
loss_ctc, cer_ctc = None, None
stats = dict()
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
# 2. CTC branch
if self.ctc_weight != 0.0:
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# 2a. Intermediate CTC (optional)
loss_interctc = 0.0
if self.interctc_weight != 0.0 and intermediate_outs is not None:
for layer_idx, intermediate_out in intermediate_outs:
# we assume intermediate_out has the same length & padding
# as those of encoder_out
loss_ic, cer_ic = self._calc_ctc_loss(
intermediate_out, encoder_out_lens, text, text_lengths
)
loss_interctc = loss_interctc + loss_ic
# Collect Intermedaite CTC stats
stats["loss_interctc_layer{}".format(layer_idx)] = (
loss_ic.detach() if loss_ic is not None else None
)
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
loss_interctc = loss_interctc / len(intermediate_outs)
# calculate whole encoder loss
loss_ctc = (
1 - self.interctc_weight
) * loss_ctc + self.interctc_weight * loss_interctc
# 3. MLM decoder branch
if self.ctc_weight != 1.0:
loss_mlm, acc_mlm = self._calc_mlm_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 4. CTC/MLM loss definition
if self.ctc_weight == 0.0:
loss = loss_mlm
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_mlm
# Collect MLM branch stats
stats["loss_mlm"] = loss_mlm.detach() if loss_mlm is not None else None
stats["acc_mlm"] = acc_mlm
# Collect total loss stats
stats["loss"] = loss.detach()
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def _calc_mlm_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# 1. Apply masks
ys_in_pad, ys_out_pad = mask_uniform(
ys_pad, self.mask_token, self.eos, self.ignore_id
)
# 2. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_pad_lens
)
# 3. Compute mlm loss
loss_mlm = self.criterion_mlm(decoder_out, ys_out_pad)
acc_mlm = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
return loss_mlm, acc_mlm
[docs] def nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
[docs] def batchify_nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
batch_size: int = 100,
):
raise NotImplementedError
[docs]class MaskCTCInference(torch.nn.Module):
"""Mask-CTC-based non-autoregressive inference"""
def __init__(
self,
asr_model: MaskCTCModel,
n_iterations: int,
threshold_probability: float,
):
"""Initialize Mask-CTC inference"""
super().__init__()
self.ctc = asr_model.ctc
self.mlm = asr_model.decoder
self.mask_token = asr_model.mask_token
self.n_iterations = n_iterations
self.threshold_probability = threshold_probability
self.converter = TokenIDConverter(token_list=asr_model.token_list)
[docs] def ids2text(self, ids: List[int]):
text = "".join(self.converter.ids2tokens(ids))
return text.replace("<mask>", "_").replace("<space>", " ")
[docs] def forward(self, enc_out: torch.Tensor) -> List[Hypothesis]:
"""Perform Mask-CTC inference"""
# greedy ctc outputs
enc_out = enc_out.unsqueeze(0)
ctc_probs, ctc_ids = torch.exp(self.ctc.log_softmax(enc_out)).max(dim=-1)
y_hat = torch.stack([x[0] for x in groupby(ctc_ids[0])])
y_idx = torch.nonzero(y_hat != 0).squeeze(-1)
logging.info("ctc:{}".format(self.ids2text(y_hat[y_idx].tolist())))
# calculate token-level ctc probabilities by taking
# the maximum probability of consecutive frames with
# the same ctc symbols
probs_hat = []
cnt = 0
for i, y in enumerate(y_hat.tolist()):
probs_hat.append(-1)
while cnt < ctc_ids.shape[1] and y == ctc_ids[0][cnt]:
if probs_hat[i] < ctc_probs[0][cnt]:
probs_hat[i] = ctc_probs[0][cnt].item()
cnt += 1
probs_hat = torch.from_numpy(numpy.array(probs_hat))
# mask ctc outputs based on ctc probabilities
p_thres = self.threshold_probability
mask_idx = torch.nonzero(probs_hat[y_idx] < p_thres).squeeze(-1)
confident_idx = torch.nonzero(probs_hat[y_idx] >= p_thres).squeeze(-1)
mask_num = len(mask_idx)
y_in = torch.zeros(1, len(y_idx), dtype=torch.long) + self.mask_token
y_in[0][confident_idx] = y_hat[y_idx][confident_idx]
logging.info("msk:{}".format(self.ids2text(y_in[0].tolist())))
# iterative decoding
if not mask_num == 0:
K = self.n_iterations
num_iter = K if mask_num >= K and K > 0 else mask_num
for t in range(num_iter - 1):
pred, _ = self.mlm(enc_out, [enc_out.size(1)], y_in, [y_in.size(1)])
pred_score, pred_id = pred[0][mask_idx].max(dim=-1)
cand = torch.topk(pred_score, mask_num // num_iter, -1)[1]
y_in[0][mask_idx[cand]] = pred_id[cand]
mask_idx = torch.nonzero(y_in[0] == self.mask_token).squeeze(-1)
logging.info("msk:{}".format(self.ids2text(y_in[0].tolist())))
# predict leftover masks (|masks| < mask_num // num_iter)
pred, _ = self.mlm(enc_out, [enc_out.size(1)], y_in, [y_in.size(1)])
y_in[0][mask_idx] = pred[0][mask_idx].argmax(dim=-1)
logging.info("msk:{}".format(self.ids2text(y_in[0].tolist())))
# pad with mask tokens to ensure compatibility with sos/eos tokens
yseq = torch.tensor(
[self.mask_token] + y_in.tolist()[0] + [self.mask_token], device=y_in.device
)
return Hypothesis(yseq=yseq)