conex / espnet2 /enh /layers /beamformer.py
tobiasc's picture
Initial commit
ad16788
from distutils.version import LooseVersion
from typing import List
from typing import Optional
from typing import Union
import numpy as np
import torch
from torch_complex import functional as FC
from torch_complex.tensor import ComplexTensor
is_torch_1_1_plus = LooseVersion(torch.__version__) >= LooseVersion("1.1.0")
EPS = torch.finfo(torch.double).eps
def complex_norm(c: ComplexTensor) -> torch.Tensor:
return torch.sqrt((c.real ** 2 + c.imag ** 2).sum(dim=-1, keepdim=True) + EPS)
def get_rtf(
psd_speech: ComplexTensor,
psd_noise: ComplexTensor,
reference_vector: Union[int, torch.Tensor, None] = None,
iterations: int = 3,
use_torch_solver: bool = True,
) -> ComplexTensor:
"""Calculate the relative transfer function (RTF) using the power method.
Algorithm:
1) rtf = reference_vector
2) for i in range(iterations):
rtf = (psd_noise^-1 @ psd_speech) @ rtf
rtf = rtf / ||rtf||_2 # this normalization can be skipped
3) rtf = psd_noise @ rtf
4) rtf = rtf / rtf[..., ref_channel, :]
Note: 4) Normalization at the reference channel is not performed here.
Args:
psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C)
psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C)
reference_vector (torch.Tensor or int): (..., C) or scalar
iterations (int): number of iterations in power method
use_torch_solver (bool): Whether to use `solve` instead of `inverse`
Returns:
rtf (ComplexTensor): (..., F, C, 1)
"""
if use_torch_solver and is_torch_1_1_plus:
# torch.solve is required, which is only available after pytorch 1.1.0+
phi = FC.solve(psd_speech, psd_noise)[0]
else:
phi = FC.matmul(psd_noise.inverse2(), psd_speech)
rtf = (
phi[..., reference_vector, None]
if isinstance(reference_vector, int)
else FC.matmul(phi, reference_vector[..., None, :, None])
)
for _ in range(iterations - 2):
rtf = FC.matmul(phi, rtf)
# rtf = rtf / complex_norm(rtf)
rtf = FC.matmul(psd_speech, rtf)
return rtf
def get_mvdr_vector(
psd_s: ComplexTensor,
psd_n: ComplexTensor,
reference_vector: torch.Tensor,
use_torch_solver: bool = True,
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> ComplexTensor:
"""Return the MVDR (Minimum Variance Distortionless Response) vector:
h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u
Reference:
On optimal frequency-domain multichannel linear filtering
for noise reduction; M. Souden et al., 2010;
https://ieeexplore.ieee.org/document/5089420
Args:
psd_s (ComplexTensor): speech covariance matrix (..., F, C, C)
psd_n (ComplexTensor): observation/noise covariance matrix (..., F, C, C)
reference_vector (torch.Tensor): (..., C)
use_torch_solver (bool): Whether to use `solve` instead of `inverse`
diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
diag_eps (float):
eps (float):
Returns:
beamform_vector (ComplexTensor): (..., F, C)
""" # noqa: D400
if diagonal_loading:
psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps)
if use_torch_solver and is_torch_1_1_plus:
# torch.solve is required, which is only available after pytorch 1.1.0+
numerator = FC.solve(psd_s, psd_n)[0]
else:
numerator = FC.matmul(psd_n.inverse2(), psd_s)
# ws: (..., C, C) / (...,) -> (..., C, C)
ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
# h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
return beamform_vector
def get_mvdr_vector_with_rtf(
psd_n: ComplexTensor,
psd_speech: ComplexTensor,
psd_noise: ComplexTensor,
iterations: int = 3,
reference_vector: Union[int, torch.Tensor, None] = None,
normalize_ref_channel: Optional[int] = None,
use_torch_solver: bool = True,
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> ComplexTensor:
"""Return the MVDR (Minimum Variance Distortionless Response) vector
calculated with RTF:
h = (Npsd^-1 @ rtf) / (rtf^H @ Npsd^-1 @ rtf)
Reference:
On optimal frequency-domain multichannel linear filtering
for noise reduction; M. Souden et al., 2010;
https://ieeexplore.ieee.org/document/5089420
Args:
psd_n (ComplexTensor): observation/noise covariance matrix (..., F, C, C)
psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C)
psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C)
iterations (int): number of iterations in power method
reference_vector (torch.Tensor or int): (..., C) or scalar
normalize_ref_channel (int): reference channel for normalizing the RTF
use_torch_solver (bool): Whether to use `solve` instead of `inverse`
diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
diag_eps (float):
eps (float):
Returns:
beamform_vector (ComplexTensor): (..., F, C)
""" # noqa: H405, D205, D400
if diagonal_loading:
psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)
# (B, F, C, 1)
rtf = get_rtf(
psd_speech,
psd_noise,
reference_vector,
iterations=iterations,
use_torch_solver=use_torch_solver,
)
# numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1)
if use_torch_solver and is_torch_1_1_plus:
# torch.solve is required, which is only available after pytorch 1.1.0+
numerator = FC.solve(rtf, psd_n)[0].squeeze(-1)
else:
numerator = FC.matmul(psd_n.inverse2(), rtf).squeeze(-1)
denominator = FC.einsum("...d,...d->...", [rtf.squeeze(-1).conj(), numerator])
if normalize_ref_channel is not None:
scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj()
beamforming_vector = numerator * scale / (denominator.real.unsqueeze(-1) + eps)
else:
beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps)
return beamforming_vector
def signal_framing(
signal: Union[torch.Tensor, ComplexTensor],
frame_length: int,
frame_step: int,
bdelay: int,
do_padding: bool = False,
pad_value: int = 0,
indices: List = None,
) -> Union[torch.Tensor, ComplexTensor]:
"""Expand `signal` into several frames, with each frame of length `frame_length`.
Args:
signal : (..., T)
frame_length: length of each segment
frame_step: step for selecting frames
bdelay: delay for WPD
do_padding: whether or not to pad the input signal at the beginning
of the time dimension
pad_value: value to fill in the padding
Returns:
torch.Tensor:
if do_padding: (..., T, frame_length)
else: (..., T - bdelay - frame_length + 2, frame_length)
"""
frame_length2 = frame_length - 1
# pad to the right at the last dimension of `signal` (time dimension)
if do_padding:
# (..., T) --> (..., T + bdelay + frame_length - 2)
signal = FC.pad(signal, (bdelay + frame_length2 - 1, 0), "constant", pad_value)
do_padding = False
if indices is None:
# [[ 0, 1, ..., frame_length2 - 1, frame_length2 - 1 + bdelay ],
# [ 1, 2, ..., frame_length2, frame_length2 + bdelay ],
# [ 2, 3, ..., frame_length2 + 1, frame_length2 + 1 + bdelay ],
# ...
# [ T-bdelay-frame_length2, ..., T-1-bdelay, T-1 ]]
indices = [
[*range(i, i + frame_length2), i + frame_length2 + bdelay - 1]
for i in range(0, signal.shape[-1] - frame_length2 - bdelay + 1, frame_step)
]
if isinstance(signal, ComplexTensor):
real = signal_framing(
signal.real,
frame_length,
frame_step,
bdelay,
do_padding,
pad_value,
indices,
)
imag = signal_framing(
signal.imag,
frame_length,
frame_step,
bdelay,
do_padding,
pad_value,
indices,
)
return ComplexTensor(real, imag)
else:
# (..., T - bdelay - frame_length + 2, frame_length)
signal = signal[..., indices]
# signal[..., :-1] = -signal[..., :-1]
return signal
def get_covariances(
Y: ComplexTensor,
inverse_power: torch.Tensor,
bdelay: int,
btaps: int,
get_vector: bool = False,
) -> ComplexTensor:
"""Calculates the power normalized spatio-temporal covariance
matrix of the framed signal.
Args:
Y : Complext STFT signal with shape (B, F, C, T)
inverse_power : Weighting factor with shape (B, F, T)
Returns:
Correlation matrix: (B, F, (btaps+1) * C, (btaps+1) * C)
Correlation vector: (B, F, btaps + 1, C, C)
""" # noqa: H405, D205, D400, D401
assert inverse_power.dim() == 3, inverse_power.dim()
assert inverse_power.size(0) == Y.size(0), (inverse_power.size(0), Y.size(0))
Bs, Fdim, C, T = Y.shape
# (B, F, C, T - bdelay - btaps + 1, btaps + 1)
Psi = signal_framing(Y, btaps + 1, 1, bdelay, do_padding=False)[
..., : T - bdelay - btaps + 1, :
]
# Reverse along btaps-axis:
# [tau, tau-bdelay, tau-bdelay-1, ..., tau-bdelay-frame_length+1]
Psi = FC.reverse(Psi, dim=-1)
Psi_norm = Psi * inverse_power[..., None, bdelay + btaps - 1 :, None]
# let T' = T - bdelay - btaps + 1
# (B, F, C, T', btaps + 1) x (B, F, C, T', btaps + 1)
# -> (B, F, btaps + 1, C, btaps + 1, C)
covariance_matrix = FC.einsum("bfdtk,bfetl->bfkdle", (Psi, Psi_norm.conj()))
# (B, F, btaps + 1, C, btaps + 1, C)
# -> (B, F, (btaps + 1) * C, (btaps + 1) * C)
covariance_matrix = covariance_matrix.view(
Bs, Fdim, (btaps + 1) * C, (btaps + 1) * C
)
if get_vector:
# (B, F, C, T', btaps + 1) x (B, F, C, T')
# --> (B, F, btaps +1, C, C)
covariance_vector = FC.einsum(
"bfdtk,bfet->bfked", (Psi_norm, Y[..., bdelay + btaps - 1 :].conj())
)
return covariance_matrix, covariance_vector
else:
return covariance_matrix
def get_WPD_filter(
Phi: ComplexTensor,
Rf: ComplexTensor,
reference_vector: torch.Tensor,
use_torch_solver: bool = True,
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> ComplexTensor:
"""Return the WPD vector.
WPD is the Weighted Power minimization Distortionless response
convolutional beamformer. As follows:
h = (Rf^-1 @ Phi_{xx}) / tr[(Rf^-1) @ Phi_{xx}] @ u
Reference:
T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer
for Simultaneous Denoising and Dereverberation," in IEEE Signal
Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi:
10.1109/LSP.2019.2911179.
https://ieeexplore.ieee.org/document/8691481
Args:
Phi (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C)
is the PSD of zero-padded speech [x^T(t,f) 0 ... 0]^T.
Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C)
is the power normalized spatio-temporal covariance matrix.
reference_vector (torch.Tensor): (B, (btaps+1) * C)
is the reference_vector.
use_torch_solver (bool): Whether to use `solve` instead of `inverse`
diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
diag_eps (float):
eps (float):
Returns:
filter_matrix (ComplexTensor): (B, F, (btaps + 1) * C)
"""
if diagonal_loading:
Rf = tik_reg(Rf, reg=diag_eps, eps=eps)
# numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
if use_torch_solver and is_torch_1_1_plus:
# torch.solve is required, which is only available after pytorch 1.1.0+
numerator = FC.solve(Phi, Rf)[0]
else:
numerator = FC.matmul(Rf.inverse2(), Phi)
# ws: (..., C, C) / (...,) -> (..., C, C)
ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
# h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
# (B, F, (btaps + 1) * C)
return beamform_vector
def get_WPD_filter_v2(
Phi: ComplexTensor,
Rf: ComplexTensor,
reference_vector: torch.Tensor,
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> ComplexTensor:
"""Return the WPD vector (v2).
This implementaion is more efficient than `get_WPD_filter` as
it skips unnecessary computation with zeros.
Args:
Phi (ComplexTensor): (B, F, C, C)
is speech PSD.
Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C)
is the power normalized spatio-temporal covariance matrix.
reference_vector (torch.Tensor): (B, C)
is the reference_vector.
diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
diag_eps (float):
eps (float):
Returns:
filter_matrix (ComplexTensor): (B, F, (btaps+1) * C)
"""
C = reference_vector.shape[-1]
if diagonal_loading:
Rf = tik_reg(Rf, reg=diag_eps, eps=eps)
inv_Rf = Rf.inverse2()
# (B, F, (btaps+1) * C, C)
inv_Rf_pruned = inv_Rf[..., :C]
# numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
numerator = FC.matmul(inv_Rf_pruned, Phi)
# ws: (..., (btaps+1) * C, C) / (...,) -> (..., (btaps+1) * C, C)
ws = numerator / (FC.trace(numerator[..., :C, :])[..., None, None] + eps)
# h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
# (B, F, (btaps+1) * C)
return beamform_vector
def get_WPD_filter_with_rtf(
psd_observed_bar: ComplexTensor,
psd_speech: ComplexTensor,
psd_noise: ComplexTensor,
iterations: int = 3,
reference_vector: Union[int, torch.Tensor, None] = None,
normalize_ref_channel: Optional[int] = None,
use_torch_solver: bool = True,
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-15,
) -> ComplexTensor:
"""Return the WPD vector calculated with RTF.
WPD is the Weighted Power minimization Distortionless response
convolutional beamformer. As follows:
h = (Rf^-1 @ vbar) / (vbar^H @ R^-1 @ vbar)
Reference:
T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer
for Simultaneous Denoising and Dereverberation," in IEEE Signal
Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi:
10.1109/LSP.2019.2911179.
https://ieeexplore.ieee.org/document/8691481
Args:
psd_observed_bar (ComplexTensor): stacked observation covariance matrix
psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C)
psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C)
iterations (int): number of iterations in power method
reference_vector (torch.Tensor or int): (..., C) or scalar
normalize_ref_channel (int): reference channel for normalizing the RTF
use_torch_solver (bool): Whether to use `solve` instead of `inverse`
diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n
diag_eps (float):
eps (float):
Returns:
beamform_vector (ComplexTensor)r: (..., F, C)
"""
C = psd_noise.size(-1)
if diagonal_loading:
psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps)
# (B, F, C, 1)
rtf = get_rtf(
psd_speech,
psd_noise,
reference_vector,
iterations=iterations,
use_torch_solver=use_torch_solver,
)
# (B, F, (K+1)*C, 1)
rtf = FC.pad(rtf, (0, 0, 0, psd_observed_bar.shape[-1] - C), "constant", 0)
# numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1)
if use_torch_solver and is_torch_1_1_plus:
# torch.solve is required, which is only available after pytorch 1.1.0+
numerator = FC.solve(rtf, psd_observed_bar)[0].squeeze(-1)
else:
numerator = FC.matmul(psd_observed_bar.inverse2(), rtf).squeeze(-1)
denominator = FC.einsum("...d,...d->...", [rtf.squeeze(-1).conj(), numerator])
if normalize_ref_channel is not None:
scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj()
beamforming_vector = numerator * scale / (denominator.real.unsqueeze(-1) + eps)
else:
beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps)
return beamforming_vector
def perform_WPD_filtering(
filter_matrix: ComplexTensor, Y: ComplexTensor, bdelay: int, btaps: int
) -> ComplexTensor:
"""Perform WPD filtering.
Args:
filter_matrix: Filter matrix (B, F, (btaps + 1) * C)
Y : Complex STFT signal with shape (B, F, C, T)
Returns:
enhanced (ComplexTensor): (B, F, T)
"""
# (B, F, C, T) --> (B, F, C, T, btaps + 1)
Ytilde = signal_framing(Y, btaps + 1, 1, bdelay, do_padding=True, pad_value=0)
Ytilde = FC.reverse(Ytilde, dim=-1)
Bs, Fdim, C, T = Y.shape
# --> (B, F, T, btaps + 1, C) --> (B, F, T, (btaps + 1) * C)
Ytilde = Ytilde.permute(0, 1, 3, 4, 2).contiguous().view(Bs, Fdim, T, -1)
# (B, F, T, 1)
enhanced = FC.einsum("...tc,...c->...t", [Ytilde, filter_matrix.conj()])
return enhanced
def tik_reg(mat: ComplexTensor, reg: float = 1e-8, eps: float = 1e-8) -> ComplexTensor:
"""Perform Tikhonov regularization (only modifying real part).
Args:
mat (ComplexTensor): input matrix (..., C, C)
reg (float): regularization factor
eps (float)
Returns:
ret (ComplexTensor): regularized matrix (..., C, C)
"""
# Add eps
C = mat.size(-1)
eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
shape = [1 for _ in range(mat.dim() - 2)] + [C, C]
eye = eye.view(*shape).repeat(*mat.shape[:-2], 1, 1)
with torch.no_grad():
epsilon = FC.trace(mat).real[..., None, None] * reg
# in case that correlation_matrix is all-zero
epsilon = epsilon + eps
mat = mat + epsilon * eye
return mat
##############################################
# Below are for Multi-Frame MVDR beamforming #
##############################################
# modified from https://gitlab.uni-oldenburg.de/hura4843/deep-mfmvdr/-/blob/master/deep_mfmvdr (# noqa: E501)
def get_adjacent(spec: ComplexTensor, filter_length: int = 5) -> ComplexTensor:
"""Zero-pad and unfold stft, i.e.,
add zeros to the beginning so that, using the multi-frame signal model,
there will be as many output frames as input frames.
Args:
spec (ComplexTensor): input spectrum (B, F, T)
filter_length (int): length for frame extension
Returns:
ret (ComplexTensor): output spectrum (B, F, T, filter_length)
""" # noqa: D400
return (
FC.pad(spec, pad=[filter_length - 1, 0])
.unfold(dim=-1, size=filter_length, step=1)
.contiguous()
)
def get_adjacent_th(spec: torch.Tensor, filter_length: int = 5) -> torch.Tensor:
"""Zero-pad and unfold stft, i.e.,
add zeros to the beginning so that, using the multi-frame signal model,
there will be as many output frames as input frames.
Args:
spec (torch.Tensor): input spectrum (B, F, T, 2)
filter_length (int): length for frame extension
Returns:
ret (torch.Tensor): output spectrum (B, F, T, filter_length, 2)
""" # noqa: D400
return (
torch.nn.functional.pad(spec, pad=[0, 0, filter_length - 1, 0])
.unfold(dimension=-2, size=filter_length, step=1)
.transpose(-2, -1)
.contiguous()
)
def vector_to_Hermitian(vec):
"""Construct a Hermitian matrix from a vector of N**2 independent
real-valued elements.
Args:
vec (torch.Tensor): (..., N ** 2)
Returns:
mat (ComplexTensor): (..., N, N)
""" # noqa: H405, D205, D400
N = int(np.sqrt(vec.shape[-1]))
mat = torch.zeros(size=vec.shape[:-1] + (N, N, 2), device=vec.device)
# real component
triu = np.triu_indices(N, 0)
triu2 = np.triu_indices(N, 1) # above main diagonal
tril = (triu2[1], triu2[0]) # below main diagonal; for symmetry
mat[(...,) + triu + (np.zeros(triu[0].shape[0]),)] = vec[..., : triu[0].shape[0]]
start = triu[0].shape[0]
mat[(...,) + tril + (np.zeros(tril[0].shape[0]),)] = mat[
(...,) + triu2 + (np.zeros(triu2[0].shape[0]),)
]
# imaginary component
mat[(...,) + triu2 + (np.ones(triu2[0].shape[0]),)] = vec[
..., start : start + triu2[0].shape[0]
]
mat[(...,) + tril + (np.ones(tril[0].shape[0]),)] = -mat[
(...,) + triu2 + (np.ones(triu2[0].shape[0]),)
]
return ComplexTensor(mat[..., 0], mat[..., 1])
def get_mfmvdr_vector(gammax, Phi, use_torch_solver: bool = True, eps: float = EPS):
"""Compute conventional MFMPDR/MFMVDR filter.
Args:
gammax (ComplexTensor): (..., L, N)
Phi (ComplexTensor): (..., L, N, N)
use_torch_solver (bool): Whether to use `solve` instead of `inverse`
eps (float)
Returns:
beamforming_vector (ComplexTensor): (..., L, N)
"""
# (..., L, N)
if use_torch_solver and is_torch_1_1_plus:
# torch.solve is required, which is only available after pytorch 1.1.0+
numerator = FC.solve(gammax.unsqueeze(-1), Phi)[0].squeeze(-1)
else:
numerator = FC.matmul(Phi.inverse2(), gammax.unsqueeze(-1)).squeeze(-1)
denominator = FC.einsum("...d,...d->...", [gammax.conj(), numerator])
return numerator / (denominator.real.unsqueeze(-1) + eps)
def filter_minimum_gain_like(
G_min, w, y, alpha=None, k: float = 10.0, eps: float = EPS
):
"""Approximate a minimum gain operation.
speech_estimate = alpha w^H y + (1 - alpha) G_min Y,
where alpha = 1 / (1 + exp(-2 k x)), x = w^H y - G_min Y
Args:
G_min (float): minimum gain
w (ComplexTensor): filter coefficients (..., L, N)
y (ComplexTensor): buffered and stacked input (..., L, N)
alpha: mixing factor
k (float): scaling in tanh-like function
esp (float)
Returns:
output (ComplexTensor): minimum gain-filtered output
alpha (float): optional
"""
# (..., L)
filtered_input = FC.einsum("...d,...d->...", [w.conj(), y])
# (..., L)
Y = y[..., -1]
return minimum_gain_like(G_min, Y, filtered_input, alpha, k, eps)
def minimum_gain_like(
G_min, Y, filtered_input, alpha=None, k: float = 10.0, eps: float = EPS
):
if alpha is None:
diff = (filtered_input + eps).abs() - (G_min * Y + eps).abs()
alpha = 1.0 / (1.0 + torch.exp(-2 * k * diff))
return_alpha = True
else:
return_alpha = False
output = alpha * filtered_input + (1 - alpha) * G_min * Y
if return_alpha:
return output, alpha
else:
return output