File size: 3,963 Bytes
f761808 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import math
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def init_weights(m):
def truncated_normal_init(t, mean=0.0, std=0.01):
torch.nn.init.normal_(t, mean=mean, std=std)
while True:
cond = torch.logical_or(t < mean - 2 * std, t > mean + 2 * std)
if not torch.sum(cond):
break
t = torch.where(cond, torch.nn.init.normal_(torch.ones(t.shape), mean=mean, std=std), t)
return t
if type(m) == nn.Linear:
input_dim = m.in_features
truncated_normal_init(m.weight, std=1 / (2 * np.sqrt(input_dim)))
m.bias.data.fill_(0.0)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
#-----------------------------------------------------------------------------#
#---------------------------------- sampling ---------------------------------#
#-----------------------------------------------------------------------------#
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = np.linspace(0, steps, steps)
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
betas_clipped = np.clip(betas, a_min=0, a_max=0.999)
return torch.tensor(betas_clipped, dtype=dtype)
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=2e-2, dtype=torch.float32):
betas = np.linspace(
beta_start, beta_end, timesteps
)
return torch.tensor(betas, dtype=dtype)
def vp_beta_schedule(timesteps, dtype=torch.float32):
t = np.arange(1, timesteps + 1)
T = timesteps
b_max = 10.
b_min = 0.1
alpha = np.exp(-b_min / T - 0.5 * (b_max - b_min) * (2 * t - 1) / T ** 2)
betas = 1 - alpha
return torch.tensor(betas, dtype=dtype)
#-----------------------------------------------------------------------------#
#---------------------------------- losses -----------------------------------#
#-----------------------------------------------------------------------------#
class WeightedLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, targ, weights=1.0):
'''
pred, targ : tensor [ batch_size x action_dim ]
'''
loss = self._loss(pred, targ)
weighted_loss = (loss * weights).mean()
return weighted_loss
class WeightedL1(WeightedLoss):
def _loss(self, pred, targ):
return torch.abs(pred - targ)
class WeightedL2(WeightedLoss):
def _loss(self, pred, targ):
return F.mse_loss(pred, targ, reduction='none')
Losses = {
'l1': WeightedL1,
'l2': WeightedL2,
}
class EMA():
'''
empirical moving average
'''
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
|