# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Authors: Yossi Adi (adiyoss) and Alexandre Defossez (adefossez) import functools import logging from contextlib import contextmanager import inspect import os import time import math import torch logger = logging.getLogger(__name__) def capture_init(init): """ Decorate `__init__` with this, and you can then recover the *args and **kwargs passed to it in `self._init_args_kwargs` """ @functools.wraps(init) def __init__(self, *args, **kwargs): self._init_args_kwargs = (args, kwargs) init(self, *args, **kwargs) return __init__ def deserialize_model(package, strict=False): klass = package['class'] if strict: model = klass(*package['args'], **package['kwargs']) else: sig = inspect.signature(klass) kw = package['kwargs'] for key in list(kw): if key not in sig.parameters: logger.warning("Dropping inexistant parameter %s", key) del kw[key] model = klass(*package['args'], **kw) model.load_state_dict(package['state']) return model def copy_state(state): return {k: v.cpu().clone() for k, v in state.items()} def serialize_model(model): args, kwargs = model._init_args_kwargs state = copy_state(model.state_dict()) return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state} @contextmanager def swap_state(model, state): old_state = copy_state(model.state_dict()) model.load_state_dict(state) try: yield finally: model.load_state_dict(old_state) @contextmanager def swap_cwd(cwd): old_cwd = os.getcwd() os.chdir(cwd) try: yield finally: os.chdir(old_cwd) def pull_metric(history, name): out = [] for metrics in history: if name in metrics: out.append(metrics[name]) return out class LogProgress: """ Sort of like tqdm but using log lines and not as real time. """ def __init__(self, logger, iterable, updates=5, total=None, name="LogProgress", level=logging.INFO): self.iterable = iterable self.total = total or len(iterable) self.updates = updates self.name = name self.logger = logger self.level = level def update(self, **infos): self._infos = infos def __iter__(self): self._iterator = iter(self.iterable) self._index = -1 self._infos = {} self._begin = time.time() return self def __next__(self): self._index += 1 try: value = next(self._iterator) except StopIteration: raise else: return value finally: log_every = max(1, self.total // self.updates) # logging is delayed by 1 it, in order to have the metrics from update if self._index >= 1 and self._index % log_every == 0: self._log() def _log(self): self._speed = (1 + self._index) / (time.time() - self._begin) infos = " | ".join(f"{k.capitalize()} {v}" for k, v in self._infos.items()) if self._speed < 1e-4: speed = "oo sec/it" elif self._speed < 0.1: speed = f"{1/self._speed:.1f} sec/it" else: speed = f"{self._speed:.1f} it/sec" out = f"{self.name} | {self._index}/{self.total} | {speed}" if infos: out += " | " + infos self.logger.log(self.level, out) def colorize(text, color): code = f"\033[{color}m" restore = f"\033[0m" return "".join([code, text, restore]) def bold(text): return colorize(text, "1") def calculate_grad_norm(model): total_norm = 0.0 is_first = True for p in model.parameters(): param_norm = p.data.grad.flatten() if is_first: total_norm = param_norm is_first = False else: total_norm = torch.cat((total_norm.unsqueeze( 1), p.data.grad.flatten().unsqueeze(1)), dim=0).squeeze(1) return total_norm.norm(2) ** (1. / 2) def calculate_weight_norm(model): total_norm = 0.0 is_first = True for p in model.parameters(): param_norm = p.data.flatten() if is_first: total_norm = param_norm is_first = False else: total_norm = torch.cat((total_norm.unsqueeze( 1), p.data.flatten().unsqueeze(1)), dim=0).squeeze(1) return total_norm.norm(2) ** (1. / 2) def remove_pad(inputs, inputs_lengths): """ Args: inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size inputs_lengths: torch.Tensor, [B] Returns: results: a list containing B items, each item is [C, T], T varies """ results = [] dim = inputs.dim() if dim == 3: C = inputs.size(1) for input, length in zip(inputs, inputs_lengths): if dim == 3: # [B, C, T] results.append(input[:, :length].view(C, -1).cpu().numpy()) elif dim == 2: # [B, T] results.append(input[:length].view(-1).cpu().numpy()) return results def overlap_and_add(signal, frame_step): """Reconstructs a signal from a framed representation. Adds potentially overlapping frames of a signal with shape `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. The resulting tensor has shape `[..., output_size]` where output_size = (frames - 1) * frame_step + frame_length Args: signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2. frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length. Returns: A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions. output_size = (frames - 1) * frame_step + frame_length Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py """ outer_dimensions = signal.size()[:-2] frames, frame_length = signal.size()[-2:] # gcd=Greatest Common Divisor subframe_length = math.gcd(frame_length, frame_step) subframe_step = frame_step // subframe_length subframes_per_frame = frame_length // subframe_length output_size = frame_step * (frames - 1) + frame_length output_subframes = output_size // subframe_length subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) frame = torch.arange(0, output_subframes).unfold( 0, subframes_per_frame, subframe_step) frame = frame.clone().detach().long().to(signal.device) # frame = signal.new_tensor(frame).clone().long() # signal may in GPU or CPU frame = frame.contiguous().view(-1) result = signal.new_zeros( *outer_dimensions, output_subframes, subframe_length) result.index_add_(-2, frame, subframe_signal) result = result.view(*outer_dimensions, -1) return result