Source code for espnet.nets.pytorch_backend.frontends.dnn_beamformer
from typing import Tuple
import torch
from torch.nn import functional as F
from espnet.nets.pytorch_backend.frontends.beamformer \
import apply_beamforming_vector
from espnet.nets.pytorch_backend.frontends.beamformer \
import get_mvdr_vector
from espnet.nets.pytorch_backend.frontends.beamformer \
import get_power_spectral_density_matrix
from espnet.nets.pytorch_backend.frontends.mask_estimator import MaskEstimator
from torch_complex.tensor import ComplexTensor
[docs]class DNN_Beamformer(torch.nn.Module):
"""DNN mask based Beamformer
Citation:
Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
https://arxiv.org/abs/1703.04783
"""
def __init__(self,
bidim,
btype='blstmp',
blayers=3,
bunits=300,
bprojs=320,
dropout_rate=0.0,
badim=320,
ref_channel: int = -1,
beamformer_type='mvdr'):
super().__init__()
self.mask = MaskEstimator(btype, bidim, blayers, bunits, bprojs,
dropout_rate, nmask=2)
self.ref = AttentionReference(bidim, badim)
self.ref_channel = ref_channel
if beamformer_type != 'mvdr':
raise ValueError(
'Not supporting beamformer_type={}'.format(beamformer_type))
self.beamformer_type = beamformer_type
[docs] def forward(self, data: ComplexTensor, ilens: torch.LongTensor) \
-> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
"""The forward function
Notation:
B: Batch
C: Channel
T: Time or Sequence length
F: Freq
Args:
data (ComplexTensor): (B, T, C, F)
ilens (torch.Tensor): (B,)
Returns:
enhanced (ComplexTensor): (B, T, F)
ilens (torch.Tensor): (B,)
"""
# data (B, T, C, F) -> (B, F, C, T)
data = data.permute(0, 3, 2, 1)
# mask: (B, F, C, T)
(mask_speech, mask_noise), _ = self.mask(data, ilens)
psd_speech = get_power_spectral_density_matrix(data, mask_speech)
psd_noise = get_power_spectral_density_matrix(data, mask_noise)
# u: (B, C)
if self.ref_channel < 0:
u, _ = self.ref(psd_speech, ilens)
else:
# (optional) Create onehot vector for fixed reference microphone
u = torch.zeros(*(data.size()[:-3] + (data.size(-2),)),
device=data.device)
u[..., self.ref_channel].fill_(1)
ws = get_mvdr_vector(psd_speech, psd_noise, u)
enhanced = apply_beamforming_vector(ws, data)
# (..., F, T) -> (..., T, F)
enhanced = enhanced.transpose(-1, -2)
mask_speech = mask_speech.transpose(-1, -3)
return enhanced, ilens, mask_speech
[docs]class AttentionReference(torch.nn.Module):
def __init__(self, bidim, att_dim):
super().__init__()
self.mlp_psd = torch.nn.Linear(bidim, att_dim)
self.gvec = torch.nn.Linear(att_dim, 1)
[docs] def forward(self, psd_in: ComplexTensor, ilens: torch.LongTensor,
scaling: float = 2.0) -> Tuple[torch.Tensor, torch.LongTensor]:
"""The forward function
Args:
psd_in (ComplexTensor): (B, F, C, C)
ilens (torch.Tensor): (B,)
scaling (float):
Returns:
u (torch.Tensor): (B, C)
ilens (torch.Tensor): (B,)
"""
B, _, C = psd_in.size()[:3]
assert psd_in.size(2) == psd_in.size(3), psd_in.size()
# psd_in: (B, F, C, C)
psd = psd_in.masked_fill(torch.eye(C, dtype=torch.uint8,
device=psd_in.device), 0)
# psd: (B, F, C, C) -> (B, C, F)
psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
# Calculate amplitude
psd_feat = (psd.real ** 2 + psd.imag ** 2) ** 0.5
# (B, C, F) -> (B, C, F2)
mlp_psd = self.mlp_psd(psd_feat)
# (B, C, F2) -> (B, C, 1) -> (B, C)
e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
u = F.softmax(scaling * e, dim=-1)
return u, ilens