Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Authors: Yossi Adi (adiyoss) and Alexandre Defossez (adefossez) | |
import functools | |
import logging | |
from contextlib import contextmanager | |
import inspect | |
import os | |
import time | |
import math | |
import torch | |
logger = logging.getLogger(__name__) | |
def capture_init(init): | |
""" | |
Decorate `__init__` with this, and you can then | |
recover the *args and **kwargs passed to it in `self._init_args_kwargs` | |
""" | |
def __init__(self, *args, **kwargs): | |
self._init_args_kwargs = (args, kwargs) | |
init(self, *args, **kwargs) | |
return __init__ | |
def deserialize_model(package, strict=False): | |
klass = package['class'] | |
if strict: | |
model = klass(*package['args'], **package['kwargs']) | |
else: | |
sig = inspect.signature(klass) | |
kw = package['kwargs'] | |
for key in list(kw): | |
if key not in sig.parameters: | |
logger.warning("Dropping inexistant parameter %s", key) | |
del kw[key] | |
model = klass(*package['args'], **kw) | |
model.load_state_dict(package['state']) | |
return model | |
def copy_state(state): | |
return {k: v.cpu().clone() for k, v in state.items()} | |
def serialize_model(model): | |
args, kwargs = model._init_args_kwargs | |
state = copy_state(model.state_dict()) | |
return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state} | |
def swap_state(model, state): | |
old_state = copy_state(model.state_dict()) | |
model.load_state_dict(state) | |
try: | |
yield | |
finally: | |
model.load_state_dict(old_state) | |
def swap_cwd(cwd): | |
old_cwd = os.getcwd() | |
os.chdir(cwd) | |
try: | |
yield | |
finally: | |
os.chdir(old_cwd) | |
def pull_metric(history, name): | |
out = [] | |
for metrics in history: | |
if name in metrics: | |
out.append(metrics[name]) | |
return out | |
class LogProgress: | |
""" | |
Sort of like tqdm but using log lines and not as real time. | |
""" | |
def __init__(self, logger, iterable, updates=5, total=None, | |
name="LogProgress", level=logging.INFO): | |
self.iterable = iterable | |
self.total = total or len(iterable) | |
self.updates = updates | |
self.name = name | |
self.logger = logger | |
self.level = level | |
def update(self, **infos): | |
self._infos = infos | |
def __iter__(self): | |
self._iterator = iter(self.iterable) | |
self._index = -1 | |
self._infos = {} | |
self._begin = time.time() | |
return self | |
def __next__(self): | |
self._index += 1 | |
try: | |
value = next(self._iterator) | |
except StopIteration: | |
raise | |
else: | |
return value | |
finally: | |
log_every = max(1, self.total // self.updates) | |
# logging is delayed by 1 it, in order to have the metrics from update | |
if self._index >= 1 and self._index % log_every == 0: | |
self._log() | |
def _log(self): | |
self._speed = (1 + self._index) / (time.time() - self._begin) | |
infos = " | ".join(f"{k.capitalize()} {v}" for k, | |
v in self._infos.items()) | |
if self._speed < 1e-4: | |
speed = "oo sec/it" | |
elif self._speed < 0.1: | |
speed = f"{1/self._speed:.1f} sec/it" | |
else: | |
speed = f"{self._speed:.1f} it/sec" | |
out = f"{self.name} | {self._index}/{self.total} | {speed}" | |
if infos: | |
out += " | " + infos | |
self.logger.log(self.level, out) | |
def colorize(text, color): | |
code = f"\033[{color}m" | |
restore = f"\033[0m" | |
return "".join([code, text, restore]) | |
def bold(text): | |
return colorize(text, "1") | |
def calculate_grad_norm(model): | |
total_norm = 0.0 | |
is_first = True | |
for p in model.parameters(): | |
param_norm = p.data.grad.flatten() | |
if is_first: | |
total_norm = param_norm | |
is_first = False | |
else: | |
total_norm = torch.cat((total_norm.unsqueeze( | |
1), p.data.grad.flatten().unsqueeze(1)), dim=0).squeeze(1) | |
return total_norm.norm(2) ** (1. / 2) | |
def calculate_weight_norm(model): | |
total_norm = 0.0 | |
is_first = True | |
for p in model.parameters(): | |
param_norm = p.data.flatten() | |
if is_first: | |
total_norm = param_norm | |
is_first = False | |
else: | |
total_norm = torch.cat((total_norm.unsqueeze( | |
1), p.data.flatten().unsqueeze(1)), dim=0).squeeze(1) | |
return total_norm.norm(2) ** (1. / 2) | |
def remove_pad(inputs, inputs_lengths): | |
""" | |
Args: | |
inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size | |
inputs_lengths: torch.Tensor, [B] | |
Returns: | |
results: a list containing B items, each item is [C, T], T varies | |
""" | |
results = [] | |
dim = inputs.dim() | |
if dim == 3: | |
C = inputs.size(1) | |
for input, length in zip(inputs, inputs_lengths): | |
if dim == 3: # [B, C, T] | |
results.append(input[:, :length].view(C, -1).cpu().numpy()) | |
elif dim == 2: # [B, T] | |
results.append(input[:length].view(-1).cpu().numpy()) | |
return results | |
def overlap_and_add(signal, frame_step): | |
"""Reconstructs a signal from a framed representation. | |
Adds potentially overlapping frames of a signal with shape | |
`[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. | |
The resulting tensor has shape `[..., output_size]` where | |
output_size = (frames - 1) * frame_step + frame_length | |
Args: | |
signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2. | |
frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length. | |
Returns: | |
A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions. | |
output_size = (frames - 1) * frame_step + frame_length | |
Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py | |
""" | |
outer_dimensions = signal.size()[:-2] | |
frames, frame_length = signal.size()[-2:] | |
# gcd=Greatest Common Divisor | |
subframe_length = math.gcd(frame_length, frame_step) | |
subframe_step = frame_step // subframe_length | |
subframes_per_frame = frame_length // subframe_length | |
output_size = frame_step * (frames - 1) + frame_length | |
output_subframes = output_size // subframe_length | |
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) | |
frame = torch.arange(0, output_subframes).unfold( | |
0, subframes_per_frame, subframe_step) | |
frame = frame.clone().detach().long().to(signal.device) | |
# frame = signal.new_tensor(frame).clone().long() # signal may in GPU or CPU | |
frame = frame.contiguous().view(-1) | |
result = signal.new_zeros( | |
*outer_dimensions, output_subframes, subframe_length) | |
result.index_add_(-2, frame, subframe_signal) | |
result = result.view(*outer_dimensions, -1) | |
return result | |