JustinLin610's picture
first commit
ee21b96
raw
history blame
No virus
4.77 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# author: adefossez
import functools
import logging
from contextlib import contextmanager
import inspect
import time
logger = logging.getLogger(__name__)
EPS = 1e-8
def capture_init(init):
"""capture_init.
Decorate `__init__` with this, and you can then
recover the *args and **kwargs passed to it in `self._init_args_kwargs`
"""
@functools.wraps(init)
def __init__(self, *args, **kwargs):
self._init_args_kwargs = (args, kwargs)
init(self, *args, **kwargs)
return __init__
def deserialize_model(package, strict=False):
"""deserialize_model.
"""
klass = package['class']
if strict:
model = klass(*package['args'], **package['kwargs'])
else:
sig = inspect.signature(klass)
kw = package['kwargs']
for key in list(kw):
if key not in sig.parameters:
logger.warning("Dropping inexistant parameter %s", key)
del kw[key]
model = klass(*package['args'], **kw)
model.load_state_dict(package['state'])
return model
def copy_state(state):
return {k: v.cpu().clone() for k, v in state.items()}
def serialize_model(model):
args, kwargs = model._init_args_kwargs
state = copy_state(model.state_dict())
return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state}
@contextmanager
def swap_state(model, state):
"""
Context manager that swaps the state of a model, e.g:
# model is in old state
with swap_state(model, new_state):
# model in new state
# model back to old state
"""
old_state = copy_state(model.state_dict())
model.load_state_dict(state)
try:
yield
finally:
model.load_state_dict(old_state)
def pull_metric(history, name):
out = []
for metrics in history:
if name in metrics:
out.append(metrics[name])
return out
class LogProgress:
"""
Sort of like tqdm but using log lines and not as real time.
Args:
- logger: logger obtained from `logging.getLogger`,
- iterable: iterable object to wrap
- updates (int): number of lines that will be printed, e.g.
if `updates=5`, log every 1/5th of the total length.
- total (int): length of the iterable, in case it does not support
`len`.
- name (str): prefix to use in the log.
- level: logging level (like `logging.INFO`).
"""
def __init__(self,
logger,
iterable,
updates=5,
total=None,
name="LogProgress",
level=logging.INFO):
self.iterable = iterable
self.total = total or len(iterable)
self.updates = updates
self.name = name
self.logger = logger
self.level = level
def update(self, **infos):
self._infos = infos
def __iter__(self):
self._iterator = iter(self.iterable)
self._index = -1
self._infos = {}
self._begin = time.time()
return self
def __next__(self):
self._index += 1
try:
value = next(self._iterator)
except StopIteration:
raise
else:
return value
finally:
log_every = max(1, self.total // self.updates)
# logging is delayed by 1 it, in order to have the metrics from update
if self._index >= 1 and self._index % log_every == 0:
self._log()
def _log(self):
self._speed = (1 + self._index) / (time.time() - self._begin)
infos = " | ".join(f"{k.capitalize()} {v}" for k, v in self._infos.items())
if self._speed < 1e-4:
speed = "oo sec/it"
elif self._speed < 0.1:
speed = f"{1/self._speed:.1f} sec/it"
else:
speed = f"{self._speed:.1f} it/sec"
out = f"{self.name} | {self._index}/{self.total} | {speed}"
if infos:
out += " | " + infos
self.logger.log(self.level, out)
def colorize(text, color):
"""
Display text with some ANSI color in the terminal.
"""
code = f"\033[{color}m"
restore = "\033[0m"
return "".join([code, text, restore])
def bold(text):
"""
Display text in bold in the terminal.
"""
return colorize(text, "1")
def cal_snr(lbl, est):
import torch
y = 10.0 * torch.log10(
torch.sum(lbl**2, dim=-1) / (torch.sum((est-lbl)**2, dim=-1) + EPS) +
EPS
)
return y