Source code for espnet2.asr.transducer.beam_search_transducer

"""Search algorithms for Transducer models."""

from dataclasses import dataclass
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import numpy as np
import torch

from espnet.nets.pytorch_backend.transducer.utils import is_prefix
from espnet.nets.pytorch_backend.transducer.utils import recombine_hyps
from espnet.nets.pytorch_backend.transducer.utils import select_k_expansions
from espnet.nets.pytorch_backend.transducer.utils import subtract

from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet2.asr.transducer.joint_network import JointNetwork


[docs]@dataclass class Hypothesis: """Default hypothesis definition for Transducer search algorithms.""" score: float yseq: List[int] dec_state: Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]], torch.Tensor, ] lm_state: Union[Dict[str, Any], List[Any]] = None
[docs]@dataclass class ExtendedHypothesis(Hypothesis): """Extended hypothesis definition for NSC beam search and mAES.""" dec_out: List[torch.Tensor] = None lm_scores: torch.Tensor = None
[docs]class BeamSearchTransducer: """Beam search implementation for Transducer.""" def __init__( self, decoder: AbsDecoder, joint_network: JointNetwork, beam_size: int, lm: torch.nn.Module = None, lm_weight: float = 0.1, search_type: str = "default", max_sym_exp: int = 2, u_max: int = 50, nstep: int = 1, prefix_alpha: int = 1, expansion_gamma: int = 2.3, expansion_beta: int = 2, score_norm: bool = True, nbest: int = 1, ): """Initialize Transducer search module. Args: decoder: Decoder module. joint_network: Joint network module. beam_size: Beam size. lm: LM class. lm_weight: LM weight for soft fusion. search_type: Search algorithm to use during inference. max_sym_exp: Number of maximum symbol expansions at each time step. (TSD) u_max: Maximum output sequence length. (ALSD) nstep: Number of maximum expansion steps at each time step. (NSC/mAES) prefix_alpha: Maximum prefix length in prefix search. (NSC/mAES) expansion_beta: Number of additional candidates for expanded hypotheses selection. (mAES) expansion_gamma: Allowed logp difference for prune-by-value method. (mAES) score_norm: Normalize final scores by length. ("default") nbest: Number of final hypothesis. """ self.decoder = decoder self.joint_network = joint_network self.beam_size = beam_size self.hidden_size = decoder.dunits self.vocab_size = decoder.odim self.blank_id = decoder.blank_id if self.beam_size <= 1: self.search_algorithm = self.greedy_search elif search_type == "default": self.search_algorithm = self.default_beam_search elif search_type == "tsd": self.max_sym_exp = max_sym_exp self.search_algorithm = self.time_sync_decoding elif search_type == "alsd": self.u_max = u_max self.search_algorithm = self.align_length_sync_decoding elif search_type == "nsc": self.nstep = nstep self.prefix_alpha = prefix_alpha self.search_algorithm = self.nsc_beam_search elif search_type == "maes": self.nstep = nstep if nstep > 1 else 2 self.prefix_alpha = prefix_alpha self.expansion_gamma = expansion_gamma self.expansion_beta = expansion_beta self.search_algorithm = self.modified_adaptive_expansion_search else: raise NotImplementedError self.use_lm = lm is not None self.lm = lm self.lm_weight = lm_weight self.score_norm = score_norm self.nbest = nbest def __call__( self, enc_out: torch.Tensor ) -> Union[List[Hypothesis], List[ExtendedHypothesis]]: """Perform beam search. Args: enc_out: Encoder output sequence. (T, D_enc) Returns: nbest_hyps: N-best decoding results """ self.decoder.set_device(enc_out.device) nbest_hyps = self.search_algorithm(enc_out) return nbest_hyps
[docs] def sort_nbest( self, hyps: Union[List[Hypothesis], List[ExtendedHypothesis]] ) -> Union[List[Hypothesis], List[ExtendedHypothesis]]: """Sort hypotheses by score or score given sequence length. Args: hyps: Hypothesis. Return: hyps: Sorted hypothesis. """ if self.score_norm: hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True) else: hyps.sort(key=lambda x: x.score, reverse=True) return hyps[: self.nbest]
[docs] def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]: """Time synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: enc_out: Encoder output sequence. (T, D) Returns: nbest_hyps: N-best hypothesis. """ beam = min(self.beam_size, self.vocab_size) beam_state = self.decoder.init_state(beam) B = [ Hypothesis( yseq=[self.blank_id], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] cache = {} if self.use_lm: B[0].lm_state = self.lm.zero_state() for enc_out_t in enc_out: A = [] C = B enc_out_t = enc_out_t.unsqueeze(0) for v in range(self.max_sym_exp): D = [] beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score( C, beam_state, cache, self.use_lm, ) beam_logp = torch.log_softmax( self.joint_network(enc_out_t, beam_dec_out), dim=-1, ) beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) seq_A = [h.yseq for h in A] for i, hyp in enumerate(C): if hyp.yseq not in seq_A: A.append( Hypothesis( score=(hyp.score + float(beam_logp[i, 0])), yseq=hyp.yseq[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, ) ) else: dict_pos = seq_A.index(hyp.yseq) A[dict_pos].score = np.logaddexp( A[dict_pos].score, (hyp.score + float(beam_logp[i, 0])) ) if v < (self.max_sym_exp - 1): if self.use_lm: beam_lm_scores, beam_lm_states = self.lm.batch_score( beam_lm_tokens, [c.lm_state for c in C], None ) for i, hyp in enumerate(C): for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): new_hyp = Hypothesis( score=(hyp.score + float(logp)), yseq=(hyp.yseq + [int(k)]), dec_state=self.decoder.select_state(beam_state, i), lm_state=hyp.lm_state, ) if self.use_lm: new_hyp.score += self.lm_weight * beam_lm_scores[i, k] new_hyp.lm_state = beam_lm_states[i] D.append(new_hyp) C = sorted(D, key=lambda x: x.score, reverse=True)[:beam] B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] return self.sort_nbest(B)
[docs] def align_length_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]: """Alignment-length synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: h: Encoder output sequences. (T, D) Returns: nbest_hyps: N-best hypothesis. """ beam = min(self.beam_size, self.vocab_size) t_max = int(enc_out.size(0)) u_max = min(self.u_max, (t_max - 1)) beam_state = self.decoder.init_state(beam) B = [ Hypothesis( yseq=[self.blank_id], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] final = [] cache = {} if self.use_lm: B[0].lm_state = self.lm.zero_state() for i in range(t_max + u_max): A = [] B_ = [] B_enc_out = [] for hyp in B: u = len(hyp.yseq) - 1 t = i - u if t > (t_max - 1): continue B_.append(hyp) B_enc_out.append((t, enc_out[t])) if B_: beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score( B_, beam_state, cache, self.use_lm, ) beam_enc_out = torch.stack([x[1] for x in B_enc_out]) beam_logp = torch.log_softmax( self.joint_network(beam_enc_out, beam_dec_out), dim=-1, ) beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) if self.use_lm: beam_lm_scores, beam_lm_states = self.lm.batch_score( beam_lm_tokens, [b.lm_state for b in B_], None, ) for i, hyp in enumerate(B_): new_hyp = Hypothesis( score=(hyp.score + float(beam_logp[i, 0])), yseq=hyp.yseq[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, ) A.append(new_hyp) if B_enc_out[i][0] == (t_max - 1): final.append(new_hyp) for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): new_hyp = Hypothesis( score=(hyp.score + float(logp)), yseq=(hyp.yseq[:] + [int(k)]), dec_state=self.decoder.select_state(beam_state, i), lm_state=hyp.lm_state, ) if self.use_lm: new_hyp.score += self.lm_weight * beam_lm_scores[i, k] new_hyp.lm_state = beam_lm_states[i] A.append(new_hyp) B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] B = recombine_hyps(B) if final: return self.sort_nbest(final) else: return B