Spaces:
Running
Running
import os | |
import pickle as pickle_tts | |
from typing import Any, Callable, Dict, Union | |
import fsspec | |
import torch | |
from TTS.utils.generic_utils import get_user_data_dir | |
class RenamingUnpickler(pickle_tts.Unpickler): | |
"""Overload default pickler to solve module renaming problem""" | |
def find_class(self, module, name): | |
return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) | |
class AttrDict(dict): | |
"""A custom dict which converts dict keys | |
to class attributes""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.__dict__ = self | |
def load_fsspec( | |
path: str, | |
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, | |
cache: bool = True, | |
**kwargs, | |
) -> Any: | |
"""Like torch.load but can load from other locations (e.g. s3:// , gs://). | |
Args: | |
path: Any path or url supported by fsspec. | |
map_location: torch.device or str. | |
cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True. | |
**kwargs: Keyword arguments forwarded to torch.load. | |
Returns: | |
Object stored in path. | |
""" | |
is_local = os.path.isdir(path) or os.path.isfile(path) | |
if cache and not is_local: | |
with fsspec.open( | |
f"filecache::{path}", | |
filecache={"cache_storage": str(get_user_data_dir("tts_cache"))}, | |
mode="rb", | |
) as f: | |
return torch.load(f, map_location=map_location, **kwargs) | |
else: | |
with fsspec.open(path, "rb") as f: | |
return torch.load(f, map_location=map_location, **kwargs) | |
def load_checkpoint( | |
model, checkpoint_path, use_cuda=False, eval=False, cache=False | |
): # pylint: disable=redefined-builtin | |
try: | |
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) | |
except ModuleNotFoundError: | |
pickle_tts.Unpickler = RenamingUnpickler | |
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache) | |
model.load_state_dict(state["model"]) | |
if use_cuda: | |
model.cuda() | |
if eval: | |
model.eval() | |
return model, state | |