poiqazwsx's picture
Upload 57 files
51e2f90
raw
history blame
No virus
14.6 kB
from dataclasses import dataclass
from torchaudio._internal import load_state_dict_from_url
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
def transform_wb_pesq_range(x: float) -> float:
"""The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined
for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric
defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score".
Args:
x (float): Narrow-band PESQ score.
Returns:
(float): Wide-band PESQ score.
"""
return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
PESQRange: Tuple[float, float] = (
1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of
# the raw score is not -0.5 anymore. It's hard to figure out the true lower bound.
# We are using 1.0 as a reasonable approximation.
transform_wb_pesq_range(4.5),
)
class RangeSigmoid(nn.Module):
def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
super(RangeSigmoid, self).__init__()
assert isinstance(val_range, tuple) and len(val_range) == 2
self.val_range: Tuple[float, float] = val_range
self.sigmoid: nn.modules.Module = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0]
return out
class Encoder(nn.Module):
"""Encoder module that transform 1D waveform to 2D representations.
Args:
feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512)
win_len (int, optional): kernel size in the Conv1D layer. (Default: 32)
"""
def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
super(Encoder, self).__init__()
self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply waveforms to convolutional layer and ReLU layer.
Args:
x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
Returns:
(torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`.
"""
out = x.unsqueeze(dim=1)
out = F.relu(self.conv1d(out))
return out
class SingleRNN(nn.Module):
def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None:
super(SingleRNN, self).__init__()
self.rnn_type = rnn_type
self.input_size = input_size
self.hidden_size = hidden_size
self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
input_size,
hidden_size,
1,
dropout=dropout,
batch_first=True,
bidirectional=True,
)
self.proj = nn.Linear(hidden_size * 2, input_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# input shape: batch, seq, dim
out, _ = self.rnn(x)
out = self.proj(out)
return out
class DPRNN(nn.Module):
"""*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`.
Args:
feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64)
hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128)
num_blocks (int, optional): Number of DPRNN layers. (Default: 6)
rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM")
d_model (int, optional): The number of expected features in the input. (Default: 256)
chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100)
chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50)
"""
def __init__(
self,
feat_dim: int = 64,
hidden_dim: int = 128,
num_blocks: int = 6,
rnn_type: str = "LSTM",
d_model: int = 256,
chunk_size: int = 100,
chunk_stride: int = 50,
) -> None:
super(DPRNN, self).__init__()
self.num_blocks = num_blocks
self.row_rnn = nn.ModuleList([])
self.col_rnn = nn.ModuleList([])
self.row_norm = nn.ModuleList([])
self.col_norm = nn.ModuleList([])
for _ in range(num_blocks):
self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
self.conv = nn.Sequential(
nn.Conv2d(feat_dim, d_model, 1),
nn.PReLU(),
)
self.chunk_size = chunk_size
self.chunk_stride = chunk_stride
def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
# input shape: (B, N, T)
seq_len = x.shape[-1]
rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
return out, rest
def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
out, rest = self.pad_chunk(x)
batch_size, feat_dim, seq_len = out.shape
segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
out = torch.cat([segments1, segments2], dim=3)
out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous()
return out, rest
def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
batch_size, dim, _, _ = x.shape
out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2)
out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :]
out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
out = out1 + out2
if rest > 0:
out = out[:, :, :-rest]
out = out.contiguous()
return out
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, rest = self.chunking(x)
batch_size, _, dim1, dim2 = x.shape
out = x
for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm):
row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous()
row_out = row_rnn(row_in)
row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
row_out = row_norm(row_out)
out = out + row_out
col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous()
col_out = col_rnn(col_in)
col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
col_out = col_norm(col_out)
out = out + col_out
out = self.conv(out)
out = self.merging(out, rest)
out = out.transpose(1, 2).contiguous()
return out
class AutoPool(nn.Module):
def __init__(self, pool_dim: int = 1) -> None:
super(AutoPool, self).__init__()
self.pool_dim: int = pool_dim
self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
weight = self.softmax(torch.mul(x, self.alpha))
out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
return out
class SquimObjective(nn.Module):
"""Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
Args:
encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation.
dprnn (torch.nn.Module): DPRNN module to model sequential feature.
branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score.
"""
def __init__(
self,
encoder: nn.Module,
dprnn: nn.Module,
branches: nn.ModuleList,
):
super(SquimObjective, self).__init__()
self.encoder = encoder
self.dprnn = dprnn
self.branches = branches
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""
Args:
x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
Returns:
List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`.
"""
if x.ndim != 2:
raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.")
x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
out = self.encoder(x)
out = self.dprnn(out)
scores = []
for branch in self.branches:
scores.append(branch(out).squeeze(dim=1))
return scores
def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
"""Create branch module after DPRNN model for predicting metric score.
Args:
d_model (int): The number of expected features in the input.
nhead (int): Number of heads in the multi-head attention model.
metric (str): The metric name to predict.
Returns:
(nn.Module): Returned module to predict corresponding metric score.
"""
layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True)
layer2 = AutoPool()
if metric == "stoi":
layer3 = nn.Sequential(
nn.Linear(d_model, d_model),
nn.PReLU(),
nn.Linear(d_model, 1),
RangeSigmoid(),
)
elif metric == "pesq":
layer3 = nn.Sequential(
nn.Linear(d_model, d_model),
nn.PReLU(),
nn.Linear(d_model, 1),
RangeSigmoid(val_range=PESQRange),
)
else:
layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1))
return nn.Sequential(layer1, layer2, layer3)
def squim_objective_model(
feat_dim: int,
win_len: int,
d_model: int,
nhead: int,
hidden_dim: int,
num_blocks: int,
rnn_type: str,
chunk_size: int,
chunk_stride: Optional[int] = None,
) -> SquimObjective:
"""Build a custome :class:`torchaudio.prototype.models.SquimObjective` model.
Args:
feat_dim (int, optional): The feature dimension after Encoder module.
win_len (int): Kernel size in the Encoder module.
d_model (int): The number of expected features in the input.
nhead (int): Number of heads in the multi-head attention model.
hidden_dim (int): Hidden dimension in the RNN layer of DPRNN.
num_blocks (int): Number of DPRNN layers.
rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"].
chunk_size (int): Chunk size of input for DPRNN.
chunk_stride (int or None, optional): Stride of chunk input for DPRNN.
"""
if chunk_stride is None:
chunk_stride = chunk_size // 2
encoder = Encoder(feat_dim, win_len)
dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride)
branches = nn.ModuleList(
[
_create_branch(d_model, nhead, "stoi"),
_create_branch(d_model, nhead, "pesq"),
_create_branch(d_model, nhead, "sisdr"),
]
)
return SquimObjective(encoder, dprnn, branches)
def squim_objective_base() -> SquimObjective:
"""Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments."""
return squim_objective_model(
feat_dim=256,
win_len=64,
d_model=256,
nhead=4,
hidden_dim=256,
num_blocks=2,
rnn_type="LSTM",
chunk_size=71,
)
@dataclass
class SquimObjectiveBundle:
_path: str
_sample_rate: float
def _get_state_dict(self, dl_kwargs):
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
return state_dict
def get_model(self, *, dl_kwargs=None) -> SquimObjective:
"""Construct the SquimObjective model, and load the pretrained weight.
The weight file is downloaded from the internet and cached with
:func:`torch.hub.load_state_dict_from_url`
Args:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
Returns:
Variation of :py:class:`~torchaudio.models.SquimObjective`.
"""
model = squim_objective_base()
model.load_state_dict(self._get_state_dict(dl_kwargs))
model.eval()
return model
@property
def sample_rate(self):
"""Sample rate of the audio that the model is trained on.
:type: float
"""
return self._sample_rate
SQUIM_OBJECTIVE = SquimObjectiveBundle(
"squim_objective_dns2020.pth",
_sample_rate=16000,
)
SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
:cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
The weights are under `Creative Commons Attribution 4.0 International License
<https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
"""