genrl / tools /utils.py
mazpie's picture
Initial commit
2d9a728
raw
history blame
No virus
7.92 kB
import math
import random
import time
from functools import wraps
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions as pyd
from torch.distributions.utils import _standard_normal
from collections.abc import MutableMapping
class eval_mode:
def __init__(self, *models):
self.models = models
def __enter__(self):
self.prev_states = []
for model in self.models:
self.prev_states.append(model.training)
model.train(False)
def __exit__(self, *args):
for model, state in zip(self.models, self.prev_states):
model.train(state)
return False
def set_seed_everywhere(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def soft_update_params(net, target_net, tau):
for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_(tau * param.data +
(1 - tau) * target_param.data)
def hard_update_params(net, target_net):
for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_(param.data)
def weight_init(m):
"""Custom weight init for Conv2D and Linear layers."""
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
gain = nn.init.calculate_gain('relu')
nn.init.orthogonal_(m.weight.data, gain)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
class Until:
def __init__(self, until, action_repeat=1):
self._until = until
self._action_repeat = action_repeat
def __call__(self, step):
if self._until is None:
return True
until = self._until // self._action_repeat
return step < until
class Every:
def __init__(self, every, action_repeat=1):
self._every = every
self._action_repeat = action_repeat
def __call__(self, step):
if self._every is None:
return False
every = self._every // self._action_repeat
if step % every == 0:
return True
return False
class Timer:
def __init__(self):
self._start_time = time.time()
self._last_time = time.time()
def reset(self):
elapsed_time = time.time() - self._last_time
self._last_time = time.time()
total_time = time.time() - self._start_time
return elapsed_time, total_time
def total_time(self):
return time.time() - self._start_time
class TruncatedNormal(pyd.Normal):
def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
super().__init__(loc, scale, validate_args=False)
self.low = low
self.high = high
self.eps = eps
def _clamp(self, x):
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
x = x - x.detach() + clamped_x.detach()
return x
def sample(self, sample_shape=torch.Size(), stddev_clip=None):
shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape,
dtype=self.loc.dtype,
device=self.loc.device)
eps *= self.scale
if stddev_clip is not None:
eps = torch.clamp(eps, -stddev_clip, stddev_clip)
x = self.loc + eps
return self._clamp(x)
class TanhTransform(pyd.transforms.Transform):
domain = pyd.constraints.real
codomain = pyd.constraints.interval(-1.0, 1.0)
bijective = True
sign = +1
def __init__(self, cache_size=1):
super().__init__(cache_size=cache_size)
@staticmethod
def atanh(x):
return 0.5 * (x.log1p() - (-x).log1p())
def __eq__(self, other):
return isinstance(other, TanhTransform)
def _call(self, x):
return x.tanh()
def _inverse(self, y):
# We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
# one should use `cache_size=1` instead
return self.atanh(y)
def log_abs_det_jacobian(self, x, y):
# We use a formula that is more numerically stable, see details in the following link
# https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
return 2. * (math.log(2.) - x - F.softplus(-2. * x))
class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
def __init__(self, loc, scale):
self.loc = loc
self.scale = scale
self.base_dist = pyd.Normal(loc, scale)
transforms = [TanhTransform()]
super().__init__(self.base_dist, transforms)
@property
def mean(self):
mu = self.loc
for tr in self.transforms:
mu = tr(mu)
return mu
def retry(func):
"""
A Decorator to retry a function for a certain amount of attempts
"""
@wraps(func)
def wrapper(*args, **kwargs):
attempts = 0
max_attempts = 1000
while attempts < max_attempts:
try:
return func(*args, **kwargs)
except (OSError, PermissionError):
attempts += 1
time.sleep(0.1)
raise OSError("Retry failed")
return wrapper
def flatten_dict(dictionary, parent_key='', separator='_'):
items = []
for key in dictionary.keys():
try:
value = dictionary[key]
except:
value = '??? <MISSING>'
new_key = parent_key + separator + key if parent_key else key
if isinstance(value, MutableMapping):
items.extend(flatten_dict(value, new_key, separator=separator).items())
else:
items.append((new_key, value))
return dict(items)
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
'''
Spherical linear interpolation
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colineal. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
'''
c = False
if not isinstance(v0,np.ndarray):
c = True
v0 = v0.detach().cpu().numpy()
if not isinstance(v1,np.ndarray):
c = True
v1 = v1.detach().cpu().numpy()
if len(v0.shape) == 1:
v0 = v0.reshape(1, -1)
if len(v1.shape) == 1:
v1 = v1.reshape(1, -1)
# Copy the vectors to reuse them later
v0_copy = np.copy(v0)
v1_copy = np.copy(v1)
# Normalize the vectors to get the directions and angles
v0 = v0 / np.linalg.norm(v0, axis=-1, keepdims=True)
v1 = v1 / np.linalg.norm(v1, axis=-1, keepdims=True)
# Dot product with the normalized vectors (can't use np.dot in W)
dot = np.sum(v0 * v1, axis=-1)
# If absolute value of dot product is almost 1, vectors are ~colineal, so use lerp
if (np.abs(dot) > DOT_THRESHOLD).any():
raise NotImplementedError('lerp not implemented') # return lerp(t, v0_copy, v1_copy)
# Calculate initial angle between v0 and v1
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
# Angle at timestep t
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
# Finish the slerp algorithm
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0.reshape(-1, 1) * v0_copy + s1.reshape(-1, 1) * v1_copy
if c:
res = torch.from_numpy(v2).to("cuda")
else:
res = v2
return res