Source code for espnet2.enh.layers.dnn_beamformer
"""DNN beamformer module."""
from distutils.version import LooseVersion
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import logging
import torch
from torch.nn import functional as F
from torch_complex.tensor import ComplexTensor
from espnet2.enh.layers.beamformer import apply_beamforming_vector
from espnet2.enh.layers.beamformer import blind_analytic_normalization
from espnet2.enh.layers.beamformer import get_gev_vector
from espnet2.enh.layers.beamformer import get_lcmv_vector_with_rtf
from espnet2.enh.layers.beamformer import get_mvdr_vector
from espnet2.enh.layers.beamformer import get_mvdr_vector_with_rtf
from espnet2.enh.layers.beamformer import get_mwf_vector
from espnet2.enh.layers.beamformer import get_rank1_mwf_vector
from espnet2.enh.layers.beamformer import get_rtf_matrix
from espnet2.enh.layers.beamformer import get_sdw_mwf_vector
from espnet2.enh.layers.beamformer import get_WPD_filter_v2
from espnet2.enh.layers.beamformer import get_WPD_filter_with_rtf
from espnet2.enh.layers.beamformer import perform_WPD_filtering
from espnet2.enh.layers.beamformer import prepare_beamformer_stats
from espnet2.enh.layers.complex_utils import stack
from espnet2.enh.layers.complex_utils import to_double
from espnet2.enh.layers.complex_utils import to_float
from espnet2.enh.layers.mask_estimator import MaskEstimator
is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
BEAMFORMER_TYPES = (
# Minimum Variance Distortionless Response beamformer
"mvdr", # RTF-based formula
"mvdr_souden", # Souden's solution
# Minimum Power Distortionless Response beamformer
"mpdr", # RTF-based formula
"mpdr_souden", # Souden's solution
# weighted MPDR beamformer
"wmpdr", # RTF-based formula
"wmpdr_souden", # Souden's solution
# Weighted Power minimization Distortionless response beamformer
"wpd", # RTF-based formula
"wpd_souden", # Souden's solution
# Multi-channel Wiener Filter (MWF) and weighted MWF
"mwf",
"wmwf",
# Speech Distortion Weighted (SDW) MWF
"sdw_mwf",
# Rank-1 MWF
"r1mwf",
# Linearly Constrained Minimum Variance beamformer
"lcmv",
# Linearly Constrained Minimum Power beamformer
"lcmp",
# weighted Linearly Constrained Minimum Power beamformer
"wlcmp",
# Generalized Eigenvalue beamformer
"gev",
"gev_ban", # with blind analytic normalization (BAN) post-filtering
# time-frequency-bin-wise switching (TFS) MVDR beamformer
"mvdr_tfs",
"mvdr_tfs_souden",
)
[docs]class DNN_Beamformer(torch.nn.Module):
"""DNN mask based Beamformer.
Citation:
Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
http://proceedings.mlr.press/v70/ochiai17a/ochiai17a.pdf
"""
def __init__(
self,
bidim,
btype: str = "blstmp",
blayers: int = 3,
bunits: int = 300,
bprojs: int = 320,
num_spk: int = 1,
use_noise_mask: bool = True,
nonlinear: str = "sigmoid",
dropout_rate: float = 0.0,
badim: int = 320,
ref_channel: int = -1,
beamformer_type: str = "mvdr_souden",
rtf_iterations: int = 2,
mwf_mu: float = 1.0,
eps: float = 1e-6,
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
mask_flooring: bool = False,
flooring_thres: float = 1e-6,
use_torch_solver: bool = True,
# only for WPD beamformer
btaps: int = 5,
bdelay: int = 3,
):
super().__init__()
bnmask = num_spk + 1 if use_noise_mask else num_spk
self.mask = MaskEstimator(
btype,
bidim,
blayers,
bunits,
bprojs,
dropout_rate,
nmask=bnmask,
nonlinear=nonlinear,
)
self.ref = (
AttentionReference(bidim, badim, eps=eps) if ref_channel < 0 else None
)
self.ref_channel = ref_channel
self.use_noise_mask = use_noise_mask
assert num_spk >= 1, num_spk
self.num_spk = num_spk
self.nmask = bnmask
if beamformer_type not in BEAMFORMER_TYPES:
raise ValueError("Not supporting beamformer_type=%s" % beamformer_type)
if (
beamformer_type == "mvdr_souden" or not beamformer_type.endswith("_souden")
) and not use_noise_mask:
if num_spk == 1:
logging.warning(
"Initializing %s beamformer without noise mask "
"estimator (single-speaker case)" % beamformer_type.upper()
)
logging.warning(
"(1 - speech_mask) will be used for estimating noise "
"PSD in %s beamformer!" % beamformer_type.upper()
)
else:
logging.warning(
"Initializing %s beamformer without noise mask "
"estimator (multi-speaker case)" % beamformer_type.upper()
)
logging.warning(
"Interference speech masks will be used for estimating "
"noise PSD in %s beamformer!" % beamformer_type.upper()
)
self.beamformer_type = beamformer_type
if not beamformer_type.endswith("_souden"):
assert rtf_iterations >= 2, rtf_iterations
# number of iterations in power method for estimating the RTF
self.rtf_iterations = rtf_iterations
# noise suppression weight in SDW-MWF
self.mwf_mu = mwf_mu
assert btaps >= 0 and bdelay >= 0, (btaps, bdelay)
self.btaps = btaps
self.bdelay = bdelay if self.btaps > 0 else 1
self.eps = eps
self.diagonal_loading = diagonal_loading
self.diag_eps = diag_eps
self.mask_flooring = mask_flooring
self.flooring_thres = flooring_thres
self.use_torch_solver = use_torch_solver
[docs] def forward(
self,
data: Union[torch.Tensor, ComplexTensor],
ilens: torch.LongTensor,
powers: Optional[List[torch.Tensor]] = None,
oracle_masks: Optional[List[torch.Tensor]] = None,
) -> Tuple[Union[torch.Tensor, ComplexTensor], torch.LongTensor, torch.Tensor]:
"""DNN_Beamformer forward function.
Notation:
B: Batch
C: Channel
T: Time or Sequence length
F: Freq
Args:
data (torch.complex64/ComplexTensor): (B, T, C, F)
ilens (torch.Tensor): (B,)
powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T)
oracle_masks (List[torch.Tensor] or None): oracle masks (B, F, C, T)
if not None, oracle_masks will be used instead of self.mask
Returns:
enhanced (torch.complex64/ComplexTensor): (B, T, F)
ilens (torch.Tensor): (B,)
masks (torch.Tensor): (B, T, C, F)
"""
# data (B, T, C, F) -> (B, F, C, T)
data = data.permute(0, 3, 2, 1)
data_d = to_double(data)
# mask: [(B, F, C, T)]
if oracle_masks is not None:
masks = oracle_masks
else:
masks, _ = self.mask(data, ilens)
assert self.nmask == len(masks), len(masks)
# floor masks to increase numerical stability
if self.mask_flooring:
masks = [torch.clamp(m, min=self.flooring_thres) for m in masks]
if self.num_spk == 1: # single-speaker case
if self.use_noise_mask:
# (mask_speech, mask_noise)
mask_speech, mask_noise = masks
else:
# (mask_speech,)
mask_speech = masks[0]
mask_noise = 1 - mask_speech
if self.beamformer_type in ("lcmv", "lcmp", "wlcmp"):
raise NotImplementedError("Single source is not supported yet")
beamformer_stats = prepare_beamformer_stats(
data_d,
[mask_speech],
mask_noise,
powers=powers,
beamformer_type=self.beamformer_type,
bdelay=self.bdelay,
btaps=self.btaps,
eps=self.eps,
)
if self.beamformer_type in ("mvdr", "mpdr", "wmpdr", "wpd"):
enhanced, ws = self.apply_beamforming(
data,
ilens,
beamformer_stats["psd_n"],
beamformer_stats["psd_speech"],
psd_distortion=beamformer_stats["psd_distortion"],
)
elif (
self.beamformer_type.endswith("_souden")
or self.beamformer_type == "mwf"
or self.beamformer_type == "wmwf"
or self.beamformer_type == "sdw_mwf"
or self.beamformer_type == "r1mwf"
or self.beamformer_type.startswith("gev")
):
enhanced, ws = self.apply_beamforming(
data,
ilens,
beamformer_stats["psd_n"],
beamformer_stats["psd_speech"],
)
else:
raise ValueError(
"Not supporting beamformer_type={}".format(self.beamformer_type)
)
# (..., F, T) -> (..., T, F)
enhanced = enhanced.transpose(-1, -2)
else: # multi-speaker case
if self.use_noise_mask:
# (mask_speech1, ..., mask_noise)
mask_speech = list(masks[:-1])
mask_noise = masks[-1]
else:
# (mask_speech1, ..., mask_speechX)
mask_speech = list(masks)
mask_noise = None
beamformer_stats = prepare_beamformer_stats(
data_d,
mask_speech,
mask_noise,
powers=powers,
beamformer_type=self.beamformer_type,
bdelay=self.bdelay,
btaps=self.btaps,
eps=self.eps,
)
if self.beamformer_type in ("lcmv", "lcmp", "wlcmp"):
rtf_mat = get_rtf_matrix(
beamformer_stats["psd_speech"],
beamformer_stats["psd_distortion"],
diagonal_loading=self.diagonal_loading,
ref_channel=self.ref_channel,
rtf_iterations=self.rtf_iterations,
use_torch_solver=self.use_torch_solver,
diag_eps=self.diag_eps,
)
enhanced, ws = [], []
for i in range(self.num_spk):
# treat all other speakers' psd_speech as noises
if self.beamformer_type in ("mvdr", "mvdr_tfs", "wmpdr", "wpd"):
enh, w = self.apply_beamforming(
data,
ilens,
beamformer_stats["psd_n"][i],
beamformer_stats["psd_speech"][i],
psd_distortion=beamformer_stats["psd_distortion"][i],
)
elif self.beamformer_type in (
"mvdr_souden",
"mvdr_tfs_souden",
"wmpdr_souden",
"wpd_souden",
"wmwf",
"sdw_mwf",
"r1mwf",
"gev",
"gev_ban",
):
enh, w = self.apply_beamforming(
data,
ilens,
beamformer_stats["psd_n"][i],
beamformer_stats["psd_speech"][i],
)
elif self.beamformer_type == "mpdr":
enh, w = self.apply_beamforming(
data,
ilens,
beamformer_stats["psd_n"],
beamformer_stats["psd_speech"][i],
psd_distortion=beamformer_stats["psd_distortion"][i],
)
elif self.beamformer_type in ("mpdr_souden", "mwf"):
enh, w = self.apply_beamforming(
data,
ilens,
beamformer_stats["psd_n"],
beamformer_stats["psd_speech"][i],
)
elif self.beamformer_type == "lcmp":
enh, w = self.apply_beamforming(
data,
ilens,
beamformer_stats["psd_n"],
beamformer_stats["psd_speech"][i],
rtf_mat=rtf_mat,
spk=i,
)
elif self.beamformer_type in ("lcmv", "wlcmp"):
enh, w = self.apply_beamforming(
data,
ilens,
beamformer_stats["psd_n"][i],
beamformer_stats["psd_speech"][i],
rtf_mat=rtf_mat,
spk=i,
)
else:
raise ValueError(
"Not supporting beamformer_type={}".format(self.beamformer_type)
)
# (..., F, T) -> (..., T, F)
enh = enh.transpose(-1, -2)
enhanced.append(enh)
ws.append(w)
# (..., F, C, T) -> (..., T, C, F)
masks = [m.transpose(-1, -3) for m in masks]
return enhanced, ilens, masks
[docs] def apply_beamforming(
self,
data,
ilens,
psd_n,
psd_speech,
psd_distortion=None,
rtf_mat=None,
spk=0,
):
"""Beamforming with the provided statistics.
Args:
data (torch.complex64/ComplexTensor): (B, F, C, T)
ilens (torch.Tensor): (B,)
psd_n (torch.complex64/ComplexTensor):
Noise covariance matrix for MVDR (B, F, C, C)
Observation covariance matrix for MPDR/wMPDR (B, F, C, C)
Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C)
psd_speech (torch.complex64/ComplexTensor):
Speech covariance matrix (B, F, C, C)
psd_distortion (torch.complex64/ComplexTensor):
Noise covariance matrix (B, F, C, C)
rtf_mat (torch.complex64/ComplexTensor):
RTF matrix (B, F, C, num_spk)
spk (int): speaker index
Return:
enhanced (torch.complex64/ComplexTensor): (B, F, T)
ws (torch.complex64/ComplexTensor): (B, F) or (B, F, (btaps+1)*C)
"""
# u: (B, C)
if self.ref_channel < 0:
u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens)
u = u.double()
else:
if self.beamformer_type.endswith("_souden"):
# (optional) Create onehot vector for fixed reference microphone
u = torch.zeros(
*(data.size()[:-3] + (data.size(-2),)),
device=data.device,
dtype=torch.double
)
u[..., self.ref_channel].fill_(1)
else:
# for simplifying computation in RTF-based beamforming
u = self.ref_channel
if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"):
ws = get_mvdr_vector_with_rtf(
to_double(psd_n),
to_double(psd_speech),
to_double(psd_distortion),
iterations=self.rtf_iterations,
reference_vector=u,
normalize_ref_channel=self.ref_channel,
use_torch_solver=self.use_torch_solver,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = apply_beamforming_vector(ws, to_double(data))
elif self.beamformer_type == "mvdr_tfs":
assert isinstance(psd_n, (list, tuple))
ws = [
get_mvdr_vector_with_rtf(
to_double(psd_n_i),
to_double(psd_speech),
to_double(psd_distortion),
iterations=self.rtf_iterations,
reference_vector=u,
normalize_ref_channel=self.ref_channel,
use_torch_solver=self.use_torch_solver,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
for psd_n_i in psd_n
]
enhanced = stack([apply_beamforming_vector(w, to_double(data)) for w in ws])
with torch.no_grad():
index = enhanced.abs().argmin(dim=0, keepdims=True)
enhanced = enhanced.gather(0, index).squeeze(0)
ws = stack(ws, dim=0)
elif self.beamformer_type in (
"mpdr_souden",
"mvdr_souden",
"wmpdr_souden",
):
ws = get_mvdr_vector(
to_double(psd_speech),
to_double(psd_n),
u,
use_torch_solver=self.use_torch_solver,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = apply_beamforming_vector(ws, to_double(data))
elif self.beamformer_type == "mvdr_tfs_souden":
assert isinstance(psd_n, (list, tuple))
ws = [
get_mvdr_vector(
to_double(psd_speech),
to_double(psd_n_i),
u,
use_torch_solver=self.use_torch_solver,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
for psd_n_i in psd_n
]
enhanced = stack([apply_beamforming_vector(w, to_double(data)) for w in ws])
with torch.no_grad():
index = enhanced.abs().argmin(dim=0, keepdims=True)
enhanced = enhanced.gather(0, index).squeeze(0)
ws = stack(ws, dim=0)
elif self.beamformer_type == "wpd":
ws = get_WPD_filter_with_rtf(
to_double(psd_n),
to_double(psd_speech),
to_double(psd_distortion),
iterations=self.rtf_iterations,
reference_vector=u,
normalize_ref_channel=self.ref_channel,
use_torch_solver=self.use_torch_solver,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = perform_WPD_filtering(
ws, to_double(data), self.bdelay, self.btaps
)
elif self.beamformer_type == "wpd_souden":
ws = get_WPD_filter_v2(
to_double(psd_speech),
to_double(psd_n),
u,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = perform_WPD_filtering(
ws, to_double(data), self.bdelay, self.btaps
)
elif self.beamformer_type in ("mwf", "wmwf"):
ws = get_mwf_vector(
to_double(psd_speech),
to_double(psd_n),
u,
use_torch_solver=self.use_torch_solver,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = apply_beamforming_vector(ws, to_double(data))
elif self.beamformer_type == "sdw_mwf":
ws = get_sdw_mwf_vector(
to_double(psd_speech),
to_double(psd_n),
u,
denoising_weight=self.mwf_mu,
use_torch_solver=self.use_torch_solver,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = apply_beamforming_vector(ws, to_double(data))
elif self.beamformer_type == "r1mwf":
ws = get_rank1_mwf_vector(
to_double(psd_speech),
to_double(psd_n),
u,
denoising_weight=self.mwf_mu,
use_torch_solver=self.use_torch_solver,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = apply_beamforming_vector(ws, to_double(data))
elif self.beamformer_type in ("lcmp", "wlcmp", "lcmv"):
ws = get_lcmv_vector_with_rtf(
to_double(psd_n),
to_double(rtf_mat),
reference_vector=spk,
use_torch_solver=self.use_torch_solver,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = apply_beamforming_vector(ws, to_double(data))
elif self.beamformer_type.startswith("gev"):
ws = get_gev_vector(
to_double(psd_n),
to_double(psd_speech),
mode="power",
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = apply_beamforming_vector(ws, to_double(data))
if self.beamformer_type == "gev_ban":
gain = blind_analytic_normalization(ws, to_double(psd_n))
enhanced = enhanced * gain.unsqueeze(-1)
else:
raise ValueError(
"Not supporting beamformer_type={}".format(self.beamformer_type)
)
return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype)
[docs] def predict_mask(
self, data: Union[torch.Tensor, ComplexTensor], ilens: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
"""Predict masks for beamforming.
Args:
data (torch.complex64/ComplexTensor): (B, T, C, F), double precision
ilens (torch.Tensor): (B,)
Returns:
masks (torch.Tensor): (B, T, C, F)
ilens (torch.Tensor): (B,)
"""
masks, _ = self.mask(to_float(data.permute(0, 3, 2, 1)), ilens)
# (B, F, C, T) -> (B, T, C, F)
masks = [m.transpose(-1, -3) for m in masks]
return masks, ilens
[docs]class AttentionReference(torch.nn.Module):
def __init__(self, bidim, att_dim, eps=1e-6):
super().__init__()
self.mlp_psd = torch.nn.Linear(bidim, att_dim)
self.gvec = torch.nn.Linear(att_dim, 1)
self.eps = eps
[docs] def forward(
self,
psd_in: Union[torch.Tensor, ComplexTensor],
ilens: torch.LongTensor,
scaling: float = 2.0,
) -> Tuple[torch.Tensor, torch.LongTensor]:
"""Attention-based reference forward function.
Args:
psd_in (torch.complex64/ComplexTensor): (B, F, C, C)
ilens (torch.Tensor): (B,)
scaling (float):
Returns:
u (torch.Tensor): (B, C)
ilens (torch.Tensor): (B,)
"""
B, _, C = psd_in.size()[:3]
assert psd_in.size(2) == psd_in.size(3), psd_in.size()
# psd_in: (B, F, C, C)
psd = psd_in.masked_fill(
torch.eye(C, dtype=torch.bool, device=psd_in.device).type(torch.bool), 0
)
# psd: (B, F, C, C) -> (B, C, F)
psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
# Calculate amplitude
psd_feat = (psd.real**2 + psd.imag**2 + self.eps) ** 0.5
# (B, C, F) -> (B, C, F2)
mlp_psd = self.mlp_psd(psd_feat)
# (B, C, F2) -> (B, C, 1) -> (B, C)
e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
u = F.softmax(scaling * e, dim=-1)
return u, ilens