Source code for espnet.nets.pytorch_backend.frontends.frontend

from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import numpy
import torch
import torch.nn as nn
from torch_complex.tensor import ComplexTensor

from espnet.nets.pytorch_backend.frontends.dnn_beamformer import DNN_Beamformer
from espnet.nets.pytorch_backend.frontends.dnn_wpe import DNN_WPE

[docs]class Frontend(nn.Module): def __init__( self, idim: int, # WPE options use_wpe: bool = False, wtype: str = "blstmp", wlayers: int = 3, wunits: int = 300, wprojs: int = 320, wdropout_rate: float = 0.0, taps: int = 5, delay: int = 3, use_dnn_mask_for_wpe: bool = True, # Beamformer options use_beamformer: bool = False, btype: str = "blstmp", blayers: int = 3, bunits: int = 300, bprojs: int = 320, bnmask: int = 2, badim: int = 320, ref_channel: int = -1, bdropout_rate=0.0, ): super().__init__() self.use_beamformer = use_beamformer self.use_wpe = use_wpe self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe # use frontend for all the data, # e.g. in the case of multi-speaker speech separation self.use_frontend_for_all = bnmask > 2 if self.use_wpe: if self.use_dnn_mask_for_wpe: # Use DNN for power estimation # (Not observed significant gains) iterations = 1 else: # Performing as conventional WPE, without DNN Estimator iterations = 2 self.wpe = DNN_WPE( wtype=wtype, widim=idim, wunits=wunits, wprojs=wprojs, wlayers=wlayers, taps=taps, delay=delay, dropout_rate=wdropout_rate, iterations=iterations, use_dnn_mask=use_dnn_mask_for_wpe, ) else: self.wpe = None if self.use_beamformer: self.beamformer = DNN_Beamformer( btype=btype, bidim=idim, bunits=bunits, bprojs=bprojs, blayers=blayers, bnmask=bnmask, dropout_rate=bdropout_rate, badim=badim, ref_channel=ref_channel, ) else: self.beamformer = None
[docs] def forward( self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]] ) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]: assert len(x) == len(ilens), (len(x), len(ilens)) # (B, T, F) or (B, T, C, F) if x.dim() not in (3, 4): raise ValueError(f"Input dim must be 3 or 4: {x.dim()}") if not torch.is_tensor(ilens): ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device) mask = None h = x if h.dim() == 4: if choices = [(False, False)] if not self.use_frontend_for_all else [] if self.use_wpe: choices.append((True, False)) if self.use_beamformer: choices.append((False, True)) use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))] else: use_wpe = self.use_wpe use_beamformer = self.use_beamformer # 1. WPE if use_wpe: # h: (B, T, C, F) -> h: (B, T, C, F) h, ilens, mask = self.wpe(h, ilens) # 2. Beamformer if use_beamformer: # h: (B, T, C, F) -> h: (B, T, F) h, ilens, mask = self.beamformer(h, ilens) return h, ilens, mask
[docs]def frontend_for(args, idim): return Frontend( idim=idim, # WPE options use_wpe=args.use_wpe, wtype=args.wtype, wlayers=args.wlayers, wunits=args.wunits, wprojs=args.wprojs, wdropout_rate=args.wdropout_rate, taps=args.wpe_taps, delay=args.wpe_delay, use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe, # Beamformer options use_beamformer=args.use_beamformer, btype=args.btype, blayers=args.blayers, bunits=args.bunits, bprojs=args.bprojs, bnmask=args.bnmask, badim=args.badim, ref_channel=args.ref_channel, bdropout_rate=args.bdropout_rate, )