conex / espnet2 /enh /layers /dnn_beamformer.py
tobiasc's picture
Initial commit
ad16788
from distutils.version import LooseVersion
from typing import List
from typing import Tuple
from typing import Union
import logging
import torch
from torch.nn import functional as F
from torch_complex import functional as FC
from torch_complex.tensor import ComplexTensor
from espnet.nets.pytorch_backend.frontends.beamformer import apply_beamforming_vector
from espnet.nets.pytorch_backend.frontends.beamformer import (
get_power_spectral_density_matrix, # noqa: H301
)
from espnet2.enh.layers.beamformer import get_covariances
from espnet2.enh.layers.beamformer import get_mvdr_vector
from espnet2.enh.layers.beamformer import get_mvdr_vector_with_rtf
from espnet2.enh.layers.beamformer import get_WPD_filter_v2
from espnet2.enh.layers.beamformer import get_WPD_filter_with_rtf
from espnet2.enh.layers.beamformer import perform_WPD_filtering
from espnet2.enh.layers.mask_estimator import MaskEstimator
is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2.0")
is_torch_1_3_plus = LooseVersion(torch.__version__) >= LooseVersion("1.3.0")
BEAMFORMER_TYPES = (
# Minimum Variance Distortionless Response beamformer
"mvdr", # RTF-based formula
"mvdr_souden", # Souden's solution
# Minimum Power Distortionless Response beamformer
"mpdr", # RTF-based formula
"mpdr_souden", # Souden's solution
# weighted MPDR beamformer
"wmpdr", # RTF-based formula
"wmpdr_souden", # Souden's solution
# Weighted Power minimization Distortionless response beamformer
"wpd", # RTF-based formula
"wpd_souden", # Souden's solution
)
class DNN_Beamformer(torch.nn.Module):
"""DNN mask based Beamformer.
Citation:
Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
http://proceedings.mlr.press/v70/ochiai17a/ochiai17a.pdf
"""
def __init__(
self,
bidim,
btype: str = "blstmp",
blayers: int = 3,
bunits: int = 300,
bprojs: int = 320,
num_spk: int = 1,
use_noise_mask: bool = True,
nonlinear: str = "sigmoid",
dropout_rate: float = 0.0,
badim: int = 320,
ref_channel: int = -1,
beamformer_type: str = "mvdr_souden",
rtf_iterations: int = 2,
eps: float = 1e-6,
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
mask_flooring: bool = False,
flooring_thres: float = 1e-6,
use_torch_solver: bool = True,
# only for WPD beamformer
btaps: int = 5,
bdelay: int = 3,
):
super().__init__()
bnmask = num_spk + 1 if use_noise_mask else num_spk
self.mask = MaskEstimator(
btype,
bidim,
blayers,
bunits,
bprojs,
dropout_rate,
nmask=bnmask,
nonlinear=nonlinear,
)
self.ref = AttentionReference(bidim, badim) if ref_channel < 0 else None
self.ref_channel = ref_channel
self.use_noise_mask = use_noise_mask
assert num_spk >= 1, num_spk
self.num_spk = num_spk
self.nmask = bnmask
if beamformer_type not in BEAMFORMER_TYPES:
raise ValueError("Not supporting beamformer_type=%s" % beamformer_type)
if (
beamformer_type == "mvdr_souden" or not beamformer_type.endswith("_souden")
) and not use_noise_mask:
if num_spk == 1:
logging.warning(
"Initializing %s beamformer without noise mask "
"estimator (single-speaker case)" % beamformer_type.upper()
)
logging.warning(
"(1 - speech_mask) will be used for estimating noise "
"PSD in %s beamformer!" % beamformer_type.upper()
)
else:
logging.warning(
"Initializing %s beamformer without noise mask "
"estimator (multi-speaker case)" % beamformer_type.upper()
)
logging.warning(
"Interference speech masks will be used for estimating "
"noise PSD in %s beamformer!" % beamformer_type.upper()
)
self.beamformer_type = beamformer_type
if not beamformer_type.endswith("_souden"):
assert rtf_iterations >= 2, rtf_iterations
# number of iterations in power method for estimating the RTF
self.rtf_iterations = rtf_iterations
assert btaps >= 0 and bdelay >= 0, (btaps, bdelay)
self.btaps = btaps
self.bdelay = bdelay if self.btaps > 0 else 1
self.eps = eps
self.diagonal_loading = diagonal_loading
self.diag_eps = diag_eps
self.mask_flooring = mask_flooring
self.flooring_thres = flooring_thres
self.use_torch_solver = use_torch_solver
def forward(
self,
data: ComplexTensor,
ilens: torch.LongTensor,
powers: Union[List[torch.Tensor], None] = None,
) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]:
"""DNN_Beamformer forward function.
Notation:
B: Batch
C: Channel
T: Time or Sequence length
F: Freq
Args:
data (ComplexTensor): (B, T, C, F)
ilens (torch.Tensor): (B,)
powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T)
Returns:
enhanced (ComplexTensor): (B, T, F)
ilens (torch.Tensor): (B,)
masks (torch.Tensor): (B, T, C, F)
"""
def apply_beamforming(data, ilens, psd_n, psd_speech, psd_distortion=None):
"""Beamforming with the provided statistics.
Args:
data (ComplexTensor): (B, F, C, T)
ilens (torch.Tensor): (B,)
psd_n (ComplexTensor):
Noise covariance matrix for MVDR (B, F, C, C)
Observation covariance matrix for MPDR/wMPDR (B, F, C, C)
Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C)
psd_speech (ComplexTensor): Speech covariance matrix (B, F, C, C)
psd_distortion (ComplexTensor): Noise covariance matrix (B, F, C, C)
Return:
enhanced (ComplexTensor): (B, F, T)
ws (ComplexTensor): (B, F) or (B, F, (btaps+1)*C)
"""
# u: (B, C)
if self.ref_channel < 0:
u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens)
u = u.double()
else:
if self.beamformer_type.endswith("_souden"):
# (optional) Create onehot vector for fixed reference microphone
u = torch.zeros(
*(data.size()[:-3] + (data.size(-2),)),
device=data.device,
dtype=torch.double
)
u[..., self.ref_channel].fill_(1)
else:
# for simplifying computation in RTF-based beamforming
u = self.ref_channel
if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"):
ws = get_mvdr_vector_with_rtf(
psd_n.double(),
psd_speech.double(),
psd_distortion.double(),
iterations=self.rtf_iterations,
reference_vector=u,
normalize_ref_channel=self.ref_channel,
use_torch_solver=self.use_torch_solver,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = apply_beamforming_vector(ws, data.double())
elif self.beamformer_type in ("mpdr_souden", "mvdr_souden", "wmpdr_souden"):
ws = get_mvdr_vector(
psd_speech.double(),
psd_n.double(),
u,
use_torch_solver=self.use_torch_solver,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = apply_beamforming_vector(ws, data.double())
elif self.beamformer_type == "wpd":
ws = get_WPD_filter_with_rtf(
psd_n.double(),
psd_speech.double(),
psd_distortion.double(),
iterations=self.rtf_iterations,
reference_vector=u,
normalize_ref_channel=self.ref_channel,
use_torch_solver=self.use_torch_solver,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = perform_WPD_filtering(
ws, data.double(), self.bdelay, self.btaps
)
elif self.beamformer_type == "wpd_souden":
ws = get_WPD_filter_v2(
psd_speech.double(),
psd_n.double(),
u,
diagonal_loading=self.diagonal_loading,
diag_eps=self.diag_eps,
)
enhanced = perform_WPD_filtering(
ws, data.double(), self.bdelay, self.btaps
)
else:
raise ValueError(
"Not supporting beamformer_type={}".format(self.beamformer_type)
)
return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype)
# data (B, T, C, F) -> (B, F, C, T)
data = data.permute(0, 3, 2, 1)
data_d = data.double()
# mask: [(B, F, C, T)]
masks, _ = self.mask(data, ilens)
assert self.nmask == len(masks), len(masks)
# floor masks to increase numerical stability
if self.mask_flooring:
masks = [torch.clamp(m, min=self.flooring_thres) for m in masks]
if self.num_spk == 1: # single-speaker case
if self.use_noise_mask:
# (mask_speech, mask_noise)
mask_speech, mask_noise = masks
else:
# (mask_speech,)
mask_speech = masks[0]
mask_noise = 1 - mask_speech
if self.beamformer_type.startswith(
"wmpdr"
) or self.beamformer_type.startswith("wpd"):
if powers is None:
power_input = data_d.real ** 2 + data_d.imag ** 2
# Averaging along the channel axis: (..., C, T) -> (..., T)
powers = (power_input * mask_speech.double()).mean(dim=-2)
else:
assert len(powers) == 1, len(powers)
powers = powers[0]
inverse_power = 1 / torch.clamp(powers, min=self.eps)
psd_speech = get_power_spectral_density_matrix(data_d, mask_speech.double())
if mask_noise is not None and (
self.beamformer_type == "mvdr_souden"
or not self.beamformer_type.endswith("_souden")
):
# MVDR or other RTF-based formulas
psd_noise = get_power_spectral_density_matrix(
data_d, mask_noise.double()
)
if self.beamformer_type == "mvdr":
enhanced, ws = apply_beamforming(
data, ilens, psd_noise, psd_speech, psd_distortion=psd_noise
)
elif self.beamformer_type == "mvdr_souden":
enhanced, ws = apply_beamforming(data, ilens, psd_noise, psd_speech)
elif self.beamformer_type == "mpdr":
psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()])
enhanced, ws = apply_beamforming(
data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise
)
elif self.beamformer_type == "mpdr_souden":
psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()])
enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech)
elif self.beamformer_type == "wmpdr":
psd_observed = FC.einsum(
"...ct,...et->...ce",
[data_d * inverse_power[..., None, :], data_d.conj()],
)
enhanced, ws = apply_beamforming(
data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise
)
elif self.beamformer_type == "wmpdr_souden":
psd_observed = FC.einsum(
"...ct,...et->...ce",
[data_d * inverse_power[..., None, :], data_d.conj()],
)
enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech)
elif self.beamformer_type == "wpd":
psd_observed_bar = get_covariances(
data_d, inverse_power, self.bdelay, self.btaps, get_vector=False
)
enhanced, ws = apply_beamforming(
data, ilens, psd_observed_bar, psd_speech, psd_distortion=psd_noise
)
elif self.beamformer_type == "wpd_souden":
psd_observed_bar = get_covariances(
data_d, inverse_power, self.bdelay, self.btaps, get_vector=False
)
enhanced, ws = apply_beamforming(
data, ilens, psd_observed_bar, psd_speech
)
else:
raise ValueError(
"Not supporting beamformer_type={}".format(self.beamformer_type)
)
# (..., F, T) -> (..., T, F)
enhanced = enhanced.transpose(-1, -2)
else: # multi-speaker case
if self.use_noise_mask:
# (mask_speech1, ..., mask_noise)
mask_speech = list(masks[:-1])
mask_noise = masks[-1]
else:
# (mask_speech1, ..., mask_speechX)
mask_speech = list(masks)
mask_noise = None
if self.beamformer_type.startswith(
"wmpdr"
) or self.beamformer_type.startswith("wpd"):
if powers is None:
power_input = data_d.real ** 2 + data_d.imag ** 2
# Averaging along the channel axis: (..., C, T) -> (..., T)
powers = [
(power_input * m.double()).mean(dim=-2) for m in mask_speech
]
else:
assert len(powers) == self.num_spk, len(powers)
inverse_power = [1 / torch.clamp(p, min=self.eps) for p in powers]
psd_speeches = [
get_power_spectral_density_matrix(data_d, mask.double())
for mask in mask_speech
]
if mask_noise is not None and (
self.beamformer_type == "mvdr_souden"
or not self.beamformer_type.endswith("_souden")
):
# MVDR or other RTF-based formulas
psd_noise = get_power_spectral_density_matrix(
data_d, mask_noise.double()
)
if self.beamformer_type in ("mpdr", "mpdr_souden"):
psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()])
elif self.beamformer_type in ("wmpdr", "wmpdr_souden"):
psd_observed = [
FC.einsum(
"...ct,...et->...ce",
[data_d * inv_p[..., None, :], data_d.conj()],
)
for inv_p in inverse_power
]
elif self.beamformer_type in ("wpd", "wpd_souden"):
psd_observed_bar = [
get_covariances(
data_d, inv_p, self.bdelay, self.btaps, get_vector=False
)
for inv_p in inverse_power
]
enhanced, ws = [], []
for i in range(self.num_spk):
psd_speech = psd_speeches.pop(i)
if (
self.beamformer_type == "mvdr_souden"
or not self.beamformer_type.endswith("_souden")
):
psd_noise_i = (
psd_noise + sum(psd_speeches)
if mask_noise is not None
else sum(psd_speeches)
)
# treat all other speakers' psd_speech as noises
if self.beamformer_type == "mvdr":
enh, w = apply_beamforming(
data, ilens, psd_noise_i, psd_speech, psd_distortion=psd_noise_i
)
elif self.beamformer_type == "mvdr_souden":
enh, w = apply_beamforming(data, ilens, psd_noise_i, psd_speech)
elif self.beamformer_type == "mpdr":
enh, w = apply_beamforming(
data,
ilens,
psd_observed,
psd_speech,
psd_distortion=psd_noise_i,
)
elif self.beamformer_type == "mpdr_souden":
enh, w = apply_beamforming(data, ilens, psd_observed, psd_speech)
elif self.beamformer_type == "wmpdr":
enh, w = apply_beamforming(
data,
ilens,
psd_observed[i],
psd_speech,
psd_distortion=psd_noise_i,
)
elif self.beamformer_type == "wmpdr_souden":
enh, w = apply_beamforming(data, ilens, psd_observed[i], psd_speech)
elif self.beamformer_type == "wpd":
enh, w = apply_beamforming(
data,
ilens,
psd_observed_bar[i],
psd_speech,
psd_distortion=psd_noise_i,
)
elif self.beamformer_type == "wpd_souden":
enh, w = apply_beamforming(
data, ilens, psd_observed_bar[i], psd_speech
)
else:
raise ValueError(
"Not supporting beamformer_type={}".format(self.beamformer_type)
)
psd_speeches.insert(i, psd_speech)
# (..., F, T) -> (..., T, F)
enh = enh.transpose(-1, -2)
enhanced.append(enh)
ws.append(w)
# (..., F, C, T) -> (..., T, C, F)
masks = [m.transpose(-1, -3) for m in masks]
return enhanced, ilens, masks
def predict_mask(
self, data: ComplexTensor, ilens: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
"""Predict masks for beamforming.
Args:
data (ComplexTensor): (B, T, C, F), double precision
ilens (torch.Tensor): (B,)
Returns:
masks (torch.Tensor): (B, T, C, F)
ilens (torch.Tensor): (B,)
"""
masks, _ = self.mask(data.permute(0, 3, 2, 1).float(), ilens)
# (B, F, C, T) -> (B, T, C, F)
masks = [m.transpose(-1, -3) for m in masks]
return masks, ilens
class AttentionReference(torch.nn.Module):
def __init__(self, bidim, att_dim):
super().__init__()
self.mlp_psd = torch.nn.Linear(bidim, att_dim)
self.gvec = torch.nn.Linear(att_dim, 1)
def forward(
self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0
) -> Tuple[torch.Tensor, torch.LongTensor]:
"""Attention-based reference forward function.
Args:
psd_in (ComplexTensor): (B, F, C, C)
ilens (torch.Tensor): (B,)
scaling (float):
Returns:
u (torch.Tensor): (B, C)
ilens (torch.Tensor): (B,)
"""
B, _, C = psd_in.size()[:3]
assert psd_in.size(2) == psd_in.size(3), psd_in.size()
# psd_in: (B, F, C, C)
datatype = torch.bool if is_torch_1_3_plus else torch.uint8
datatype2 = torch.bool if is_torch_1_2_plus else torch.uint8
psd = psd_in.masked_fill(
torch.eye(C, dtype=datatype, device=psd_in.device).type(datatype2), 0
)
# psd: (B, F, C, C) -> (B, C, F)
psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
# Calculate amplitude
psd_feat = (psd.real ** 2 + psd.imag ** 2) ** 0.5
# (B, C, F) -> (B, C, F2)
mlp_psd = self.mlp_psd(psd_feat)
# (B, C, F2) -> (B, C, 1) -> (B, C)
e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
u = F.softmax(scaling * e, dim=-1)
return u, ilens