|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, Optional |
|
|
|
|
|
import torch |
|
|
|
|
|
from nemo.core.classes import NeuralModule, typecheck |
|
|
from nemo.core.neural_types import NeuralType, SpectrogramType |
|
|
|
|
|
|
|
|
class MixtureConsistencyProjection(NeuralModule): |
|
|
"""Ensure estimated sources are consistent with the input mixture. |
|
|
Note that the input mixture is assume to be a single-channel signal. |
|
|
|
|
|
Args: |
|
|
weighting: Optional weighting mode for the consistency constraint. |
|
|
If `None`, use uniform weighting. If `power`, use the power of the |
|
|
estimated source as the weight. |
|
|
eps: Small positive value for regularization |
|
|
|
|
|
Reference: |
|
|
Wisdom et al, Differentiable consistency constraints for improved deep speech enhancement, 2018 |
|
|
""" |
|
|
|
|
|
def __init__(self, weighting: Optional[str] = None, eps: float = 1e-8): |
|
|
super().__init__() |
|
|
self.weighting = weighting |
|
|
self.eps = eps |
|
|
|
|
|
if self.weighting not in [None, 'power']: |
|
|
raise NotImplementedError(f'Weighting mode {self.weighting} not implemented') |
|
|
|
|
|
@property |
|
|
def input_types(self) -> Dict[str, NeuralType]: |
|
|
"""Returns definitions of module output ports.""" |
|
|
return { |
|
|
"mixture": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), |
|
|
"estimate": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), |
|
|
} |
|
|
|
|
|
@property |
|
|
def output_types(self) -> Dict[str, NeuralType]: |
|
|
"""Returns definitions of module output ports.""" |
|
|
return { |
|
|
"output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), |
|
|
} |
|
|
|
|
|
@typecheck() |
|
|
def forward(self, mixture: torch.Tensor, estimate: torch.Tensor) -> torch.Tensor: |
|
|
"""Enforce mixture consistency on the estimated sources. |
|
|
Args: |
|
|
mixture: Single-channel mixture, shape (B, 1, F, N) |
|
|
estimate: M estimated sources, shape (B, M, F, N) |
|
|
|
|
|
Returns: |
|
|
Source estimates consistent with the mixture, shape (B, M, F, N) |
|
|
""" |
|
|
if mixture.size(-3) != 1: |
|
|
raise ValueError(f'Mixture must have a single channel, got shape {mixture.shape}') |
|
|
|
|
|
|
|
|
M = estimate.size(-3) |
|
|
|
|
|
estimated_mixture = torch.sum(estimate, dim=-3, keepdim=True) |
|
|
|
|
|
|
|
|
if self.weighting is None: |
|
|
weight = 1 / M |
|
|
elif self.weighting == 'power': |
|
|
weight = estimate.abs().pow(2) |
|
|
weight = weight / (weight.sum(dim=-3, keepdim=True) + self.eps) |
|
|
else: |
|
|
raise NotImplementedError(f'Weighting mode {self.weighting} not implemented') |
|
|
|
|
|
|
|
|
consistent_estimate = estimate + weight * (mixture - estimated_mixture) |
|
|
|
|
|
return consistent_estimate |
|
|
|