Source code for espnet2.hubert.espnet_model

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Thanks to Abdelrahman Mohamed and Wei-Ning Hsu's help in this implementation,
# Their origial Hubert work is in:
#     Paper: https://arxiv.org/pdf/2106.07447.pdf
#     Code in Fairseq: https://github.com/pytorch/fairseq/tree/master/examples/hubert

from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import torch
from typeguard import check_argument_types

from espnet.nets.e2e_asr_common import ErrorCalculator

from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.hubert.hubert_loss import HubertPretrainLoss
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel

if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield


[docs]class HubertPretrainModel(AbsESPnetModel): """Hubert Pretrain model""" def __init__( self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], encoder: AbsEncoder, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = False, report_wer: bool = False, sym_space: str = "<space>", sym_blank: str = "<blank>", pred_masked_weight: float = 1.0, pred_nomask_weight: float = 0.0, loss_weights: float = 0.0, ): assert check_argument_types() super().__init__() # note that eos is the same as sos (equivalent ID) self.sos = vocab_size - 1 self.eos = vocab_size - 1 self.vocab_size = vocab_size self.ignore_id = ignore_id self.token_list = token_list.copy() self.frontend = frontend self.specaug = specaug self.normalize = normalize self.preencoder = preencoder self.encoder = encoder self.criterion_att = HubertPretrainLoss( pred_masked_weight, pred_nomask_weight, loss_weights, ) self.pred_masked_weight = pred_masked_weight self.pred_nomask_weight = pred_nomask_weight self.loss_weights = loss_weights if report_cer or report_wer: self.error_calculator = ErrorCalculator( token_list, sym_space, sym_blank, report_cer, report_wer ) else: self.error_calculator = None
[docs] def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert ( speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0] ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) batch_size = speech.shape[0] # for data-parallel text = text[:, : text_lengths.max()] # 1. Encoder encoder_out = self.encode(speech, speech_lengths, text, text_lengths) # 2a. Hubert criterion loss, acc_mask, acc_unmask = self._calc_hubert_loss( encoder_out, ) stats = dict( loss=loss.detach(), acc_mask=acc_mask, acc_unmask=acc_unmask, acc=acc_mask, ) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight
[docs] def collect_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, ) -> Dict[str, torch.Tensor]: feats, feats_lengths = self._extract_feats(speech, speech_lengths) return {"feats": feats, "feats_lengths": feats_lengths}
[docs] def encode( self, speech: torch.Tensor, speech_lengths: torch.Tensor, y_pad: torch.Tensor, y_pad_length: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by asr_inference.py Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) y_pad: (Batch, Length, ...) y_pad_length: (Batch, ) """ with autocast(False): # 1. Extract feats feats, feats_lengths = self._extract_feats(speech, speech_lengths) # 2. Data augmentation if self.specaug is not None and self.training: feats, feats_lengths = self.specaug(feats, feats_lengths) # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN if self.normalize is not None: feats, feats_lengths = self.normalize(feats, feats_lengths) # Pre-encoder, e.g. used for raw input data if self.preencoder is not None: feats, feats_lengths = self.preencoder(feats, feats_lengths) # 4. Forward encoder # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) encoder_out = self.encoder(feats, feats_lengths, y_pad, y_pad_length) if hasattr(self.encoder, "encoder"): logp_m_list = self.encoder.encoder.get_logits(encoder_out, True) assert self.pred_masked_weight == 0 or len(logp_m_list) > 0 logp_u_list = self.encoder.encoder.get_logits(encoder_out, False) assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0 return encoder_out
def _extract_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: assert speech_lengths.dim() == 1, speech_lengths.shape # for data-parallel speech = speech[:, : speech_lengths.max()] if self.frontend is not None: # Frontend # e.g. STFT and Feature extract # data_loader may send time-domain signal in this case # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) feats, feats_lengths = self.frontend(speech, speech_lengths) else: # No frontend and no feature extract feats, feats_lengths = speech, speech_lengths return feats, feats_lengths
[docs] def compute_correct( self, logits, ): if logits.numel() == 0: return 0, 0 else: assert logits.dim() > 1, logits.shape max = logits.argmax(-1) == 0 min = logits.argmin(-1) == 0 both = max & min corr = max.long().sum().item() - both.long().sum().item() count = max.numel() return corr, count
def _calc_hubert_loss( self, encoder_out: Dict[str, torch.Tensor], ): # 1. Compute attention loss loss_att, logp_m_list, logp_u_list = self.criterion_att( self.encoder.encoder, encoder_out ) corr_masked, count_masked = 0, 0 corr_unmask, count_unmask = 0, 0 with torch.no_grad(): for i, logp_m in enumerate(logp_m_list): corr_m, count_m = self.compute_correct(logp_m) corr_masked += corr_m count_masked += count_m for i, logp_u in enumerate(logp_u_list): corr_u, count_u = self.compute_correct(logp_u) corr_unmask += corr_u count_unmask += count_u acc_att_m = corr_masked / (count_masked + 1e-10) acc_att_u = corr_unmask / (count_unmask + 1e-10) return loss_att, acc_att_m, acc_att_u