genrl / tools /utils.py
mazpie's picture
Initial commit
2d9a728
import math
import random
import time
from functools import wraps
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions as pyd
from torch.distributions.utils import _standard_normal
from collections.abc import MutableMapping
class eval_mode:
def __init__(self, *models):
self.models = models
def __enter__(self):
self.prev_states = []
for model in self.models:
self.prev_states.append(model.training)
model.train(False)
def __exit__(self, *args):
for model, state in zip(self.models, self.prev_states):
model.train(state)
return False
def set_seed_everywhere(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def soft_update_params(net, target_net, tau):
for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_(tau * param.data +
(1 - tau) * target_param.data)
def hard_update_params(net, target_net):
for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_(param.data)
def weight_init(m):
"""Custom weight init for Conv2D and Linear layers."""
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
gain = nn.init.calculate_gain('relu')
nn.init.orthogonal_(m.weight.data, gain)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
class Until:
def __init__(self, until, action_repeat=1):
self._until = until
self._action_repeat = action_repeat
def __call__(self, step):
if self._until is None:
return True
until = self._until // self._action_repeat
return step < until
class Every:
def __init__(self, every, action_repeat=1):
self._every = every
self._action_repeat = action_repeat
def __call__(self, step):
if self._every is None:
return False
every = self._every // self._action_repeat
if step % every == 0:
return True
return False
class Timer:
def __init__(self):
self._start_time = time.time()
self._last_time = time.time()
def reset(self):
elapsed_time = time.time() - self._last_time
self._last_time = time.time()
total_time = time.time() - self._start_time
return elapsed_time, total_time
def total_time(self):
return time.time() - self._start_time
class TruncatedNormal(pyd.Normal):
def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
super().__init__(loc, scale, validate_args=False)
self.low = low
self.high = high
self.eps = eps
def _clamp(self, x):
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
x = x - x.detach() + clamped_x.detach()
return x
def sample(self, sample_shape=torch.Size(), stddev_clip=None):
shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape,
dtype=self.loc.dtype,
device=self.loc.device)
eps *= self.scale
if stddev_clip is not None:
eps = torch.clamp(eps, -stddev_clip, stddev_clip)
x = self.loc + eps
return self._clamp(x)
class TanhTransform(pyd.transforms.Transform):
domain = pyd.constraints.real
codomain = pyd.constraints.interval(-1.0, 1.0)
bijective = True
sign = +1
def __init__(self, cache_size=1):
super().__init__(cache_size=cache_size)
@staticmethod
def atanh(x):
return 0.5 * (x.log1p() - (-x).log1p())
def __eq__(self, other):
return isinstance(other, TanhTransform)
def _call(self, x):
return x.tanh()
def _inverse(self, y):
# We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
# one should use `cache_size=1` instead
return self.atanh(y)
def log_abs_det_jacobian(self, x, y):
# We use a formula that is more numerically stable, see details in the following link
# https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
return 2. * (math.log(2.) - x - F.softplus(-2. * x))
class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
def __init__(self, loc, scale):
self.loc = loc
self.scale = scale
self.base_dist = pyd.Normal(loc, scale)
transforms = [TanhTransform()]
super().__init__(self.base_dist, transforms)
@property
def mean(self):
mu = self.loc
for tr in self.transforms:
mu = tr(mu)
return mu
def retry(func):
"""
A Decorator to retry a function for a certain amount of attempts
"""
@wraps(func)
def wrapper(*args, **kwargs):
attempts = 0
max_attempts = 1000
while attempts < max_attempts:
try:
return func(*args, **kwargs)
except (OSError, PermissionError):
attempts += 1
time.sleep(0.1)
raise OSError("Retry failed")
return wrapper
def flatten_dict(dictionary, parent_key='', separator='_'):
items = []
for key in dictionary.keys():
try:
value = dictionary[key]
except:
value = '??? <MISSING>'
new_key = parent_key + separator + key if parent_key else key
if isinstance(value, MutableMapping):
items.extend(flatten_dict(value, new_key, separator=separator).items())
else:
items.append((new_key, value))
return dict(items)
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
'''
Spherical linear interpolation
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colineal. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
'''
c = False
if not isinstance(v0,np.ndarray):
c = True
v0 = v0.detach().cpu().numpy()
if not isinstance(v1,np.ndarray):
c = True
v1 = v1.detach().cpu().numpy()
if len(v0.shape) == 1:
v0 = v0.reshape(1, -1)
if len(v1.shape) == 1:
v1 = v1.reshape(1, -1)
# Copy the vectors to reuse them later
v0_copy = np.copy(v0)
v1_copy = np.copy(v1)
# Normalize the vectors to get the directions and angles
v0 = v0 / np.linalg.norm(v0, axis=-1, keepdims=True)
v1 = v1 / np.linalg.norm(v1, axis=-1, keepdims=True)
# Dot product with the normalized vectors (can't use np.dot in W)
dot = np.sum(v0 * v1, axis=-1)
# If absolute value of dot product is almost 1, vectors are ~colineal, so use lerp
if (np.abs(dot) > DOT_THRESHOLD).any():
raise NotImplementedError('lerp not implemented') # return lerp(t, v0_copy, v1_copy)
# Calculate initial angle between v0 and v1
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
# Angle at timestep t
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
# Finish the slerp algorithm
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0.reshape(-1, 1) * v0_copy + s1.reshape(-1, 1) * v1_copy
if c:
res = torch.from_numpy(v2).to("cuda")
else:
res = v2
return res