Spaces:
Runtime error
Runtime error
| import os | |
| import warnings | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from ..masknn import activations | |
| from ..utils.torch_utils import pad_x_to_y | |
| def _unsqueeze_to_3d(x): | |
| if x.ndim == 1: | |
| return x.reshape(1, 1, -1) | |
| elif x.ndim == 2: | |
| return x.unsqueeze(1) | |
| else: | |
| return x | |
| class BaseModel(nn.Module): | |
| def __init__(self): | |
| print("initialize BaseModel") | |
| super().__init__() | |
| def forward(self, *args, **kwargs): | |
| raise NotImplementedError | |
| def separate(self, wav, output_dir=None, force_overwrite=False, **kwargs): | |
| """Infer separated sources from input waveforms. | |
| Also supports filenames. | |
| Args: | |
| wav (Union[torch.Tensor, numpy.ndarray, str]): waveform array/tensor. | |
| Shape: 1D, 2D or 3D tensor, time last. | |
| output_dir (str): path to save all the wav files. If None, | |
| estimated sources will be saved next to the original ones. | |
| force_overwrite (bool): whether to overwrite existing files. | |
| **kwargs: keyword arguments to be passed to `_separate`. | |
| Returns: | |
| Union[torch.Tensor, numpy.ndarray, None], the estimated sources. | |
| (batch, n_src, time) or (n_src, time) w/o batch dim. | |
| .. note:: | |
| By default, `separate` calls `_separate` which calls `forward`. | |
| For models whose `forward` doesn't return waveform tensors, | |
| overwrite `_separate` to return waveform tensors. | |
| """ | |
| if isinstance(wav, str): | |
| self.file_separate( | |
| wav, output_dir=output_dir, force_overwrite=force_overwrite, **kwargs | |
| ) | |
| elif isinstance(wav, np.ndarray): | |
| print("is ndarray") | |
| # import pdb ; pdb.set_trace() | |
| return self.numpy_separate(wav, **kwargs) | |
| elif isinstance(wav, torch.Tensor): | |
| print("is torch.Tensor") | |
| return self.torch_separate(wav, **kwargs) | |
| else: | |
| raise ValueError( | |
| f"Only support filenames, numpy arrays and torch tensors, received {type(wav)}" | |
| ) | |
| def torch_separate(self, wav: torch.Tensor, **kwargs) -> torch.Tensor: | |
| """ Core logic of `separate`.""" | |
| # Handle device placement | |
| input_device = wav.device | |
| model_device = next(self.parameters()).device | |
| wav = wav.to(model_device) | |
| # Forward | |
| out_wavs = self._separate(wav, **kwargs) | |
| # FIXME: for now this is the best we can do. | |
| out_wavs *= wav.abs().sum() / (out_wavs.abs().sum()) | |
| # Back to input device (and numpy if necessary) | |
| out_wavs = out_wavs.to(input_device) | |
| return out_wavs | |
| def numpy_separate(self, wav: np.ndarray, **kwargs) -> np.ndarray: | |
| """ Numpy interface to `separate`.""" | |
| wav = torch.from_numpy(wav) | |
| out_wav = self.torch_separate(wav, **kwargs) | |
| out_wav = out_wav.data.numpy() | |
| return out_wav | |
| def file_separate( | |
| self, filename: str, output_dir=None, force_overwrite=False, **kwargs | |
| ) -> None: | |
| """ Filename interface to `separate`.""" | |
| import soundfile as sf | |
| wav, fs = sf.read(filename, dtype="float32", always_2d=True) | |
| # FIXME: support only single-channel files for now. | |
| to_save = self.numpy_separate(wav[:, 0], **kwargs) | |
| # Save wav files to filename_est1.wav etc... | |
| for src_idx, est_src in enumerate(to_save): | |
| base = ".".join(filename.split(".")[:-1]) | |
| save_name = base + "_est{}.".format(src_idx + 1) + filename.split(".")[-1] | |
| if os.path.isfile(save_name) and not force_overwrite: | |
| warnings.warn( | |
| f"File {save_name} already exists, pass `force_overwrite=True` to overwrite it", | |
| UserWarning, | |
| ) | |
| return | |
| if output_dir is not None: | |
| save_name = os.path.join(output_dir, save_name.split("/")[-1]) | |
| sf.write(save_name, est_src, fs) | |
| def _separate(self, wav, *args, **kwargs): | |
| """Hidden separation method | |
| Args: | |
| wav (Union[torch.Tensor, numpy.ndarray, str]): waveform array/tensor. | |
| Shape: 1D, 2D or 3D tensor, time last. | |
| Returns: | |
| The output of self(wav, *args, **kwargs). | |
| """ | |
| return self(wav, *args, **kwargs) | |
| def from_pretrained(cls, pretrained_model_conf_or_path, *args, **kwargs): | |
| """Instantiate separation model from a model config (file or dict). | |
| Args: | |
| pretrained_model_conf_or_path (Union[dict, str]): model conf as | |
| returned by `serialize`, or path to it. Need to contain | |
| `model_args` and `state_dict` keys. | |
| *args: Positional arguments to be passed to the model. | |
| **kwargs: Keyword arguments to be passed to the model. | |
| They overwrite the ones in the model package. | |
| Returns: | |
| nn.Module corresponding to the pretrained model conf/URL. | |
| Raises: | |
| ValueError if the input config file doesn't contain the keys | |
| `model_name`, `model_args` or `state_dict`. | |
| """ | |
| from . import get # Avoid circular imports | |
| if isinstance(pretrained_model_conf_or_path, str): | |
| # cached_model = self.cached_download(pretrained_model_conf_or_path) | |
| if os.path.isfile(pretrained_model_conf_or_path): | |
| cached_model = pretrained_model_conf_or_path | |
| else: | |
| raise ValueError( | |
| "Model {} is not a file or doesn't exist.".format(pretrained_model_conf_or_path) | |
| ) | |
| conf = torch.load(cached_model, map_location="cpu") | |
| else: | |
| conf = pretrained_model_conf_or_path | |
| if "model_name" not in conf.keys(): | |
| raise ValueError( | |
| "Expected config dictionary to have field " | |
| "model_name`. Found only: {}".format(conf.keys()) | |
| ) | |
| if "state_dict" not in conf.keys(): | |
| raise ValueError( | |
| "Expected config dictionary to have field " | |
| "state_dict`. Found only: {}".format(conf.keys()) | |
| ) | |
| if "model_args" not in conf.keys(): | |
| raise ValueError( | |
| "Expected config dictionary to have field " | |
| "model_args`. Found only: {}".format(conf.keys()) | |
| ) | |
| conf["model_args"].update(kwargs) # kwargs overwrite config. | |
| # Attempt to find the model and instantiate it. | |
| try: | |
| model_class = get(conf["model_name"]) | |
| except ValueError: # Couldn't get the model, maybe custom. | |
| model = cls(*args, **conf["model_args"]) # Child class. | |
| else: | |
| model = model_class(*args, **conf["model_args"]) | |
| model.load_state_dict(conf["state_dict"]) | |
| return model | |
| def serialize(self): | |
| """Serialize model and output dictionary. | |
| Returns: | |
| dict, serialized model with keys `model_args` and `state_dict`. | |
| """ | |
| import pytorch_lightning as pl # Not used in torch.hub | |
| from .. import __version__ as asteroid_version # Avoid circular imports | |
| model_conf = dict( | |
| model_name=self.__class__.__name__, | |
| state_dict=self.get_state_dict(), | |
| model_args=self.get_model_args(), | |
| ) | |
| # Additional infos | |
| infos = dict() | |
| infos["software_versions"] = dict( | |
| torch_version=torch.__version__, | |
| pytorch_lightning_version=pl.__version__, | |
| asteroid_version=asteroid_version, | |
| ) | |
| model_conf["infos"] = infos | |
| return model_conf | |
| def get_state_dict(self): | |
| """ In case the state dict needs to be modified before sharing the model.""" | |
| return self.state_dict() | |
| def get_model_args(self): | |
| raise NotImplementedError | |
| def cached_download(self, filename_or_url): | |
| if os.path.isfile(filename_or_url): | |
| print("is file") | |
| return filename_or_url | |
| else: | |
| print("Model {} is not a file or doesn't exist.".format(filename_or_url)) | |
| class BaseEncoderMaskerDecoder(BaseModel): | |
| """Base class for encoder-masker-decoder separation models. | |
| Args: | |
| encoder (Encoder): Encoder instance. | |
| masker (nn.Module): masker network. | |
| decoder (Decoder): Decoder instance. | |
| encoder_activation (Optional[str], optional): Activation to apply after encoder. | |
| See ``asteroid.masknn.activations`` for valid values. | |
| """ | |
| def __init__(self, encoder, masker, decoder, encoder_activation=None): | |
| super().__init__() | |
| self.encoder = encoder | |
| self.masker = masker | |
| self.decoder = decoder | |
| self.encoder_activation = encoder_activation | |
| self.enc_activation = activations.get(encoder_activation or "linear")() | |
| def forward(self, wav): | |
| """Enc/Mask/Dec model forward | |
| Args: | |
| wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last. | |
| Returns: | |
| torch.Tensor, of shape (batch, n_src, time) or (n_src, time). | |
| """ | |
| # Handle 1D, 2D or n-D inputs | |
| was_one_d = wav.ndim == 1 | |
| # Reshape to (batch, n_mix, time) | |
| wav = _unsqueeze_to_3d(wav) | |
| # Real forward | |
| tf_rep = self.encoder(wav) | |
| tf_rep = self.postprocess_encoded(tf_rep) | |
| tf_rep = self.enc_activation(tf_rep) | |
| est_masks = self.masker(tf_rep) | |
| est_masks = self.postprocess_masks(est_masks) | |
| masked_tf_rep = est_masks * tf_rep.unsqueeze(1) | |
| masked_tf_rep = self.postprocess_masked(masked_tf_rep) | |
| decoded = self.decoder(masked_tf_rep) | |
| decoded = self.postprocess_decoded(decoded) | |
| reconstructed = pad_x_to_y(decoded, wav) | |
| if was_one_d: | |
| return reconstructed.squeeze(0) | |
| else: | |
| return reconstructed | |
| def postprocess_encoded(self, tf_rep): | |
| """Hook to perform transformations on the encoded, time-frequency domain | |
| representation (output of the encoder) before encoder activation is applied. | |
| Args: | |
| tf_rep (Tensor of shape (batch, freq, time)): | |
| Output of the encoder, before encoder activation is applied. | |
| Return: | |
| Transformed `tf_rep` | |
| """ | |
| return tf_rep | |
| def postprocess_masks(self, masks): | |
| """Hook to perform transformations on the masks (output of the masker) before | |
| masks are applied. | |
| Args: | |
| masks (Tensor of shape (batch, n_src, freq, time)): | |
| Output of the masker | |
| Return: | |
| Transformed `masks` | |
| """ | |
| return masks | |
| def postprocess_masked(self, masked_tf_rep): | |
| """Hook to perform transformations on the masked time-frequency domain | |
| representation (result of masking in the time-frequency domain) before decoding. | |
| Args: | |
| masked_tf_rep (Tensor of shape (batch, n_src, freq, time)): | |
| Masked time-frequency representation, before decoding. | |
| Return: | |
| Transformed `masked_tf_rep` | |
| """ | |
| return masked_tf_rep | |
| def postprocess_decoded(self, decoded): | |
| """Hook to perform transformations on the decoded, time domain representation | |
| (output of the decoder) before original shape reconstruction. | |
| Args: | |
| decoded (Tensor of shape (batch, n_src, time)): | |
| Output of the decoder, before original shape reconstruction. | |
| Return: | |
| Transformed `decoded` | |
| """ | |
| return decoded | |
| def get_model_args(self): | |
| """ Arguments needed to re-instantiate the model. """ | |
| fb_config = self.encoder.filterbank.get_config() | |
| masknet_config = self.masker.get_config() | |
| # Assert both dict are disjoint | |
| if not all(k not in fb_config for k in masknet_config): | |
| raise AssertionError( | |
| "Filterbank and Mask network config share" "common keys. Merging them is not safe." | |
| ) | |
| # Merge all args under model_args. | |
| model_args = { | |
| **fb_config, | |
| **masknet_config, | |
| "encoder_activation": self.encoder_activation, | |
| } | |
| return model_args | |
| # Backwards compatibility | |
| BaseTasNet = BaseEncoderMaskerDecoder | |