zjowowen's picture
init space
079c32c
from typing import Union, List, Dict
from collections import namedtuple
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from ding.utils import list_split, MODEL_REGISTRY, squeeze, SequenceType
def extract(a, t, x_shape):
"""
Overview:
extract output from a through index t.
Arguments:
- a (:obj:`torch.Tensor`): input tensor
- t (:obj:`torch.Tensor`): index tensor
- x_shape (:obj:`torch.Tensor`): shape of x
"""
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
def cosine_beta_schedule(timesteps: int, s: float = 0.008, dtype=torch.float32):
"""
Overview:
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
Arguments:
- timesteps (:obj:`int`): timesteps of diffusion step
- s (:obj:`float`): s
- dtype (:obj:`torch.dtype`): dtype of beta
Return:
Tensor of beta [timesteps,], computing by cosine.
"""
steps = timesteps + 1
x = np.linspace(0, steps, steps)
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
betas_clipped = np.clip(betas, a_min=0, a_max=0.999)
return torch.tensor(betas_clipped, dtype=dtype)
def apply_conditioning(x, conditions, action_dim):
"""
Overview:
add condition into x
Arguments:
- x (:obj:`torch.Tensor`): input tensor
- conditions (:obj:`dict`): condition dict, key is timestep, value is condition
- action_dim (:obj:`int`): action dim
"""
for t, val in conditions.items():
x[:, t, action_dim:] = val.clone()
return x
class DiffusionConv1d(nn.Module):
"""
Overview:
Conv1d with activation and normalization for diffusion models.
Interfaces:
``__init__``, ``forward``
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: int,
activation: nn.Module = None,
n_groups: int = 8
) -> None:
"""
Overview:
Create a 1-dim convlution layer with activation and normalization. This Conv1d have GropuNorm.
And need add 1-dim when compute norm
Arguments:
- in_channels (:obj:`int`): Number of channels in the input tensor
- out_channels (:obj:`int`): Number of channels in the output tensor
- kernel_size (:obj:`int`): Size of the convolving kernel
- padding (:obj:`int`): Zero-padding added to both sides of the input
- activation (:obj:`nn.Module`): the optional activation function
"""
super().__init__()
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
self.norm = nn.GroupNorm(n_groups, out_channels)
self.act = activation
def forward(self, inputs) -> torch.Tensor:
"""
Overview:
compute conv1d for inputs.
Arguments:
- inputs (:obj:`torch.Tensor`): input tensor
Return:
- out (:obj:`torch.Tensor`): output tensor
"""
x = self.conv1(inputs)
# [batch, channels, horizon] -> [batch, channels, 1, horizon]
x = x.unsqueeze(-2)
x = self.norm(x)
# [batch, channels, 1, horizon] -> [batch, channels, horizon]
x = x.squeeze(-2)
out = self.act(x)
return out
class SinusoidalPosEmb(nn.Module):
"""
Overview:
class for computing sin position embeding
Interfaces:
``__init__``, ``forward``
"""
def __init__(self, dim: int) -> None:
"""
Overview:
Initialization of SinusoidalPosEmb class
Arguments:
- dim (:obj:`int`): dimension of embeding
"""
super().__init__()
self.dim = dim
def forward(self, x) -> torch.Tensor:
"""
Overview:
compute sin position embeding
Arguments:
- x (:obj:`torch.Tensor`): input tensor
Return:
- emb (:obj:`torch.Tensor`): output tensor
"""
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=1)
return emb
class Residual(nn.Module):
"""
Overview:
Basic Residual block
Interfaces:
``__init__``, ``forward``
"""
def __init__(self, fn):
"""
Overview:
Initialization of Residual class
Arguments:
- fn (:obj:`nn.Module`): function of residual block
"""
super().__init__()
self.fn = fn
def forward(self, x, *arg, **kwargs):
"""
Overview:
compute residual block
Arguments:
- x (:obj:`torch.Tensor`): input tensor
"""
return self.fn(x, *arg, **kwargs) + x
class LayerNorm(nn.Module):
"""
Overview:
LayerNorm, compute dim = 1, because Temporal input x [batch, dim, horizon]
Interfaces:
``__init__``, ``forward``
"""
def __init__(self, dim, eps=1e-5) -> None:
"""
Overview:
Initialization of LayerNorm class
Arguments:
- dim (:obj:`int`): dimension of input
- eps (:obj:`float`): eps of LayerNorm
"""
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1))
def forward(self, x):
"""
Overview:
compute LayerNorm
Arguments:
- x (:obj:`torch.Tensor`): input tensor
"""
print('x.shape:', x.shape)
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
"""
Overview:
PreNorm, compute dim = 1, because Temporal input x [batch, dim, horizon]
Interfaces:
``__init__``, ``forward``
"""
def __init__(self, dim, fn) -> None:
"""
Overview:
Initialization of PreNorm class
Arguments:
- dim (:obj:`int`): dimension of input
- fn (:obj:`nn.Module`): function of residual block
"""
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x):
"""
Overview:
compute PreNorm
Arguments:
- x (:obj:`torch.Tensor`): input tensor
"""
x = self.norm(x)
return self.fn(x)
class LinearAttention(nn.Module):
"""
Overview:
Linear Attention head
Interfaces:
``__init__``, ``forward``
"""
def __init__(self, dim, heads=4, dim_head=32) -> None:
"""
Overview:
Initialization of LinearAttention class
Arguments:
- dim (:obj:`int`): dimension of input
- heads (:obj:`int`): heads of attention
- dim_head (:obj:`int`): dim of head
"""
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv1d(hidden_dim, dim, 1)
def forward(self, x):
"""
Overview:
compute LinearAttention
Arguments:
- x (:obj:`torch.Tensor`): input tensor
"""
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(lambda t: t.reshape(t.shape[0], self.heads, -1, t.shape[-1]), qkv)
q = q * self.scale
k = k.softmax(dim=-1)
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
out = out.reshape(out.shape[0], -1, out.shape[-1])
return self.to_out(out)
class ResidualTemporalBlock(nn.Module):
"""
Overview:
Residual block of temporal
Interfaces:
``__init__``, ``forward``
"""
def __init__(
self, in_channels: int, out_channels: int, embed_dim: int, kernel_size: int = 5, mish: bool = True
) -> None:
"""
Overview:
Initialization of ResidualTemporalBlock class
Arguments:
- in_channels (:obj:'int'): dim of in_channels
- out_channels (:obj:'int'): dim of out_channels
- embed_dim (:obj:'int'): dim of embeding layer
- kernel_size (:obj:'int'): kernel_size of conv1d
- mish (:obj:'bool'): whether use mish as a activate function
"""
super().__init__()
if mish:
act = nn.Mish()
else:
act = nn.SiLU()
self.blocks = nn.ModuleList(
[
DiffusionConv1d(in_channels, out_channels, kernel_size, kernel_size // 2, act),
DiffusionConv1d(out_channels, out_channels, kernel_size, kernel_size // 2, act),
]
)
self.time_mlp = nn.Sequential(
act,
nn.Linear(embed_dim, out_channels),
)
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
if in_channels != out_channels else nn.Identity()
def forward(self, x, t):
"""
Overview:
compute residual block
Arguments:
- x (:obj:'tensor'): input tensor
- t (:obj:'tensor'): time tensor
"""
out = self.blocks[0](x) + self.time_mlp(t).unsqueeze(-1)
out = self.blocks[1](out)
return out + self.residual_conv(x)
class DiffusionUNet1d(nn.Module):
"""
Overview:
Diffusion unet for 1d vector data
Interfaces:
``__init__``, ``forward``, ``get_pred``
"""
def __init__(
self,
transition_dim: int,
dim: int = 32,
dim_mults: SequenceType = [1, 2, 4, 8],
returns_condition: bool = False,
condition_dropout: float = 0.1,
calc_energy: bool = False,
kernel_size: int = 5,
attention: bool = False,
) -> None:
"""
Overview:
Initialization of DiffusionUNet1d class
Arguments:
- transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim
- dim (:obj:'int'): dim of layer
- dim_mults (:obj:'SequenceType'): mults of dim
- returns_condition (:obj:'bool'): whether use return as a condition
- condition_dropout (:obj:'float'): dropout of returns condition
- calc_energy (:obj:'bool'): whether use calc_energy
- kernel_size (:obj:'int'): kernel_size of conv1d
- attention (:obj:'bool'): whether use attention
"""
super().__init__()
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
if calc_energy:
mish = False
act = nn.SiLU()
else:
mish = True
act = nn.Mish()
self.time_dim = dim
self.returns_dim = dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
act,
nn.Linear(dim * 4, dim),
)
self.returns_condition = returns_condition
self.condition_dropout = condition_dropout
self.cale_energy = calc_energy
if self.returns_condition:
self.returns_mlp = nn.Sequential(
nn.Linear(1, dim),
act,
nn.Linear(dim, dim * 4),
act,
nn.Linear(dim * 4, dim),
)
self.mask_dist = torch.distributions.Bernoulli(probs=1 - self.condition_dropout)
embed_dim = 2 * dim
else:
embed_dim = dim
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolution = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolution - 1)
self.downs.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_in, dim_out, embed_dim, kernel_size, mish=mish),
ResidualTemporalBlock(dim_out, dim_out, embed_dim, kernel_size, mish=mish),
Residual(PreNorm(dim_out, LinearAttention(dim_out))) if attention else nn.Identity(),
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity()
]
)
)
mid_dim = dims[-1]
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim, kernel_size, mish)
self.mid_atten = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) if attention else nn.Identity()
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim, kernel_size, mish)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolution - 1)
self.ups.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim, kernel_size, mish=mish),
ResidualTemporalBlock(dim_in, dim_in, embed_dim, kernel_size, mish=mish),
Residual(PreNorm(dim_in, LinearAttention(dim_in))) if attention else nn.Identity(),
nn.ConvTranspose1d(dim_in, dim_in, 4, 2, 1) if not is_last else nn.Identity()
]
)
)
self.final_conv = nn.Sequential(
DiffusionConv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, activation=act),
nn.Conv1d(dim, transition_dim, 1),
)
def forward(self, x, cond, time, returns=None, use_dropout: bool = True, force_dropout: bool = False):
"""
Overview:
compute diffusion unet forward
Arguments:
- x (:obj:'tensor'): noise trajectory
- cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0
- time (:obj:'int'): timestep of diffusion step
- returns (:obj:'tensor'): condition returns of trajectory, returns is normal return
- use_dropout (:obj:'bool'): Whether use returns condition mask
- force_dropout (:obj:'bool'): Whether use returns condition
"""
if self.cale_energy:
x_inp = x
# [batch, horizon, transition ] -> [batch, transition , horizon]
x = x.transpose(1, 2)
t = self.time_mlp(time)
if self.returns_condition:
assert returns is not None
returns_embed = self.returns_mlp(returns)
if use_dropout:
mask = self.mask_dist.sample(sample_shape=(returns_embed.size(0), 1)).to(returns_embed.device)
returns_embed = mask * returns_embed
if force_dropout:
returns_embed = 0 * returns_embed
t = torch.cat([t, returns_embed], dim=-1)
h = []
for resnet, resnet2, atten, downsample in self.downs:
x = resnet(x, t)
x = resnet2(x, t)
x = atten(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_atten(x)
x = self.mid_block2(x, t)
for resnet, resnet2, atten, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, t)
x = resnet2(x, t)
x = atten(x)
x = upsample(x)
x = self.final_conv(x)
# [batch, transition , horizon] -> [batch, horizon, transition ]
x = x.transpose(1, 2)
if self.cale_energy:
# Energy function
energy = ((x - x_inp) ** 2).mean()
grad = torch.autograd.grad(outputs=energy, inputs=x_inp, create_graph=True)
return grad[0]
else:
return x
def get_pred(self, x, cond, time, returns: bool = None, use_dropout: bool = True, force_dropout: bool = False):
"""
Overview:
compute diffusion unet forward
Arguments:
- x (:obj:'tensor'): noise trajectory
- cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0
- time (:obj:'int'): timestep of diffusion step
- returns (:obj:'tensor'): condition returns of trajectory, returns is normal return
- use_dropout (:obj:'bool'): Whether use returns condition mask
- force_dropout (:obj:'bool'): Whether use returns condition
"""
# [batch, horizon, transition ] -> [batch, transition , horizon]
x = x.transpose(1, 2)
t = self.time_mlp(time)
if self.returns_condition:
assert returns is not None
returns_embed = self.returns_mlp(returns)
if use_dropout:
mask = self.mask_dist.sample(sample_shape=(returns_embed.size(0), 1)).to(returns_embed.device)
returns_embed = mask * returns_embed
if force_dropout:
returns_embed = 0 * returns_embed
t = torch.cat([t, returns_embed], dim=-1)
h = []
for resnet, resnet2, downsample in self.downs:
x = resnet(x, t)
x = resnet2(x, t)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_block2(x, t)
for resnet, resnet2, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, t)
x = resnet2(x, t)
x = upsample(x)
x = self.final_conv(x)
# [batch, transition , horizon] -> [batch, horizon, transition ]
x = x.transpose(1, 2)
return x
class TemporalValue(nn.Module):
"""
Overview:
temporal net for value function
Interfaces:
``__init__``, ``forward``
"""
def __init__(
self,
horizon: int,
transition_dim: int,
dim: int = 32,
time_dim: int = None,
out_dim: int = 1,
kernel_size: int = 5,
dim_mults: SequenceType = [1, 2, 4, 8],
) -> None:
"""
Overview:
Initialization of TemporalValue class
Arguments:
- horizon (:obj:'int'): horizon of trajectory
- transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim
- dim (:obj:'int'): dim of layer
- time_dim (:obj:'int'): dim of time
- out_dim (:obj:'int'): dim of output
- kernel_size (:obj:'int'): kernel_size of conv1d
- dim_mults (:obj:'SequenceType'): mults of dim
"""
super().__init__()
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
time_dim = time_dim or dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.Mish(),
nn.Linear(dim * 4, dim),
)
self.blocks = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
self.blocks.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_in, dim_out, kernel_size=kernel_size, embed_dim=time_dim),
ResidualTemporalBlock(dim_out, dim_out, kernel_size=kernel_size, embed_dim=time_dim),
nn.Conv1d(dim_out, dim_out, 3, 2, 1)
]
)
)
horizon = horizon // 2
mid_dim = dims[-1]
mid_dim_2 = mid_dim // 2
mid_dim_3 = mid_dim // 4
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim_2, kernel_size=kernel_size, embed_dim=time_dim)
self.mid_down1 = nn.Conv1d(mid_dim_2, mid_dim_2, 3, 2, 1)
horizon = horizon // 2
self.mid_block2 = ResidualTemporalBlock(mid_dim_2, mid_dim_3, kernel_size=kernel_size, embed_dim=time_dim)
self.mid_down2 = nn.Conv1d(mid_dim_3, mid_dim_3, 3, 2, 1)
horizon = horizon // 2
fc_dim = mid_dim_3 * max(horizon, 1)
self.final_block = nn.Sequential(
nn.Linear(fc_dim + time_dim, fc_dim // 2),
nn.Mish(),
nn.Linear(fc_dim // 2, out_dim),
)
def forward(self, x, cond, time, *args):
"""
Overview:
compute temporal value forward
Arguments:
- x (:obj:'tensor'): noise trajectory
- cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0
- time (:obj:'int'): timestep of diffusion step
"""
# [batch, horizon, transition ] -> [batch, transition , horizon]
x = x.transpose(1, 2)
t = self.time_mlp(time)
for resnet, resnet2, downsample in self.blocks:
x = resnet(x, t)
x = resnet2(x, t)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_down1(x)
x = self.mid_block2(x, t)
x = self.mid_down2(x)
x = x.view(len(x), -1)
out = self.final_block(torch.cat([x, t], dim=-1))
return out