Spaces:
Sleeping
Sleeping
import os | |
import warnings | |
import numpy as np | |
import torch | |
from torch import nn | |
from ..masknn import activations | |
from ..utils.torch_utils import pad_x_to_y | |
def _unsqueeze_to_3d(x): | |
if x.ndim == 1: | |
return x.reshape(1, 1, -1) | |
elif x.ndim == 2: | |
return x.unsqueeze(1) | |
else: | |
return x | |
class BaseModel(nn.Module): | |
def __init__(self): | |
print("initialize BaseModel") | |
super().__init__() | |
def forward(self, *args, **kwargs): | |
raise NotImplementedError | |
def separate(self, wav, output_dir=None, force_overwrite=False, **kwargs): | |
"""Infer separated sources from input waveforms. | |
Also supports filenames. | |
Args: | |
wav (Union[torch.Tensor, numpy.ndarray, str]): waveform array/tensor. | |
Shape: 1D, 2D or 3D tensor, time last. | |
output_dir (str): path to save all the wav files. If None, | |
estimated sources will be saved next to the original ones. | |
force_overwrite (bool): whether to overwrite existing files. | |
**kwargs: keyword arguments to be passed to `_separate`. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray, None], the estimated sources. | |
(batch, n_src, time) or (n_src, time) w/o batch dim. | |
.. note:: | |
By default, `separate` calls `_separate` which calls `forward`. | |
For models whose `forward` doesn't return waveform tensors, | |
overwrite `_separate` to return waveform tensors. | |
""" | |
if isinstance(wav, str): | |
self.file_separate( | |
wav, output_dir=output_dir, force_overwrite=force_overwrite, **kwargs | |
) | |
elif isinstance(wav, np.ndarray): | |
print("is ndarray") | |
# import pdb ; pdb.set_trace() | |
return self.numpy_separate(wav, **kwargs) | |
elif isinstance(wav, torch.Tensor): | |
print("is torch.Tensor") | |
return self.torch_separate(wav, **kwargs) | |
else: | |
raise ValueError( | |
f"Only support filenames, numpy arrays and torch tensors, received {type(wav)}" | |
) | |
def torch_separate(self, wav: torch.Tensor, **kwargs) -> torch.Tensor: | |
""" Core logic of `separate`.""" | |
# Handle device placement | |
input_device = wav.device | |
model_device = next(self.parameters()).device | |
wav = wav.to(model_device) | |
# Forward | |
out_wavs = self._separate(wav, **kwargs) | |
# FIXME: for now this is the best we can do. | |
out_wavs *= wav.abs().sum() / (out_wavs.abs().sum()) | |
# Back to input device (and numpy if necessary) | |
out_wavs = out_wavs.to(input_device) | |
return out_wavs | |
def numpy_separate(self, wav: np.ndarray, **kwargs) -> np.ndarray: | |
""" Numpy interface to `separate`.""" | |
wav = torch.from_numpy(wav) | |
out_wav = self.torch_separate(wav, **kwargs) | |
out_wav = out_wav.data.numpy() | |
return out_wav | |
def file_separate( | |
self, filename: str, output_dir=None, force_overwrite=False, **kwargs | |
) -> None: | |
""" Filename interface to `separate`.""" | |
import soundfile as sf | |
wav, fs = sf.read(filename, dtype="float32", always_2d=True) | |
# FIXME: support only single-channel files for now. | |
to_save = self.numpy_separate(wav[:, 0], **kwargs) | |
# Save wav files to filename_est1.wav etc... | |
for src_idx, est_src in enumerate(to_save): | |
base = ".".join(filename.split(".")[:-1]) | |
save_name = base + "_est{}.".format(src_idx + 1) + filename.split(".")[-1] | |
if os.path.isfile(save_name) and not force_overwrite: | |
warnings.warn( | |
f"File {save_name} already exists, pass `force_overwrite=True` to overwrite it", | |
UserWarning, | |
) | |
return | |
if output_dir is not None: | |
save_name = os.path.join(output_dir, save_name.split("/")[-1]) | |
sf.write(save_name, est_src, fs) | |
def _separate(self, wav, *args, **kwargs): | |
"""Hidden separation method | |
Args: | |
wav (Union[torch.Tensor, numpy.ndarray, str]): waveform array/tensor. | |
Shape: 1D, 2D or 3D tensor, time last. | |
Returns: | |
The output of self(wav, *args, **kwargs). | |
""" | |
return self(wav, *args, **kwargs) | |
def from_pretrained(cls, pretrained_model_conf_or_path, *args, **kwargs): | |
"""Instantiate separation model from a model config (file or dict). | |
Args: | |
pretrained_model_conf_or_path (Union[dict, str]): model conf as | |
returned by `serialize`, or path to it. Need to contain | |
`model_args` and `state_dict` keys. | |
*args: Positional arguments to be passed to the model. | |
**kwargs: Keyword arguments to be passed to the model. | |
They overwrite the ones in the model package. | |
Returns: | |
nn.Module corresponding to the pretrained model conf/URL. | |
Raises: | |
ValueError if the input config file doesn't contain the keys | |
`model_name`, `model_args` or `state_dict`. | |
""" | |
from . import get # Avoid circular imports | |
if isinstance(pretrained_model_conf_or_path, str): | |
# cached_model = self.cached_download(pretrained_model_conf_or_path) | |
if os.path.isfile(pretrained_model_conf_or_path): | |
cached_model = pretrained_model_conf_or_path | |
else: | |
raise ValueError( | |
"Model {} is not a file or doesn't exist.".format(pretrained_model_conf_or_path) | |
) | |
conf = torch.load(cached_model, map_location="cpu") | |
else: | |
conf = pretrained_model_conf_or_path | |
if "model_name" not in conf.keys(): | |
raise ValueError( | |
"Expected config dictionary to have field " | |
"model_name`. Found only: {}".format(conf.keys()) | |
) | |
if "state_dict" not in conf.keys(): | |
raise ValueError( | |
"Expected config dictionary to have field " | |
"state_dict`. Found only: {}".format(conf.keys()) | |
) | |
if "model_args" not in conf.keys(): | |
raise ValueError( | |
"Expected config dictionary to have field " | |
"model_args`. Found only: {}".format(conf.keys()) | |
) | |
conf["model_args"].update(kwargs) # kwargs overwrite config. | |
# Attempt to find the model and instantiate it. | |
try: | |
model_class = get(conf["model_name"]) | |
except ValueError: # Couldn't get the model, maybe custom. | |
model = cls(*args, **conf["model_args"]) # Child class. | |
else: | |
model = model_class(*args, **conf["model_args"]) | |
model.load_state_dict(conf["state_dict"]) | |
return model | |
def serialize(self): | |
"""Serialize model and output dictionary. | |
Returns: | |
dict, serialized model with keys `model_args` and `state_dict`. | |
""" | |
import pytorch_lightning as pl # Not used in torch.hub | |
from .. import __version__ as asteroid_version # Avoid circular imports | |
model_conf = dict( | |
model_name=self.__class__.__name__, | |
state_dict=self.get_state_dict(), | |
model_args=self.get_model_args(), | |
) | |
# Additional infos | |
infos = dict() | |
infos["software_versions"] = dict( | |
torch_version=torch.__version__, | |
pytorch_lightning_version=pl.__version__, | |
asteroid_version=asteroid_version, | |
) | |
model_conf["infos"] = infos | |
return model_conf | |
def get_state_dict(self): | |
""" In case the state dict needs to be modified before sharing the model.""" | |
return self.state_dict() | |
def get_model_args(self): | |
raise NotImplementedError | |
def cached_download(self, filename_or_url): | |
if os.path.isfile(filename_or_url): | |
print("is file") | |
return filename_or_url | |
else: | |
print("Model {} is not a file or doesn't exist.".format(filename_or_url)) | |
class BaseEncoderMaskerDecoder(BaseModel): | |
"""Base class for encoder-masker-decoder separation models. | |
Args: | |
encoder (Encoder): Encoder instance. | |
masker (nn.Module): masker network. | |
decoder (Decoder): Decoder instance. | |
encoder_activation (Optional[str], optional): Activation to apply after encoder. | |
See ``asteroid.masknn.activations`` for valid values. | |
""" | |
def __init__(self, encoder, masker, decoder, encoder_activation=None): | |
super().__init__() | |
self.encoder = encoder | |
self.masker = masker | |
self.decoder = decoder | |
self.encoder_activation = encoder_activation | |
self.enc_activation = activations.get(encoder_activation or "linear")() | |
def forward(self, wav): | |
"""Enc/Mask/Dec model forward | |
Args: | |
wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last. | |
Returns: | |
torch.Tensor, of shape (batch, n_src, time) or (n_src, time). | |
""" | |
# Handle 1D, 2D or n-D inputs | |
was_one_d = wav.ndim == 1 | |
# Reshape to (batch, n_mix, time) | |
wav = _unsqueeze_to_3d(wav) | |
# Real forward | |
tf_rep = self.encoder(wav) | |
tf_rep = self.postprocess_encoded(tf_rep) | |
tf_rep = self.enc_activation(tf_rep) | |
est_masks = self.masker(tf_rep) | |
est_masks = self.postprocess_masks(est_masks) | |
masked_tf_rep = est_masks * tf_rep.unsqueeze(1) | |
masked_tf_rep = self.postprocess_masked(masked_tf_rep) | |
decoded = self.decoder(masked_tf_rep) | |
decoded = self.postprocess_decoded(decoded) | |
reconstructed = pad_x_to_y(decoded, wav) | |
if was_one_d: | |
return reconstructed.squeeze(0) | |
else: | |
return reconstructed | |
def postprocess_encoded(self, tf_rep): | |
"""Hook to perform transformations on the encoded, time-frequency domain | |
representation (output of the encoder) before encoder activation is applied. | |
Args: | |
tf_rep (Tensor of shape (batch, freq, time)): | |
Output of the encoder, before encoder activation is applied. | |
Return: | |
Transformed `tf_rep` | |
""" | |
return tf_rep | |
def postprocess_masks(self, masks): | |
"""Hook to perform transformations on the masks (output of the masker) before | |
masks are applied. | |
Args: | |
masks (Tensor of shape (batch, n_src, freq, time)): | |
Output of the masker | |
Return: | |
Transformed `masks` | |
""" | |
return masks | |
def postprocess_masked(self, masked_tf_rep): | |
"""Hook to perform transformations on the masked time-frequency domain | |
representation (result of masking in the time-frequency domain) before decoding. | |
Args: | |
masked_tf_rep (Tensor of shape (batch, n_src, freq, time)): | |
Masked time-frequency representation, before decoding. | |
Return: | |
Transformed `masked_tf_rep` | |
""" | |
return masked_tf_rep | |
def postprocess_decoded(self, decoded): | |
"""Hook to perform transformations on the decoded, time domain representation | |
(output of the decoder) before original shape reconstruction. | |
Args: | |
decoded (Tensor of shape (batch, n_src, time)): | |
Output of the decoder, before original shape reconstruction. | |
Return: | |
Transformed `decoded` | |
""" | |
return decoded | |
def get_model_args(self): | |
""" Arguments needed to re-instantiate the model. """ | |
fb_config = self.encoder.filterbank.get_config() | |
masknet_config = self.masker.get_config() | |
# Assert both dict are disjoint | |
if not all(k not in fb_config for k in masknet_config): | |
raise AssertionError( | |
"Filterbank and Mask network config share" "common keys. Merging them is not safe." | |
) | |
# Merge all args under model_args. | |
model_args = { | |
**fb_config, | |
**masknet_config, | |
"encoder_activation": self.encoder_activation, | |
} | |
return model_args | |
# Backwards compatibility | |
BaseTasNet = BaseEncoderMaskerDecoder | |