# -*- coding: utf-8 -*- import datetime import importlib import logging import os import re import subprocess import sys from pathlib import Path from typing import Dict import fsspec import torch def to_cuda(x: torch.Tensor) -> torch.Tensor: if x is None: return None if torch.is_tensor(x): x = x.contiguous() if torch.cuda.is_available(): x = x.cuda(non_blocking=True) return x def get_cuda(): use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") return use_cuda, device def get_git_branch(): try: out = subprocess.check_output(["git", "branch"]).decode("utf8") current = next(line for line in out.split("\n") if line.startswith("*")) current.replace("* ", "") except subprocess.CalledProcessError: current = "inside_docker" except FileNotFoundError: current = "unknown" except StopIteration: current = "unknown" return current def get_commit_hash(): """https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" # try: # subprocess.check_output(['git', 'diff-index', '--quiet', # 'HEAD']) # Verify client is clean # except: # raise RuntimeError( # " !! Commit before training to get the commit hash.") try: commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip() # Not copying .git folder into docker container except (subprocess.CalledProcessError, FileNotFoundError): commit = "0000000" return commit def get_experiment_folder_path(root_path, model_name): """Get an experiment folder path with the current date and time""" date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") commit_hash = get_commit_hash() output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash) return output_folder def remove_experiment_folder(experiment_path): """Check folder if there is a checkpoint, otherwise remove the folder""" fs = fsspec.get_mapper(experiment_path).fs checkpoint_files = fs.glob(experiment_path + "/*.pth") if not checkpoint_files: if fs.exists(experiment_path): fs.rm(experiment_path, recursive=True) print(" ! Run is removed from {}".format(experiment_path)) else: print(" ! Run is kept in {}".format(experiment_path)) def count_parameters(model): r"""Count number of trainable parameters in a network""" return sum(p.numel() for p in model.parameters() if p.requires_grad) def to_camel(text): text = text.capitalize() text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) text = text.replace("Tts", "TTS") text = text.replace("vc", "VC") return text def find_module(module_path: str, module_name: str) -> object: module_name = module_name.lower() module = importlib.import_module(module_path + "." + module_name) class_name = to_camel(module_name) return getattr(module, class_name) def import_class(module_path: str) -> object: """Import a class from a module path. Args: module_path (str): The module path of the class. Returns: object: The imported class. """ class_name = module_path.split(".")[-1] module_path = ".".join(module_path.split(".")[:-1]) module = importlib.import_module(module_path) return getattr(module, class_name) def get_import_path(obj: object) -> str: """Get the import path of a class. Args: obj (object): The class object. Returns: str: The import path of the class. """ return ".".join([type(obj).__module__, type(obj).__name__]) def get_user_data_dir(appname): TTS_HOME = os.environ.get("TTS_HOME") XDG_DATA_HOME = os.environ.get("XDG_DATA_HOME") if TTS_HOME is not None: ans = Path(TTS_HOME).expanduser().resolve(strict=False) elif XDG_DATA_HOME is not None: ans = Path(XDG_DATA_HOME).expanduser().resolve(strict=False) elif sys.platform == "win32": import winreg # pylint: disable=import-outside-toplevel key = winreg.OpenKey( winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" ) dir_, _ = winreg.QueryValueEx(key, "Local AppData") ans = Path(dir_).resolve(strict=False) elif sys.platform == "darwin": ans = Path("~/Library/Application Support/").expanduser() else: ans = Path.home().joinpath(".local/share") return ans.joinpath(appname) def set_init_dict(model_dict, checkpoint_state, c): # Partial initialization: if there is a mismatch with new and old layer, it is skipped. for k, v in checkpoint_state.items(): if k not in model_dict: print(" | > Layer missing in the model definition: {}".format(k)) # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} # 2. filter out different size layers pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} # 3. skip reinit layers if c.has("reinit_layers") and c.reinit_layers is not None: for reinit_layer_name in c.reinit_layers: pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} # 4. overwrite entries in the existing state dict model_dict.update(pretrained_dict) print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) return model_dict def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict: """Format kwargs to hande auxilary inputs to models. Args: def_args (Dict): A dictionary of argument names and their default values if not defined in `kwargs`. kwargs (Dict): A `dict` or `kwargs` that includes auxilary inputs to the model. Returns: Dict: arguments with formatted auxilary inputs. """ kwargs = kwargs.copy() for name in def_args: if name not in kwargs or kwargs[name] is None: kwargs[name] = def_args[name] return kwargs class KeepAverage: def __init__(self): self.avg_values = {} self.iters = {} def __getitem__(self, key): return self.avg_values[key] def items(self): return self.avg_values.items() def add_value(self, name, init_val=0, init_iter=0): self.avg_values[name] = init_val self.iters[name] = init_iter def update_value(self, name, value, weighted_avg=False): if name not in self.avg_values: # add value if not exist before self.add_value(name, init_val=value) else: # else update existing value if weighted_avg: self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value self.iters[name] += 1 else: self.avg_values[name] = self.avg_values[name] * self.iters[name] + value self.iters[name] += 1 self.avg_values[name] /= self.iters[name] def add_values(self, name_dict): for key, value in name_dict.items(): self.add_value(key, init_val=value) def update_values(self, value_dict): for key, value in value_dict.items(): self.update_value(key, value) def get_timestamp(): return datetime.now().strftime("%y%m%d-%H%M%S") def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): lg = logging.getLogger(logger_name) formatter = logging.Formatter("%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S") lg.setLevel(level) if tofile: log_file = os.path.join(root, phase + "_{}.log".format(get_timestamp())) fh = logging.FileHandler(log_file, mode="w") fh.setFormatter(formatter) lg.addHandler(fh) if screen: sh = logging.StreamHandler() sh.setFormatter(formatter) lg.addHandler(sh)