Spaces:
Sleeping
Sleeping
import math | |
import random | |
import time | |
from functools import wraps | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import distributions as pyd | |
from torch.distributions.utils import _standard_normal | |
from collections.abc import MutableMapping | |
class eval_mode: | |
def __init__(self, *models): | |
self.models = models | |
def __enter__(self): | |
self.prev_states = [] | |
for model in self.models: | |
self.prev_states.append(model.training) | |
model.train(False) | |
def __exit__(self, *args): | |
for model, state in zip(self.models, self.prev_states): | |
model.train(state) | |
return False | |
def set_seed_everywhere(seed): | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
def soft_update_params(net, target_net, tau): | |
for param, target_param in zip(net.parameters(), target_net.parameters()): | |
target_param.data.copy_(tau * param.data + | |
(1 - tau) * target_param.data) | |
def hard_update_params(net, target_net): | |
for param, target_param in zip(net.parameters(), target_net.parameters()): | |
target_param.data.copy_(param.data) | |
def weight_init(m): | |
"""Custom weight init for Conv2D and Linear layers.""" | |
if isinstance(m, nn.Linear): | |
nn.init.orthogonal_(m.weight.data) | |
if hasattr(m.bias, 'data'): | |
m.bias.data.fill_(0.0) | |
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): | |
gain = nn.init.calculate_gain('relu') | |
nn.init.orthogonal_(m.weight.data, gain) | |
if hasattr(m.bias, 'data'): | |
m.bias.data.fill_(0.0) | |
class Until: | |
def __init__(self, until, action_repeat=1): | |
self._until = until | |
self._action_repeat = action_repeat | |
def __call__(self, step): | |
if self._until is None: | |
return True | |
until = self._until // self._action_repeat | |
return step < until | |
class Every: | |
def __init__(self, every, action_repeat=1): | |
self._every = every | |
self._action_repeat = action_repeat | |
def __call__(self, step): | |
if self._every is None: | |
return False | |
every = self._every // self._action_repeat | |
if step % every == 0: | |
return True | |
return False | |
class Timer: | |
def __init__(self): | |
self._start_time = time.time() | |
self._last_time = time.time() | |
def reset(self): | |
elapsed_time = time.time() - self._last_time | |
self._last_time = time.time() | |
total_time = time.time() - self._start_time | |
return elapsed_time, total_time | |
def total_time(self): | |
return time.time() - self._start_time | |
class TruncatedNormal(pyd.Normal): | |
def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): | |
super().__init__(loc, scale, validate_args=False) | |
self.low = low | |
self.high = high | |
self.eps = eps | |
def _clamp(self, x): | |
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) | |
x = x - x.detach() + clamped_x.detach() | |
return x | |
def sample(self, sample_shape=torch.Size(), stddev_clip=None): | |
shape = self._extended_shape(sample_shape) | |
eps = _standard_normal(shape, | |
dtype=self.loc.dtype, | |
device=self.loc.device) | |
eps *= self.scale | |
if stddev_clip is not None: | |
eps = torch.clamp(eps, -stddev_clip, stddev_clip) | |
x = self.loc + eps | |
return self._clamp(x) | |
class TanhTransform(pyd.transforms.Transform): | |
domain = pyd.constraints.real | |
codomain = pyd.constraints.interval(-1.0, 1.0) | |
bijective = True | |
sign = +1 | |
def __init__(self, cache_size=1): | |
super().__init__(cache_size=cache_size) | |
def atanh(x): | |
return 0.5 * (x.log1p() - (-x).log1p()) | |
def __eq__(self, other): | |
return isinstance(other, TanhTransform) | |
def _call(self, x): | |
return x.tanh() | |
def _inverse(self, y): | |
# We do not clamp to the boundary here as it may degrade the performance of certain algorithms. | |
# one should use `cache_size=1` instead | |
return self.atanh(y) | |
def log_abs_det_jacobian(self, x, y): | |
# We use a formula that is more numerically stable, see details in the following link | |
# https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7 | |
return 2. * (math.log(2.) - x - F.softplus(-2. * x)) | |
class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): | |
def __init__(self, loc, scale): | |
self.loc = loc | |
self.scale = scale | |
self.base_dist = pyd.Normal(loc, scale) | |
transforms = [TanhTransform()] | |
super().__init__(self.base_dist, transforms) | |
def mean(self): | |
mu = self.loc | |
for tr in self.transforms: | |
mu = tr(mu) | |
return mu | |
def retry(func): | |
""" | |
A Decorator to retry a function for a certain amount of attempts | |
""" | |
def wrapper(*args, **kwargs): | |
attempts = 0 | |
max_attempts = 1000 | |
while attempts < max_attempts: | |
try: | |
return func(*args, **kwargs) | |
except (OSError, PermissionError): | |
attempts += 1 | |
time.sleep(0.1) | |
raise OSError("Retry failed") | |
return wrapper | |
def flatten_dict(dictionary, parent_key='', separator='_'): | |
items = [] | |
for key in dictionary.keys(): | |
try: | |
value = dictionary[key] | |
except: | |
value = '??? <MISSING>' | |
new_key = parent_key + separator + key if parent_key else key | |
if isinstance(value, MutableMapping): | |
items.extend(flatten_dict(value, new_key, separator=separator).items()) | |
else: | |
items.append((new_key, value)) | |
return dict(items) | |
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): | |
''' | |
Spherical linear interpolation | |
Args: | |
t (float/np.ndarray): Float value between 0.0 and 1.0 | |
v0 (np.ndarray): Starting vector | |
v1 (np.ndarray): Final vector | |
DOT_THRESHOLD (float): Threshold for considering the two vectors as | |
colineal. Not recommended to alter this. | |
Returns: | |
v2 (np.ndarray): Interpolation vector between v0 and v1 | |
''' | |
c = False | |
if not isinstance(v0,np.ndarray): | |
c = True | |
v0 = v0.detach().cpu().numpy() | |
if not isinstance(v1,np.ndarray): | |
c = True | |
v1 = v1.detach().cpu().numpy() | |
if len(v0.shape) == 1: | |
v0 = v0.reshape(1, -1) | |
if len(v1.shape) == 1: | |
v1 = v1.reshape(1, -1) | |
# Copy the vectors to reuse them later | |
v0_copy = np.copy(v0) | |
v1_copy = np.copy(v1) | |
# Normalize the vectors to get the directions and angles | |
v0 = v0 / np.linalg.norm(v0, axis=-1, keepdims=True) | |
v1 = v1 / np.linalg.norm(v1, axis=-1, keepdims=True) | |
# Dot product with the normalized vectors (can't use np.dot in W) | |
dot = np.sum(v0 * v1, axis=-1) | |
# If absolute value of dot product is almost 1, vectors are ~colineal, so use lerp | |
if (np.abs(dot) > DOT_THRESHOLD).any(): | |
raise NotImplementedError('lerp not implemented') # return lerp(t, v0_copy, v1_copy) | |
# Calculate initial angle between v0 and v1 | |
theta_0 = np.arccos(dot) | |
sin_theta_0 = np.sin(theta_0) | |
# Angle at timestep t | |
theta_t = theta_0 * t | |
sin_theta_t = np.sin(theta_t) | |
# Finish the slerp algorithm | |
s0 = np.sin(theta_0 - theta_t) / sin_theta_0 | |
s1 = sin_theta_t / sin_theta_0 | |
v2 = s0.reshape(-1, 1) * v0_copy + s1.reshape(-1, 1) * v1_copy | |
if c: | |
res = torch.from_numpy(v2).to("cuda") | |
else: | |
res = v2 | |
return res |