Source code for espnet2.enh.layers.dnn_beamformer

"""DNN beamformer module."""
from distutils.version import LooseVersion
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import logging
import torch
from torch.nn import functional as F
from torch_complex.tensor import ComplexTensor

from espnet2.enh.layers.beamformer import apply_beamforming_vector
from espnet2.enh.layers.beamformer import blind_analytic_normalization
from espnet2.enh.layers.beamformer import get_gev_vector
from espnet2.enh.layers.beamformer import get_lcmv_vector_with_rtf
from espnet2.enh.layers.beamformer import get_mvdr_vector
from espnet2.enh.layers.beamformer import get_mvdr_vector_with_rtf
from espnet2.enh.layers.beamformer import get_mwf_vector
from espnet2.enh.layers.beamformer import get_rank1_mwf_vector
from espnet2.enh.layers.beamformer import get_rtf_matrix
from espnet2.enh.layers.beamformer import get_sdw_mwf_vector
from espnet2.enh.layers.beamformer import get_WPD_filter_v2
from espnet2.enh.layers.beamformer import get_WPD_filter_with_rtf
from espnet2.enh.layers.beamformer import perform_WPD_filtering
from espnet2.enh.layers.beamformer import prepare_beamformer_stats
from espnet2.enh.layers.complex_utils import stack
from espnet2.enh.layers.complex_utils import to_double
from espnet2.enh.layers.complex_utils import to_float
from espnet2.enh.layers.mask_estimator import MaskEstimator


is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")

BEAMFORMER_TYPES = (
    # Minimum Variance Distortionless Response beamformer
    "mvdr",  # RTF-based formula
    "mvdr_souden",  # Souden's solution
    # Minimum Power Distortionless Response beamformer
    "mpdr",  # RTF-based formula
    "mpdr_souden",  # Souden's solution
    # weighted MPDR beamformer
    "wmpdr",  # RTF-based formula
    "wmpdr_souden",  # Souden's solution
    # Weighted Power minimization Distortionless response beamformer
    "wpd",  # RTF-based formula
    "wpd_souden",  # Souden's solution
    # Multi-channel Wiener Filter (MWF) and weighted MWF
    "mwf",
    "wmwf",
    # Speech Distortion Weighted (SDW) MWF
    "sdw_mwf",
    # Rank-1 MWF
    "r1mwf",
    # Linearly Constrained Minimum Variance beamformer
    "lcmv",
    # Linearly Constrained Minimum Power beamformer
    "lcmp",
    # weighted Linearly Constrained Minimum Power beamformer
    "wlcmp",
    # Generalized Eigenvalue beamformer
    "gev",
    "gev_ban",  # with blind analytic normalization (BAN) post-filtering
    # time-frequency-bin-wise switching (TFS) MVDR beamformer
    "mvdr_tfs",
    "mvdr_tfs_souden",
)


