|
from distutils.version import LooseVersion |
|
from typing import List |
|
from typing import Tuple |
|
from typing import Union |
|
|
|
import logging |
|
import torch |
|
from torch.nn import functional as F |
|
from torch_complex import functional as FC |
|
from torch_complex.tensor import ComplexTensor |
|
|
|
from espnet.nets.pytorch_backend.frontends.beamformer import apply_beamforming_vector |
|
from espnet.nets.pytorch_backend.frontends.beamformer import ( |
|
get_power_spectral_density_matrix, |
|
) |
|
from espnet2.enh.layers.beamformer import get_covariances |
|
from espnet2.enh.layers.beamformer import get_mvdr_vector |
|
from espnet2.enh.layers.beamformer import get_mvdr_vector_with_rtf |
|
from espnet2.enh.layers.beamformer import get_WPD_filter_v2 |
|
from espnet2.enh.layers.beamformer import get_WPD_filter_with_rtf |
|
from espnet2.enh.layers.beamformer import perform_WPD_filtering |
|
from espnet2.enh.layers.mask_estimator import MaskEstimator |
|
|
|
is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2.0") |
|
is_torch_1_3_plus = LooseVersion(torch.__version__) >= LooseVersion("1.3.0") |
|
|
|
|
|
BEAMFORMER_TYPES = ( |
|
|
|
"mvdr", |
|
"mvdr_souden", |
|
|
|
"mpdr", |
|
"mpdr_souden", |
|
|
|
"wmpdr", |
|
"wmpdr_souden", |
|
|
|
"wpd", |
|
"wpd_souden", |
|
) |
|
|
|
|
|
class DNN_Beamformer(torch.nn.Module): |
|
"""DNN mask based Beamformer. |
|
|
|
Citation: |
|
Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017; |
|
http://proceedings.mlr.press/v70/ochiai17a/ochiai17a.pdf |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
bidim, |
|
btype: str = "blstmp", |
|
blayers: int = 3, |
|
bunits: int = 300, |
|
bprojs: int = 320, |
|
num_spk: int = 1, |
|
use_noise_mask: bool = True, |
|
nonlinear: str = "sigmoid", |
|
dropout_rate: float = 0.0, |
|
badim: int = 320, |
|
ref_channel: int = -1, |
|
beamformer_type: str = "mvdr_souden", |
|
rtf_iterations: int = 2, |
|
eps: float = 1e-6, |
|
diagonal_loading: bool = True, |
|
diag_eps: float = 1e-7, |
|
mask_flooring: bool = False, |
|
flooring_thres: float = 1e-6, |
|
use_torch_solver: bool = True, |
|
|
|
btaps: int = 5, |
|
bdelay: int = 3, |
|
): |
|
super().__init__() |
|
bnmask = num_spk + 1 if use_noise_mask else num_spk |
|
self.mask = MaskEstimator( |
|
btype, |
|
bidim, |
|
blayers, |
|
bunits, |
|
bprojs, |
|
dropout_rate, |
|
nmask=bnmask, |
|
nonlinear=nonlinear, |
|
) |
|
self.ref = AttentionReference(bidim, badim) if ref_channel < 0 else None |
|
self.ref_channel = ref_channel |
|
|
|
self.use_noise_mask = use_noise_mask |
|
assert num_spk >= 1, num_spk |
|
self.num_spk = num_spk |
|
self.nmask = bnmask |
|
|
|
if beamformer_type not in BEAMFORMER_TYPES: |
|
raise ValueError("Not supporting beamformer_type=%s" % beamformer_type) |
|
if ( |
|
beamformer_type == "mvdr_souden" or not beamformer_type.endswith("_souden") |
|
) and not use_noise_mask: |
|
if num_spk == 1: |
|
logging.warning( |
|
"Initializing %s beamformer without noise mask " |
|
"estimator (single-speaker case)" % beamformer_type.upper() |
|
) |
|
logging.warning( |
|
"(1 - speech_mask) will be used for estimating noise " |
|
"PSD in %s beamformer!" % beamformer_type.upper() |
|
) |
|
else: |
|
logging.warning( |
|
"Initializing %s beamformer without noise mask " |
|
"estimator (multi-speaker case)" % beamformer_type.upper() |
|
) |
|
logging.warning( |
|
"Interference speech masks will be used for estimating " |
|
"noise PSD in %s beamformer!" % beamformer_type.upper() |
|
) |
|
|
|
self.beamformer_type = beamformer_type |
|
if not beamformer_type.endswith("_souden"): |
|
assert rtf_iterations >= 2, rtf_iterations |
|
|
|
self.rtf_iterations = rtf_iterations |
|
|
|
assert btaps >= 0 and bdelay >= 0, (btaps, bdelay) |
|
self.btaps = btaps |
|
self.bdelay = bdelay if self.btaps > 0 else 1 |
|
self.eps = eps |
|
self.diagonal_loading = diagonal_loading |
|
self.diag_eps = diag_eps |
|
self.mask_flooring = mask_flooring |
|
self.flooring_thres = flooring_thres |
|
self.use_torch_solver = use_torch_solver |
|
|
|
def forward( |
|
self, |
|
data: ComplexTensor, |
|
ilens: torch.LongTensor, |
|
powers: Union[List[torch.Tensor], None] = None, |
|
) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]: |
|
"""DNN_Beamformer forward function. |
|
|
|
Notation: |
|
B: Batch |
|
C: Channel |
|
T: Time or Sequence length |
|
F: Freq |
|
|
|
Args: |
|
data (ComplexTensor): (B, T, C, F) |
|
ilens (torch.Tensor): (B,) |
|
powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T) |
|
Returns: |
|
enhanced (ComplexTensor): (B, T, F) |
|
ilens (torch.Tensor): (B,) |
|
masks (torch.Tensor): (B, T, C, F) |
|
""" |
|
|
|
def apply_beamforming(data, ilens, psd_n, psd_speech, psd_distortion=None): |
|
"""Beamforming with the provided statistics. |
|
|
|
Args: |
|
data (ComplexTensor): (B, F, C, T) |
|
ilens (torch.Tensor): (B,) |
|
psd_n (ComplexTensor): |
|
Noise covariance matrix for MVDR (B, F, C, C) |
|
Observation covariance matrix for MPDR/wMPDR (B, F, C, C) |
|
Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C) |
|
psd_speech (ComplexTensor): Speech covariance matrix (B, F, C, C) |
|
psd_distortion (ComplexTensor): Noise covariance matrix (B, F, C, C) |
|
Return: |
|
enhanced (ComplexTensor): (B, F, T) |
|
ws (ComplexTensor): (B, F) or (B, F, (btaps+1)*C) |
|
""" |
|
|
|
if self.ref_channel < 0: |
|
u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens) |
|
u = u.double() |
|
else: |
|
if self.beamformer_type.endswith("_souden"): |
|
|
|
u = torch.zeros( |
|
*(data.size()[:-3] + (data.size(-2),)), |
|
device=data.device, |
|
dtype=torch.double |
|
) |
|
u[..., self.ref_channel].fill_(1) |
|
else: |
|
|
|
u = self.ref_channel |
|
|
|
if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"): |
|
ws = get_mvdr_vector_with_rtf( |
|
psd_n.double(), |
|
psd_speech.double(), |
|
psd_distortion.double(), |
|
iterations=self.rtf_iterations, |
|
reference_vector=u, |
|
normalize_ref_channel=self.ref_channel, |
|
use_torch_solver=self.use_torch_solver, |
|
diagonal_loading=self.diagonal_loading, |
|
diag_eps=self.diag_eps, |
|
) |
|
enhanced = apply_beamforming_vector(ws, data.double()) |
|
elif self.beamformer_type in ("mpdr_souden", "mvdr_souden", "wmpdr_souden"): |
|
ws = get_mvdr_vector( |
|
psd_speech.double(), |
|
psd_n.double(), |
|
u, |
|
use_torch_solver=self.use_torch_solver, |
|
diagonal_loading=self.diagonal_loading, |
|
diag_eps=self.diag_eps, |
|
) |
|
enhanced = apply_beamforming_vector(ws, data.double()) |
|
elif self.beamformer_type == "wpd": |
|
ws = get_WPD_filter_with_rtf( |
|
psd_n.double(), |
|
psd_speech.double(), |
|
psd_distortion.double(), |
|
iterations=self.rtf_iterations, |
|
reference_vector=u, |
|
normalize_ref_channel=self.ref_channel, |
|
use_torch_solver=self.use_torch_solver, |
|
diagonal_loading=self.diagonal_loading, |
|
diag_eps=self.diag_eps, |
|
) |
|
enhanced = perform_WPD_filtering( |
|
ws, data.double(), self.bdelay, self.btaps |
|
) |
|
elif self.beamformer_type == "wpd_souden": |
|
ws = get_WPD_filter_v2( |
|
psd_speech.double(), |
|
psd_n.double(), |
|
u, |
|
diagonal_loading=self.diagonal_loading, |
|
diag_eps=self.diag_eps, |
|
) |
|
enhanced = perform_WPD_filtering( |
|
ws, data.double(), self.bdelay, self.btaps |
|
) |
|
else: |
|
raise ValueError( |
|
"Not supporting beamformer_type={}".format(self.beamformer_type) |
|
) |
|
|
|
return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype) |
|
|
|
|
|
data = data.permute(0, 3, 2, 1) |
|
data_d = data.double() |
|
|
|
|
|
masks, _ = self.mask(data, ilens) |
|
assert self.nmask == len(masks), len(masks) |
|
|
|
if self.mask_flooring: |
|
masks = [torch.clamp(m, min=self.flooring_thres) for m in masks] |
|
|
|
if self.num_spk == 1: |
|
if self.use_noise_mask: |
|
|
|
mask_speech, mask_noise = masks |
|
else: |
|
|
|
mask_speech = masks[0] |
|
mask_noise = 1 - mask_speech |
|
|
|
if self.beamformer_type.startswith( |
|
"wmpdr" |
|
) or self.beamformer_type.startswith("wpd"): |
|
if powers is None: |
|
power_input = data_d.real ** 2 + data_d.imag ** 2 |
|
|
|
powers = (power_input * mask_speech.double()).mean(dim=-2) |
|
else: |
|
assert len(powers) == 1, len(powers) |
|
powers = powers[0] |
|
inverse_power = 1 / torch.clamp(powers, min=self.eps) |
|
|
|
psd_speech = get_power_spectral_density_matrix(data_d, mask_speech.double()) |
|
if mask_noise is not None and ( |
|
self.beamformer_type == "mvdr_souden" |
|
or not self.beamformer_type.endswith("_souden") |
|
): |
|
|
|
psd_noise = get_power_spectral_density_matrix( |
|
data_d, mask_noise.double() |
|
) |
|
if self.beamformer_type == "mvdr": |
|
enhanced, ws = apply_beamforming( |
|
data, ilens, psd_noise, psd_speech, psd_distortion=psd_noise |
|
) |
|
elif self.beamformer_type == "mvdr_souden": |
|
enhanced, ws = apply_beamforming(data, ilens, psd_noise, psd_speech) |
|
elif self.beamformer_type == "mpdr": |
|
psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) |
|
enhanced, ws = apply_beamforming( |
|
data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise |
|
) |
|
elif self.beamformer_type == "mpdr_souden": |
|
psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) |
|
enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech) |
|
elif self.beamformer_type == "wmpdr": |
|
psd_observed = FC.einsum( |
|
"...ct,...et->...ce", |
|
[data_d * inverse_power[..., None, :], data_d.conj()], |
|
) |
|
enhanced, ws = apply_beamforming( |
|
data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise |
|
) |
|
elif self.beamformer_type == "wmpdr_souden": |
|
psd_observed = FC.einsum( |
|
"...ct,...et->...ce", |
|
[data_d * inverse_power[..., None, :], data_d.conj()], |
|
) |
|
enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech) |
|
elif self.beamformer_type == "wpd": |
|
psd_observed_bar = get_covariances( |
|
data_d, inverse_power, self.bdelay, self.btaps, get_vector=False |
|
) |
|
enhanced, ws = apply_beamforming( |
|
data, ilens, psd_observed_bar, psd_speech, psd_distortion=psd_noise |
|
) |
|
elif self.beamformer_type == "wpd_souden": |
|
psd_observed_bar = get_covariances( |
|
data_d, inverse_power, self.bdelay, self.btaps, get_vector=False |
|
) |
|
enhanced, ws = apply_beamforming( |
|
data, ilens, psd_observed_bar, psd_speech |
|
) |
|
else: |
|
raise ValueError( |
|
"Not supporting beamformer_type={}".format(self.beamformer_type) |
|
) |
|
|
|
|
|
enhanced = enhanced.transpose(-1, -2) |
|
else: |
|
if self.use_noise_mask: |
|
|
|
mask_speech = list(masks[:-1]) |
|
mask_noise = masks[-1] |
|
else: |
|
|
|
mask_speech = list(masks) |
|
mask_noise = None |
|
|
|
if self.beamformer_type.startswith( |
|
"wmpdr" |
|
) or self.beamformer_type.startswith("wpd"): |
|
if powers is None: |
|
power_input = data_d.real ** 2 + data_d.imag ** 2 |
|
|
|
powers = [ |
|
(power_input * m.double()).mean(dim=-2) for m in mask_speech |
|
] |
|
else: |
|
assert len(powers) == self.num_spk, len(powers) |
|
inverse_power = [1 / torch.clamp(p, min=self.eps) for p in powers] |
|
|
|
psd_speeches = [ |
|
get_power_spectral_density_matrix(data_d, mask.double()) |
|
for mask in mask_speech |
|
] |
|
if mask_noise is not None and ( |
|
self.beamformer_type == "mvdr_souden" |
|
or not self.beamformer_type.endswith("_souden") |
|
): |
|
|
|
psd_noise = get_power_spectral_density_matrix( |
|
data_d, mask_noise.double() |
|
) |
|
if self.beamformer_type in ("mpdr", "mpdr_souden"): |
|
psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) |
|
elif self.beamformer_type in ("wmpdr", "wmpdr_souden"): |
|
psd_observed = [ |
|
FC.einsum( |
|
"...ct,...et->...ce", |
|
[data_d * inv_p[..., None, :], data_d.conj()], |
|
) |
|
for inv_p in inverse_power |
|
] |
|
elif self.beamformer_type in ("wpd", "wpd_souden"): |
|
psd_observed_bar = [ |
|
get_covariances( |
|
data_d, inv_p, self.bdelay, self.btaps, get_vector=False |
|
) |
|
for inv_p in inverse_power |
|
] |
|
|
|
enhanced, ws = [], [] |
|
for i in range(self.num_spk): |
|
psd_speech = psd_speeches.pop(i) |
|
if ( |
|
self.beamformer_type == "mvdr_souden" |
|
or not self.beamformer_type.endswith("_souden") |
|
): |
|
psd_noise_i = ( |
|
psd_noise + sum(psd_speeches) |
|
if mask_noise is not None |
|
else sum(psd_speeches) |
|
) |
|
|
|
if self.beamformer_type == "mvdr": |
|
enh, w = apply_beamforming( |
|
data, ilens, psd_noise_i, psd_speech, psd_distortion=psd_noise_i |
|
) |
|
elif self.beamformer_type == "mvdr_souden": |
|
enh, w = apply_beamforming(data, ilens, psd_noise_i, psd_speech) |
|
elif self.beamformer_type == "mpdr": |
|
enh, w = apply_beamforming( |
|
data, |
|
ilens, |
|
psd_observed, |
|
psd_speech, |
|
psd_distortion=psd_noise_i, |
|
) |
|
elif self.beamformer_type == "mpdr_souden": |
|
enh, w = apply_beamforming(data, ilens, psd_observed, psd_speech) |
|
elif self.beamformer_type == "wmpdr": |
|
enh, w = apply_beamforming( |
|
data, |
|
ilens, |
|
psd_observed[i], |
|
psd_speech, |
|
psd_distortion=psd_noise_i, |
|
) |
|
elif self.beamformer_type == "wmpdr_souden": |
|
enh, w = apply_beamforming(data, ilens, psd_observed[i], psd_speech) |
|
elif self.beamformer_type == "wpd": |
|
enh, w = apply_beamforming( |
|
data, |
|
ilens, |
|
psd_observed_bar[i], |
|
psd_speech, |
|
psd_distortion=psd_noise_i, |
|
) |
|
elif self.beamformer_type == "wpd_souden": |
|
enh, w = apply_beamforming( |
|
data, ilens, psd_observed_bar[i], psd_speech |
|
) |
|
else: |
|
raise ValueError( |
|
"Not supporting beamformer_type={}".format(self.beamformer_type) |
|
) |
|
psd_speeches.insert(i, psd_speech) |
|
|
|
|
|
enh = enh.transpose(-1, -2) |
|
enhanced.append(enh) |
|
ws.append(w) |
|
|
|
|
|
masks = [m.transpose(-1, -3) for m in masks] |
|
return enhanced, ilens, masks |
|
|
|
def predict_mask( |
|
self, data: ComplexTensor, ilens: torch.LongTensor |
|
) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]: |
|
"""Predict masks for beamforming. |
|
|
|
Args: |
|
data (ComplexTensor): (B, T, C, F), double precision |
|
ilens (torch.Tensor): (B,) |
|
Returns: |
|
masks (torch.Tensor): (B, T, C, F) |
|
ilens (torch.Tensor): (B,) |
|
""" |
|
masks, _ = self.mask(data.permute(0, 3, 2, 1).float(), ilens) |
|
|
|
masks = [m.transpose(-1, -3) for m in masks] |
|
return masks, ilens |
|
|
|
|
|
class AttentionReference(torch.nn.Module): |
|
def __init__(self, bidim, att_dim): |
|
super().__init__() |
|
self.mlp_psd = torch.nn.Linear(bidim, att_dim) |
|
self.gvec = torch.nn.Linear(att_dim, 1) |
|
|
|
def forward( |
|
self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0 |
|
) -> Tuple[torch.Tensor, torch.LongTensor]: |
|
"""Attention-based reference forward function. |
|
|
|
Args: |
|
psd_in (ComplexTensor): (B, F, C, C) |
|
ilens (torch.Tensor): (B,) |
|
scaling (float): |
|
Returns: |
|
u (torch.Tensor): (B, C) |
|
ilens (torch.Tensor): (B,) |
|
""" |
|
B, _, C = psd_in.size()[:3] |
|
assert psd_in.size(2) == psd_in.size(3), psd_in.size() |
|
|
|
datatype = torch.bool if is_torch_1_3_plus else torch.uint8 |
|
datatype2 = torch.bool if is_torch_1_2_plus else torch.uint8 |
|
psd = psd_in.masked_fill( |
|
torch.eye(C, dtype=datatype, device=psd_in.device).type(datatype2), 0 |
|
) |
|
|
|
psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2) |
|
|
|
|
|
psd_feat = (psd.real ** 2 + psd.imag ** 2) ** 0.5 |
|
|
|
|
|
mlp_psd = self.mlp_psd(psd_feat) |
|
|
|
e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1) |
|
u = F.softmax(scaling * e, dim=-1) |
|
return u, ilens |
|
|