Spaces:
Running
Running
File size: 2,291 Bytes
d5001fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import os
import pickle as pickle_tts
from typing import Any, Callable, Dict, Union
import fsspec
import torch
from TTS.utils.generic_utils import get_user_data_dir
class RenamingUnpickler(pickle_tts.Unpickler):
"""Overload default pickler to solve module renaming problem"""
def find_class(self, module, name):
return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name)
class AttrDict(dict):
"""A custom dict which converts dict keys
to class attributes"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
def load_fsspec(
path: str,
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
cache: bool = True,
**kwargs,
) -> Any:
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
Args:
path: Any path or url supported by fsspec.
map_location: torch.device or str.
cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True.
**kwargs: Keyword arguments forwarded to torch.load.
Returns:
Object stored in path.
"""
is_local = os.path.isdir(path) or os.path.isfile(path)
if cache and not is_local:
with fsspec.open(
f"filecache::{path}",
filecache={"cache_storage": str(get_user_data_dir("tts_cache"))},
mode="rb",
) as f:
return torch.load(f, map_location=map_location, **kwargs)
else:
with fsspec.open(path, "rb") as f:
return torch.load(f, map_location=map_location, **kwargs)
def load_checkpoint(
model, checkpoint_path, use_cuda=False, eval=False, cache=False
): # pylint: disable=redefined-builtin
try:
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
except ModuleNotFoundError:
pickle_tts.Unpickler = RenamingUnpickler
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache)
model.load_state_dict(state["model"])
if use_cuda:
model.cuda()
if eval:
model.eval()
return model, state
|