[docs]class DNN_Beamformer(torch.nn.Module): """DNN mask based Beamformer. Citation: Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017; http://proceedings.mlr.press/v70/ochiai17a/ochiai17a.pdf """ def __init__( self, bidim, btype: str = "blstmp", blayers: int = 3, bunits: int = 300, bprojs: int = 320, num_spk: int = 1, use_noise_mask: bool = True, nonlinear: str = "sigmoid", dropout_rate: float = 0.0, badim: int = 320, ref_channel: int = -1, beamformer_type: str = "mvdr_souden", rtf_iterations: int = 2, mwf_mu: float = 1.0, eps: float = 1e-6, diagonal_loading: bool = True, diag_eps: float = 1e-7, mask_flooring: bool = False, flooring_thres: float = 1e-6, use_torch_solver: bool = True, # only for WPD beamformer btaps: int = 5, bdelay: int = 3, ): super().__init__() bnmask = num_spk + 1 if use_noise_mask else num_spk self.mask = MaskEstimator( btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask, nonlinear=nonlinear, ) self.ref = ( AttentionReference(bidim, badim, eps=eps) if ref_channel < 0 else None ) self.ref_channel = ref_channel self.use_noise_mask = use_noise_mask assert num_spk >= 1, num_spk self.num_spk = num_spk self.nmask = bnmask if beamformer_type not in BEAMFORMER_TYPES: raise ValueError("Not supporting beamformer_type=%s" % beamformer_type) if ( beamformer_type == "mvdr_souden" or not beamformer_type.endswith("_souden") ) and not use_noise_mask: if num_spk == 1: logging.warning( "Initializing %s beamformer without noise mask " "estimator (single-speaker case)" % beamformer_type.upper() ) logging.warning( "(1 - speech_mask) will be used for estimating noise " "PSD in %s beamformer!" % beamformer_type.upper() ) else: logging.warning( "Initializing %s beamformer without noise mask " "estimator (multi-speaker case)" % beamformer_type.upper() ) logging.warning( "Interference speech masks will be used for estimating " "noise PSD in %s beamformer!" % beamformer_type.upper() ) self.beamformer_type = beamformer_type if not beamformer_type.endswith("_souden"): assert rtf_iterations >= 2, rtf_iterations # number of iterations in power method for estimating the RTF self.rtf_iterations = rtf_iterations # noise suppression weight in SDW-MWF self.mwf_mu = mwf_mu assert btaps >= 0 and bdelay >= 0, (btaps, bdelay) self.btaps = btaps self.bdelay = bdelay if self.btaps > 0 else 1 self.eps = eps self.diagonal_loading = diagonal_loading self.diag_eps = diag_eps self.mask_flooring = mask_flooring self.flooring_thres = flooring_thres self.use_torch_solver = use_torch_solver
[docs] def forward( self, data: Union[torch.Tensor, ComplexTensor], ilens: torch.LongTensor, powers: Optional[List[torch.Tensor]] = None, oracle_masks: Optional[List[torch.Tensor]] = None, ) -> Tuple[Union[torch.Tensor, ComplexTensor], torch.LongTensor, torch.Tensor]: """DNN_Beamformer forward function. Notation: B: Batch C: Channel T: Time or Sequence length F: Freq Args: data (torch.complex64/ComplexTensor): (B, T, C, F) ilens (torch.Tensor): (B,) powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T) oracle_masks (List[torch.Tensor] or None): oracle masks (B, F, C, T) if not None, oracle_masks will be used instead of self.mask Returns: enhanced (torch.complex64/ComplexTensor): (B, T, F) ilens (torch.Tensor): (B,) masks (torch.Tensor): (B, T, C, F) """ # data (B, T, C, F) -> (B, F, C, T) data = data.permute(0, 3, 2, 1) data_d = to_double(data) # mask: [(B, F, C, T)] if oracle_masks is not None: masks = oracle_masks else: masks, _ = self.mask(data, ilens) assert self.nmask == len(masks), len(masks) # floor masks to increase numerical stability if self.mask_flooring: masks = [torch.clamp(m, min=self.flooring_thres) for m in masks] if self.num_spk == 1: # single-speaker case if self.use_noise_mask: # (mask_speech, mask_noise) mask_speech, mask_noise = masks else: # (mask_speech,) mask_speech = masks[0] mask_noise = 1 - mask_speech if self.beamformer_type in ("lcmv", "lcmp", "wlcmp"): raise NotImplementedError("Single source is not supported yet") beamformer_stats = prepare_beamformer_stats( data_d, [mask_speech], mask_noise, powers=powers, beamformer_type=self.beamformer_type, bdelay=self.bdelay, btaps=self.btaps, eps=self.eps, ) if self.beamformer_type in ("mvdr", "mpdr", "wmpdr", "wpd"): enhanced, ws = self.apply_beamforming( data, ilens, beamformer_stats["psd_n"], beamformer_stats["psd_speech"], psd_distortion=beamformer_stats["psd_distortion"], ) elif ( self.beamformer_type.endswith("_souden") or self.beamformer_type == "mwf" or self.beamformer_type == "wmwf" or self.beamformer_type == "sdw_mwf" or self.beamformer_type == "r1mwf" or self.beamformer_type.startswith("gev") ): enhanced, ws = self.apply_beamforming( data, ilens, beamformer_stats["psd_n"], beamformer_stats["psd_speech"], ) else: raise ValueError( "Not supporting beamformer_type={}".format(self.beamformer_type) ) # (..., F, T) -> (..., T, F) enhanced = enhanced.transpose(-1, -2) else: # multi-speaker case if self.use_noise_mask: # (mask_speech1, ..., mask_noise) mask_speech = list(masks[:-1]) mask_noise = masks[-1] else: # (mask_speech1, ..., mask_speechX) mask_speech = list(masks) mask_noise = None beamformer_stats = prepare_beamformer_stats( data_d, mask_speech, mask_noise, powers=powers, beamformer_type=self.beamformer_type, bdelay=self.bdelay, btaps=self.btaps, eps=self.eps, ) if self.beamformer_type in ("lcmv", "lcmp", "wlcmp"): rtf_mat = get_rtf_matrix( beamformer_stats["psd_speech"], beamformer_stats["psd_distortion"], diagonal_loading=self.diagonal_loading, ref_channel=self.ref_channel, rtf_iterations=self.rtf_iterations, use_torch_solver=self.use_torch_solver, diag_eps=self.diag_eps, ) enhanced, ws = [], [] for i in range(self.num_spk): # treat all other speakers' psd_speech as noises if self.beamformer_type in ("mvdr", "mvdr_tfs", "wmpdr", "wpd"): enh, w = self.apply_beamforming( data, ilens, beamformer_stats["psd_n"][i], beamformer_stats["psd_speech"][i], psd_distortion=beamformer_stats["psd_distortion"][i], ) elif self.beamformer_type in ( "mvdr_souden", "mvdr_tfs_souden", "wmpdr_souden", "wpd_souden", "wmwf", "sdw_mwf", "r1mwf", "gev", "gev_ban", ): enh, w = self.apply_beamforming( data, ilens, beamformer_stats["psd_n"][i], beamformer_stats["psd_speech"][i], ) elif self.beamformer_type == "mpdr": enh, w = self.apply_beamforming( data, ilens, beamformer_stats["psd_n"], beamformer_stats["psd_speech"][i], psd_distortion=beamformer_stats["psd_distortion"][i], ) elif self.beamformer_type in ("mpdr_souden", "mwf"): enh, w = self.apply_beamforming( data, ilens, beamformer_stats["psd_n"], beamformer_stats["psd_speech"][i], ) elif self.beamformer_type == "lcmp": enh, w = self.apply_beamforming( data, ilens, beamformer_stats["psd_n"], beamformer_stats["psd_speech"][i], rtf_mat=rtf_mat, spk=i, ) elif self.beamformer_type in ("lcmv", "wlcmp"): enh, w = self.apply_beamforming( data, ilens, beamformer_stats["psd_n"][i], beamformer_stats["psd_speech"][i], rtf_mat=rtf_mat, spk=i, ) else: raise ValueError( "Not supporting beamformer_type={}".format(self.beamformer_type) ) # (..., F, T) -> (..., T, F) enh = enh.transpose(-1, -2) enhanced.append(enh) ws.append(w) # (..., F, C, T) -> (..., T, C, F) masks = [m.transpose(-1, -3) for m in masks] return enhanced, ilens, masks
[docs] def apply_beamforming( self, data, ilens, psd_n, psd_speech, psd_distortion=None, rtf_mat=None, spk=0, ): """Beamforming with the provided statistics. Args: data (torch.complex64/ComplexTensor): (B, F, C, T) ilens (torch.Tensor): (B,) psd_n (torch.complex64/ComplexTensor): Noise covariance matrix for MVDR (B, F, C, C) Observation covariance matrix for MPDR/wMPDR (B, F, C, C) Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C) psd_speech (torch.complex64/ComplexTensor): Speech covariance matrix (B, F, C, C) psd_distortion (torch.complex64/ComplexTensor): Noise covariance matrix (B, F, C, C) rtf_mat (torch.complex64/ComplexTensor): RTF matrix (B, F, C, num_spk) spk (int): speaker index Return: enhanced (torch.complex64/ComplexTensor): (B, F, T) ws (torch.complex64/ComplexTensor): (B, F) or (B, F, (btaps+1)*C) """ # u: (B, C) if self.ref_channel < 0: u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens) u = u.double() else: if self.beamformer_type.endswith("_souden"): # (optional) Create onehot vector for fixed reference microphone u = torch.zeros( *(data.size()[:-3] + (data.size(-2),)), device=data.device, dtype=torch.double ) u[..., self.ref_channel].fill_(1) else: # for simplifying computation in RTF-based beamforming u = self.ref_channel if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"): ws = get_mvdr_vector_with_rtf( to_double(psd_n), to_double(psd_speech), to_double(psd_distortion), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) elif self.beamformer_type == "mvdr_tfs": assert isinstance(psd_n, (list, tuple)) ws = [ get_mvdr_vector_with_rtf( to_double(psd_n_i), to_double(psd_speech), to_double(psd_distortion), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) for psd_n_i in psd_n ] enhanced = stack([apply_beamforming_vector(w, to_double(data)) for w in ws]) with torch.no_grad(): index = enhanced.abs().argmin(dim=0, keepdims=True) enhanced = enhanced.gather(0, index).squeeze(0) ws = stack(ws, dim=0) elif self.beamformer_type in ( "mpdr_souden", "mvdr_souden", "wmpdr_souden", ): ws = get_mvdr_vector( to_double(psd_speech), to_double(psd_n), u, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) elif self.beamformer_type == "mvdr_tfs_souden": assert isinstance(psd_n, (list, tuple)) ws = [ get_mvdr_vector( to_double(psd_speech), to_double(psd_n_i), u, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) for psd_n_i in psd_n ] enhanced = stack([apply_beamforming_vector(w, to_double(data)) for w in ws]) with torch.no_grad(): index = enhanced.abs().argmin(dim=0, keepdims=True) enhanced = enhanced.gather(0, index).squeeze(0) ws = stack(ws, dim=0) elif self.beamformer_type == "wpd": ws = get_WPD_filter_with_rtf( to_double(psd_n), to_double(psd_speech), to_double(psd_distortion), iterations=self.rtf_iterations, reference_vector=u, normalize_ref_channel=self.ref_channel, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = perform_WPD_filtering( ws, to_double(data), self.bdelay, self.btaps ) elif self.beamformer_type == "wpd_souden": ws = get_WPD_filter_v2( to_double(psd_speech), to_double(psd_n), u, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = perform_WPD_filtering( ws, to_double(data), self.bdelay, self.btaps ) elif self.beamformer_type in ("mwf", "wmwf"): ws = get_mwf_vector( to_double(psd_speech), to_double(psd_n), u, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) elif self.beamformer_type == "sdw_mwf": ws = get_sdw_mwf_vector( to_double(psd_speech), to_double(psd_n), u, denoising_weight=self.mwf_mu, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) elif self.beamformer_type == "r1mwf": ws = get_rank1_mwf_vector( to_double(psd_speech), to_double(psd_n), u, denoising_weight=self.mwf_mu, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) elif self.beamformer_type in ("lcmp", "wlcmp", "lcmv"): ws = get_lcmv_vector_with_rtf( to_double(psd_n), to_double(rtf_mat), reference_vector=spk, use_torch_solver=self.use_torch_solver, diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) elif self.beamformer_type.startswith("gev"): ws = get_gev_vector( to_double(psd_n), to_double(psd_speech), mode="power", diagonal_loading=self.diagonal_loading, diag_eps=self.diag_eps, ) enhanced = apply_beamforming_vector(ws, to_double(data)) if self.beamformer_type == "gev_ban": gain = blind_analytic_normalization(ws, to_double(psd_n)) enhanced = enhanced * gain.unsqueeze(-1) else: raise ValueError( "Not supporting beamformer_type={}".format(self.beamformer_type) ) return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype)
[docs] def predict_mask( self, data: Union[torch.Tensor, ComplexTensor], ilens: torch.LongTensor ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]: """Predict masks for beamforming. Args: data (torch.complex64/ComplexTensor): (B, T, C, F), double precision ilens (torch.Tensor): (B,) Returns: masks (torch.Tensor): (B, T, C, F) ilens (torch.Tensor): (B,) """ masks, _ = self.mask(to_float(data.permute(0, 3, 2, 1)), ilens) # (B, F, C, T) -> (B, T, C, F) masks = [m.transpose(-1, -3) for m in masks] return masks, ilens
[docs]class AttentionReference(torch.nn.Module): def __init__(self, bidim, att_dim, eps=1e-6): super().__init__() self.mlp_psd = torch.nn.Linear(bidim, att_dim) self.gvec = torch.nn.Linear(att_dim, 1) self.eps = eps
[docs] def forward( self, psd_in: Union[torch.Tensor, ComplexTensor], ilens: torch.LongTensor, scaling: float = 2.0, ) -> Tuple[torch.Tensor, torch.LongTensor]: """Attention-based reference forward function. Args: psd_in (torch.complex64/ComplexTensor): (B, F, C, C) ilens (torch.Tensor): (B,) scaling (float): Returns: u (torch.Tensor): (B, C) ilens (torch.Tensor): (B,) """ B, _, C = psd_in.size()[:3] assert psd_in.size(2) == psd_in.size(3), psd_in.size() # psd_in: (B, F, C, C) psd = psd_in.masked_fill( torch.eye(C, dtype=torch.bool, device=psd_in.device).type(torch.bool), 0 ) # psd: (B, F, C, C) -> (B, C, F) psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2) # Calculate amplitude psd_feat = (psd.real**2 + psd.imag**2 + self.eps) ** 0.5 # (B, C, F) -> (B, C, F2) mlp_psd = self.mlp_psd(psd_feat) # (B, C, F2) -> (B, C, 1) -> (B, C) e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1) u = F.softmax(scaling * e, dim=-1) return u, ilens