Spaces:
Sleeping
Sleeping
""" | |
Various handy Python and PyTorch utils. | |
Author: Paul-Edouard Sarlin (skydes) | |
""" | |
import os | |
import random | |
import time | |
from collections.abc import Iterable | |
from contextlib import contextmanager | |
import numpy as np | |
import torch | |
class AverageMetric: | |
def __init__(self): | |
self._sum = 0 | |
self._num_examples = 0 | |
def update(self, tensor): | |
assert tensor.dim() == 1 | |
tensor = tensor[~torch.isnan(tensor)] | |
self._sum += tensor.sum().item() | |
self._num_examples += len(tensor) | |
def compute(self): | |
if self._num_examples == 0: | |
return np.nan | |
else: | |
return self._sum / self._num_examples | |
# same as AverageMetric, but tracks all elements | |
class FAverageMetric: | |
def __init__(self): | |
self._sum = 0 | |
self._num_examples = 0 | |
self._elements = [] | |
def update(self, tensor): | |
self._elements += tensor.cpu().numpy().tolist() | |
assert tensor.dim() == 1 | |
tensor = tensor[~torch.isnan(tensor)] | |
self._sum += tensor.sum().item() | |
self._num_examples += len(tensor) | |
def compute(self): | |
if self._num_examples == 0: | |
return np.nan | |
else: | |
return self._sum / self._num_examples | |
class MedianMetric: | |
def __init__(self): | |
self._elements = [] | |
def update(self, tensor): | |
assert tensor.dim() == 1 | |
self._elements += tensor.cpu().numpy().tolist() | |
def compute(self): | |
if len(self._elements) == 0: | |
return np.nan | |
else: | |
return np.nanmedian(self._elements) | |
class PRMetric: | |
def __init__(self): | |
self.labels = [] | |
self.predictions = [] | |
def update(self, labels, predictions, mask=None): | |
assert labels.shape == predictions.shape | |
self.labels += ( | |
(labels[mask] if mask is not None else labels).cpu().numpy().tolist() | |
) | |
self.predictions += ( | |
(predictions[mask] if mask is not None else predictions) | |
.cpu() | |
.numpy() | |
.tolist() | |
) | |
def compute(self): | |
return np.array(self.labels), np.array(self.predictions) | |
def reset(self): | |
self.labels = [] | |
self.predictions = [] | |
class QuantileMetric: | |
def __init__(self, q=0.05): | |
self._elements = [] | |
self.q = q | |
def update(self, tensor): | |
assert tensor.dim() == 1 | |
self._elements += tensor.cpu().numpy().tolist() | |
def compute(self): | |
if len(self._elements) == 0: | |
return np.nan | |
else: | |
return np.nanquantile(self._elements, self.q) | |
class RecallMetric: | |
def __init__(self, ths, elements=[]): | |
self._elements = elements | |
self.ths = ths | |
def update(self, tensor): | |
assert tensor.dim() == 1 | |
self._elements += tensor.cpu().numpy().tolist() | |
def compute(self): | |
if isinstance(self.ths, Iterable): | |
return [self.compute_(th) for th in self.ths] | |
else: | |
return self.compute_(self.ths[0]) | |
def compute_(self, th): | |
if len(self._elements) == 0: | |
return np.nan | |
else: | |
s = (np.array(self._elements) < th).sum() | |
return s / len(self._elements) | |
def cal_error_auc(errors, thresholds): | |
sort_idx = np.argsort(errors) | |
errors = np.array(errors.copy())[sort_idx] | |
recall = (np.arange(len(errors)) + 1) / len(errors) | |
errors = np.r_[0.0, errors] | |
recall = np.r_[0.0, recall] | |
aucs = [] | |
for t in thresholds: | |
last_index = np.searchsorted(errors, t) | |
r = np.r_[recall[:last_index], recall[last_index - 1]] | |
e = np.r_[errors[:last_index], t] | |
aucs.append(np.round((np.trapz(r, x=e) / t), 4)) | |
return aucs | |
class AUCMetric: | |
def __init__(self, thresholds, elements=None): | |
self._elements = elements | |
self.thresholds = thresholds | |
if not isinstance(thresholds, list): | |
self.thresholds = [thresholds] | |
def update(self, tensor): | |
assert tensor.dim() == 1 | |
self._elements += tensor.cpu().numpy().tolist() | |
def compute(self): | |
if len(self._elements) == 0: | |
return np.nan | |
else: | |
return cal_error_auc(self._elements, self.thresholds) | |
class Timer(object): | |
"""A simpler timer context object. | |
Usage: | |
``` | |
> with Timer('mytimer'): | |
> # some computations | |
[mytimer] Elapsed: X | |
``` | |
""" | |
def __init__(self, name=None): | |
self.name = name | |
def __enter__(self): | |
self.tstart = time.time() | |
return self | |
def __exit__(self, type, value, traceback): | |
self.duration = time.time() - self.tstart | |
if self.name is not None: | |
print("[%s] Elapsed: %s" % (self.name, self.duration)) | |
def get_class(mod_path, BaseClass): | |
"""Get the class object which inherits from BaseClass and is defined in | |
the module named mod_name, child of base_path. | |
""" | |
import inspect | |
mod = __import__(mod_path, fromlist=[""]) | |
classes = inspect.getmembers(mod, inspect.isclass) | |
# Filter classes defined in the module | |
classes = [c for c in classes if c[1].__module__ == mod_path] | |
# Filter classes inherited from BaseModel | |
classes = [c for c in classes if issubclass(c[1], BaseClass)] | |
assert len(classes) == 1, classes | |
return classes[0][1] | |
def set_num_threads(nt): | |
"""Force numpy and other libraries to use a limited number of threads.""" | |
try: | |
import mkl | |
except ImportError: | |
pass | |
else: | |
mkl.set_num_threads(nt) | |
torch.set_num_threads(1) | |
os.environ["IPC_ENABLE"] = "1" | |
for o in [ | |
"OPENBLAS_NUM_THREADS", | |
"NUMEXPR_NUM_THREADS", | |
"OMP_NUM_THREADS", | |
"MKL_NUM_THREADS", | |
]: | |
os.environ[o] = str(nt) | |
def set_seed(seed): | |
random.seed(seed) | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def get_random_state(with_cuda): | |
pth_state = torch.get_rng_state() | |
np_state = np.random.get_state() | |
py_state = random.getstate() | |
if torch.cuda.is_available() and with_cuda: | |
cuda_state = torch.cuda.get_rng_state_all() | |
else: | |
cuda_state = None | |
return pth_state, np_state, py_state, cuda_state | |
def set_random_state(state): | |
pth_state, np_state, py_state, cuda_state = state | |
torch.set_rng_state(pth_state) | |
np.random.set_state(np_state) | |
random.setstate(py_state) | |
if ( | |
cuda_state is not None | |
and torch.cuda.is_available() | |
and len(cuda_state) == torch.cuda.device_count() | |
): | |
torch.cuda.set_rng_state_all(cuda_state) | |
def fork_rng(seed=None, with_cuda=True): | |
state = get_random_state(with_cuda) | |
if seed is not None: | |
set_seed(seed) | |
try: | |
yield | |
finally: | |
set_random_state(state) | |