File size: 7,919 Bytes
2d9a728 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
import math
import random
import time
from functools import wraps
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions as pyd
from torch.distributions.utils import _standard_normal
from import MutableMapping
class eval_mode:
def __init__(self, *models):
self.models = models
def __enter__(self):
self.prev_states = []
for model in self.models:
def __exit__(self, *args):
for model, state in zip(self.models, self.prev_states):
return False
def set_seed_everywhere(seed):
if torch.cuda.is_available():
def soft_update_params(net, target_net, tau):
for param, target_param in zip(net.parameters(), target_net.parameters()): * +
(1 - tau) *
def hard_update_params(net, target_net):
for param, target_param in zip(net.parameters(), target_net.parameters()):
def weight_init(m):
"""Custom weight init for Conv2D and Linear layers."""
if isinstance(m, nn.Linear):
if hasattr(m.bias, 'data'):
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
gain = nn.init.calculate_gain('relu')
nn.init.orthogonal_(, gain)
if hasattr(m.bias, 'data'):
class Until:
def __init__(self, until, action_repeat=1):
self._until = until
self._action_repeat = action_repeat
def __call__(self, step):
if self._until is None:
return True
until = self._until // self._action_repeat
return step < until
class Every:
def __init__(self, every, action_repeat=1):
self._every = every
self._action_repeat = action_repeat
def __call__(self, step):
if self._every is None:
return False
every = self._every // self._action_repeat
if step % every == 0:
return True
return False
class Timer:
def __init__(self):
self._start_time = time.time()
self._last_time = time.time()
def reset(self):
elapsed_time = time.time() - self._last_time
self._last_time = time.time()
total_time = time.time() - self._start_time
return elapsed_time, total_time
def total_time(self):
return time.time() - self._start_time
class TruncatedNormal(pyd.Normal):
def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
super().__init__(loc, scale, validate_args=False)
self.low = low
self.high = high
self.eps = eps
def _clamp(self, x):
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
x = x - x.detach() + clamped_x.detach()
return x
def sample(self, sample_shape=torch.Size(), stddev_clip=None):
shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape,
eps *= self.scale
if stddev_clip is not None:
eps = torch.clamp(eps, -stddev_clip, stddev_clip)
x = self.loc + eps
return self._clamp(x)
class TanhTransform(pyd.transforms.Transform):
domain = pyd.constraints.real
codomain = pyd.constraints.interval(-1.0, 1.0)
bijective = True
sign = +1
def __init__(self, cache_size=1):
def atanh(x):
return 0.5 * (x.log1p() - (-x).log1p())
def __eq__(self, other):
return isinstance(other, TanhTransform)
def _call(self, x):
return x.tanh()
def _inverse(self, y):
# We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
# one should use `cache_size=1` instead
return self.atanh(y)
def log_abs_det_jacobian(self, x, y):
# We use a formula that is more numerically stable, see details in the following link
return 2. * (math.log(2.) - x - F.softplus(-2. * x))
class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
def __init__(self, loc, scale):
self.loc = loc
self.scale = scale
self.base_dist = pyd.Normal(loc, scale)
transforms = [TanhTransform()]
super().__init__(self.base_dist, transforms)
def mean(self):
mu = self.loc
for tr in self.transforms:
mu = tr(mu)
return mu
def retry(func):
A Decorator to retry a function for a certain amount of attempts
def wrapper(*args, **kwargs):
attempts = 0
max_attempts = 1000
while attempts < max_attempts:
return func(*args, **kwargs)
except (OSError, PermissionError):
attempts += 1
raise OSError("Retry failed")
return wrapper
def flatten_dict(dictionary, parent_key='', separator='_'):
items = []
for key in dictionary.keys():
value = dictionary[key]
value = '??? <MISSING>'
new_key = parent_key + separator + key if parent_key else key
if isinstance(value, MutableMapping):
items.extend(flatten_dict(value, new_key, separator=separator).items())
items.append((new_key, value))
return dict(items)
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
Spherical linear interpolation
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colineal. Not recommended to alter this.
v2 (np.ndarray): Interpolation vector between v0 and v1
c = False
if not isinstance(v0,np.ndarray):
c = True
v0 = v0.detach().cpu().numpy()
if not isinstance(v1,np.ndarray):
c = True
v1 = v1.detach().cpu().numpy()
if len(v0.shape) == 1:
v0 = v0.reshape(1, -1)
if len(v1.shape) == 1:
v1 = v1.reshape(1, -1)
# Copy the vectors to reuse them later
v0_copy = np.copy(v0)
v1_copy = np.copy(v1)
# Normalize the vectors to get the directions and angles
v0 = v0 / np.linalg.norm(v0, axis=-1, keepdims=True)
v1 = v1 / np.linalg.norm(v1, axis=-1, keepdims=True)
# Dot product with the normalized vectors (can't use in W)
dot = np.sum(v0 * v1, axis=-1)
# If absolute value of dot product is almost 1, vectors are ~colineal, so use lerp
if (np.abs(dot) > DOT_THRESHOLD).any():
raise NotImplementedError('lerp not implemented') # return lerp(t, v0_copy, v1_copy)
# Calculate initial angle between v0 and v1
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
# Angle at timestep t
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
# Finish the slerp algorithm
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0.reshape(-1, 1) * v0_copy + s1.reshape(-1, 1) * v1_copy
if c:
res = torch.from_numpy(v2).to("cuda")
res = v2
return res |