""" Various handy Python and PyTorch utils. Author: Paul-Edouard Sarlin (skydes) """ import os import random import time from collections.abc import Iterable from contextlib import contextmanager import numpy as np import torch class AverageMetric: def __init__(self): self._sum = 0 self._num_examples = 0 def update(self, tensor): assert tensor.dim() == 1 tensor = tensor[~torch.isnan(tensor)] self._sum += tensor.sum().item() self._num_examples += len(tensor) def compute(self): if self._num_examples == 0: return np.nan else: return self._sum / self._num_examples # same as AverageMetric, but tracks all elements class FAverageMetric: def __init__(self): self._sum = 0 self._num_examples = 0 self._elements = [] def update(self, tensor): self._elements += tensor.cpu().numpy().tolist() assert tensor.dim() == 1 tensor = tensor[~torch.isnan(tensor)] self._sum += tensor.sum().item() self._num_examples += len(tensor) def compute(self): if self._num_examples == 0: return np.nan else: return self._sum / self._num_examples class MedianMetric: def __init__(self): self._elements = [] def update(self, tensor): assert tensor.dim() == 1 self._elements += tensor.cpu().numpy().tolist() def compute(self): if len(self._elements) == 0: return np.nan else: return np.nanmedian(self._elements) class PRMetric: def __init__(self): self.labels = [] self.predictions = [] @torch.no_grad() def update(self, labels, predictions, mask=None): assert labels.shape == predictions.shape self.labels += ( (labels[mask] if mask is not None else labels).cpu().numpy().tolist() ) self.predictions += ( (predictions[mask] if mask is not None else predictions) .cpu() .numpy() .tolist() ) @torch.no_grad() def compute(self): return np.array(self.labels), np.array(self.predictions) def reset(self): self.labels = [] self.predictions = [] class QuantileMetric: def __init__(self, q=0.05): self._elements = [] self.q = q def update(self, tensor): assert tensor.dim() == 1 self._elements += tensor.cpu().numpy().tolist() def compute(self): if len(self._elements) == 0: return np.nan else: return np.nanquantile(self._elements, self.q) class RecallMetric: def __init__(self, ths, elements=[]): self._elements = elements self.ths = ths def update(self, tensor): assert tensor.dim() == 1 self._elements += tensor.cpu().numpy().tolist() def compute(self): if isinstance(self.ths, Iterable): return [self.compute_(th) for th in self.ths] else: return self.compute_(self.ths[0]) def compute_(self, th): if len(self._elements) == 0: return np.nan else: s = (np.array(self._elements) < th).sum() return s / len(self._elements) def cal_error_auc(errors, thresholds): sort_idx = np.argsort(errors) errors = np.array(errors.copy())[sort_idx] recall = (np.arange(len(errors)) + 1) / len(errors) errors = np.r_[0.0, errors] recall = np.r_[0.0, recall] aucs = [] for t in thresholds: last_index = np.searchsorted(errors, t) r = np.r_[recall[:last_index], recall[last_index - 1]] e = np.r_[errors[:last_index], t] aucs.append(np.round((np.trapz(r, x=e) / t), 4)) return aucs class AUCMetric: def __init__(self, thresholds, elements=None): self._elements = elements self.thresholds = thresholds if not isinstance(thresholds, list): self.thresholds = [thresholds] def update(self, tensor): assert tensor.dim() == 1 self._elements += tensor.cpu().numpy().tolist() def compute(self): if len(self._elements) == 0: return np.nan else: return cal_error_auc(self._elements, self.thresholds) class Timer(object): """A simpler timer context object. Usage: ``` > with Timer('mytimer'): > # some computations [mytimer] Elapsed: X ``` """ def __init__(self, name=None): self.name = name def __enter__(self): self.tstart = time.time() return self def __exit__(self, type, value, traceback): self.duration = time.time() - self.tstart if self.name is not None: print("[%s] Elapsed: %s" % (self.name, self.duration)) def get_class(mod_path, BaseClass): """Get the class object which inherits from BaseClass and is defined in the module named mod_name, child of base_path. """ import inspect mod = __import__(mod_path, fromlist=[""]) classes = inspect.getmembers(mod, inspect.isclass) # Filter classes defined in the module classes = [c for c in classes if c[1].__module__ == mod_path] # Filter classes inherited from BaseModel classes = [c for c in classes if issubclass(c[1], BaseClass)] assert len(classes) == 1, classes return classes[0][1] def set_num_threads(nt): """Force numpy and other libraries to use a limited number of threads.""" try: import mkl except ImportError: pass else: mkl.set_num_threads(nt) torch.set_num_threads(1) os.environ["IPC_ENABLE"] = "1" for o in [ "OPENBLAS_NUM_THREADS", "NUMEXPR_NUM_THREADS", "OMP_NUM_THREADS", "MKL_NUM_THREADS", ]: os.environ[o] = str(nt) def set_seed(seed): random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) def get_random_state(with_cuda): pth_state = torch.get_rng_state() np_state = np.random.get_state() py_state = random.getstate() if torch.cuda.is_available() and with_cuda: cuda_state = torch.cuda.get_rng_state_all() else: cuda_state = None return pth_state, np_state, py_state, cuda_state def set_random_state(state): pth_state, np_state, py_state, cuda_state = state torch.set_rng_state(pth_state) np.random.set_state(np_state) random.setstate(py_state) if ( cuda_state is not None and torch.cuda.is_available() and len(cuda_state) == torch.cuda.device_count() ): torch.cuda.set_rng_state_all(cuda_state) @contextmanager def fork_rng(seed=None, with_cuda=True): state = get_random_state(with_cuda) if seed is not None: set_seed(seed) try: yield finally: set_random_state(state)