|
from typing import Union, List, Dict |
|
from collections import namedtuple |
|
import numpy as np |
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from ding.utils import list_split, MODEL_REGISTRY, squeeze, SequenceType |
|
|
|
|
|
def extract(a, t, x_shape): |
|
""" |
|
Overview: |
|
extract output from a through index t. |
|
Arguments: |
|
- a (:obj:`torch.Tensor`): input tensor |
|
- t (:obj:`torch.Tensor`): index tensor |
|
- x_shape (:obj:`torch.Tensor`): shape of x |
|
""" |
|
b, *_ = t.shape |
|
out = a.gather(-1, t) |
|
return out.reshape(b, *((1, ) * (len(x_shape) - 1))) |
|
|
|
|
|
def cosine_beta_schedule(timesteps: int, s: float = 0.008, dtype=torch.float32): |
|
""" |
|
Overview: |
|
cosine schedule |
|
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ |
|
Arguments: |
|
- timesteps (:obj:`int`): timesteps of diffusion step |
|
- s (:obj:`float`): s |
|
- dtype (:obj:`torch.dtype`): dtype of beta |
|
Return: |
|
Tensor of beta [timesteps,], computing by cosine. |
|
""" |
|
steps = timesteps + 1 |
|
x = np.linspace(0, steps, steps) |
|
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 |
|
alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
|
betas_clipped = np.clip(betas, a_min=0, a_max=0.999) |
|
return torch.tensor(betas_clipped, dtype=dtype) |
|
|
|
|
|
def apply_conditioning(x, conditions, action_dim): |
|
""" |
|
Overview: |
|
add condition into x |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): input tensor |
|
- conditions (:obj:`dict`): condition dict, key is timestep, value is condition |
|
- action_dim (:obj:`int`): action dim |
|
""" |
|
for t, val in conditions.items(): |
|
x[:, t, action_dim:] = val.clone() |
|
return x |
|
|
|
|
|
class DiffusionConv1d(nn.Module): |
|
""" |
|
Overview: |
|
Conv1d with activation and normalization for diffusion models. |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size: int, |
|
padding: int, |
|
activation: nn.Module = None, |
|
n_groups: int = 8 |
|
) -> None: |
|
""" |
|
Overview: |
|
Create a 1-dim convlution layer with activation and normalization. This Conv1d have GropuNorm. |
|
And need add 1-dim when compute norm |
|
Arguments: |
|
- in_channels (:obj:`int`): Number of channels in the input tensor |
|
- out_channels (:obj:`int`): Number of channels in the output tensor |
|
- kernel_size (:obj:`int`): Size of the convolving kernel |
|
- padding (:obj:`int`): Zero-padding added to both sides of the input |
|
- activation (:obj:`nn.Module`): the optional activation function |
|
""" |
|
super().__init__() |
|
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) |
|
self.norm = nn.GroupNorm(n_groups, out_channels) |
|
self.act = activation |
|
|
|
def forward(self, inputs) -> torch.Tensor: |
|
""" |
|
Overview: |
|
compute conv1d for inputs. |
|
Arguments: |
|
- inputs (:obj:`torch.Tensor`): input tensor |
|
Return: |
|
- out (:obj:`torch.Tensor`): output tensor |
|
""" |
|
x = self.conv1(inputs) |
|
|
|
x = x.unsqueeze(-2) |
|
x = self.norm(x) |
|
|
|
x = x.squeeze(-2) |
|
out = self.act(x) |
|
return out |
|
|
|
|
|
class SinusoidalPosEmb(nn.Module): |
|
""" |
|
Overview: |
|
class for computing sin position embeding |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__(self, dim: int) -> None: |
|
""" |
|
Overview: |
|
Initialization of SinusoidalPosEmb class |
|
Arguments: |
|
- dim (:obj:`int`): dimension of embeding |
|
""" |
|
|
|
super().__init__() |
|
self.dim = dim |
|
|
|
def forward(self, x) -> torch.Tensor: |
|
""" |
|
Overview: |
|
compute sin position embeding |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): input tensor |
|
Return: |
|
- emb (:obj:`torch.Tensor`): output tensor |
|
""" |
|
|
|
device = x.device |
|
half_dim = self.dim // 2 |
|
emb = math.log(10000) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) |
|
emb = x[:, None] * emb[None, :] |
|
emb = torch.cat((emb.sin(), emb.cos()), dim=1) |
|
return emb |
|
|
|
|
|
class Residual(nn.Module): |
|
""" |
|
Overview: |
|
Basic Residual block |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__(self, fn): |
|
""" |
|
Overview: |
|
Initialization of Residual class |
|
Arguments: |
|
- fn (:obj:`nn.Module`): function of residual block |
|
""" |
|
|
|
super().__init__() |
|
self.fn = fn |
|
|
|
def forward(self, x, *arg, **kwargs): |
|
""" |
|
Overview: |
|
compute residual block |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): input tensor |
|
""" |
|
|
|
return self.fn(x, *arg, **kwargs) + x |
|
|
|
|
|
class LayerNorm(nn.Module): |
|
""" |
|
Overview: |
|
LayerNorm, compute dim = 1, because Temporal input x [batch, dim, horizon] |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__(self, dim, eps=1e-5) -> None: |
|
""" |
|
Overview: |
|
Initialization of LayerNorm class |
|
Arguments: |
|
- dim (:obj:`int`): dimension of input |
|
- eps (:obj:`float`): eps of LayerNorm |
|
""" |
|
|
|
super().__init__() |
|
self.eps = eps |
|
self.g = nn.Parameter(torch.ones(1, dim, 1)) |
|
self.b = nn.Parameter(torch.zeros(1, dim, 1)) |
|
|
|
def forward(self, x): |
|
""" |
|
Overview: |
|
compute LayerNorm |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): input tensor |
|
""" |
|
|
|
print('x.shape:', x.shape) |
|
var = torch.var(x, dim=1, unbiased=False, keepdim=True) |
|
mean = torch.mean(x, dim=1, keepdim=True) |
|
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b |
|
|
|
|
|
class PreNorm(nn.Module): |
|
""" |
|
Overview: |
|
PreNorm, compute dim = 1, because Temporal input x [batch, dim, horizon] |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__(self, dim, fn) -> None: |
|
""" |
|
Overview: |
|
Initialization of PreNorm class |
|
Arguments: |
|
- dim (:obj:`int`): dimension of input |
|
- fn (:obj:`nn.Module`): function of residual block |
|
""" |
|
|
|
super().__init__() |
|
self.fn = fn |
|
self.norm = LayerNorm(dim) |
|
|
|
def forward(self, x): |
|
""" |
|
Overview: |
|
compute PreNorm |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): input tensor |
|
""" |
|
x = self.norm(x) |
|
return self.fn(x) |
|
|
|
|
|
class LinearAttention(nn.Module): |
|
""" |
|
Overview: |
|
Linear Attention head |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__(self, dim, heads=4, dim_head=32) -> None: |
|
""" |
|
Overview: |
|
Initialization of LinearAttention class |
|
Arguments: |
|
- dim (:obj:`int`): dimension of input |
|
- heads (:obj:`int`): heads of attention |
|
- dim_head (:obj:`int`): dim of head |
|
""" |
|
super().__init__() |
|
self.scale = dim_head ** -0.5 |
|
self.heads = heads |
|
hidden_dim = dim_head * heads |
|
self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False) |
|
self.to_out = nn.Conv1d(hidden_dim, dim, 1) |
|
|
|
def forward(self, x): |
|
""" |
|
Overview: |
|
compute LinearAttention |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): input tensor |
|
""" |
|
qkv = self.to_qkv(x).chunk(3, dim=1) |
|
q, k, v = map(lambda t: t.reshape(t.shape[0], self.heads, -1, t.shape[-1]), qkv) |
|
q = q * self.scale |
|
k = k.softmax(dim=-1) |
|
context = torch.einsum('b h d n, b h e n -> b h d e', k, v) |
|
|
|
out = torch.einsum('b h d e, b h d n -> b h e n', context, q) |
|
out = out.reshape(out.shape[0], -1, out.shape[-1]) |
|
return self.to_out(out) |
|
|
|
|
|
class ResidualTemporalBlock(nn.Module): |
|
""" |
|
Overview: |
|
Residual block of temporal |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__( |
|
self, in_channels: int, out_channels: int, embed_dim: int, kernel_size: int = 5, mish: bool = True |
|
) -> None: |
|
""" |
|
Overview: |
|
Initialization of ResidualTemporalBlock class |
|
Arguments: |
|
- in_channels (:obj:'int'): dim of in_channels |
|
- out_channels (:obj:'int'): dim of out_channels |
|
- embed_dim (:obj:'int'): dim of embeding layer |
|
- kernel_size (:obj:'int'): kernel_size of conv1d |
|
- mish (:obj:'bool'): whether use mish as a activate function |
|
""" |
|
super().__init__() |
|
if mish: |
|
act = nn.Mish() |
|
else: |
|
act = nn.SiLU() |
|
self.blocks = nn.ModuleList( |
|
[ |
|
DiffusionConv1d(in_channels, out_channels, kernel_size, kernel_size // 2, act), |
|
DiffusionConv1d(out_channels, out_channels, kernel_size, kernel_size // 2, act), |
|
] |
|
) |
|
self.time_mlp = nn.Sequential( |
|
act, |
|
nn.Linear(embed_dim, out_channels), |
|
) |
|
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ |
|
if in_channels != out_channels else nn.Identity() |
|
|
|
def forward(self, x, t): |
|
""" |
|
Overview: |
|
compute residual block |
|
Arguments: |
|
- x (:obj:'tensor'): input tensor |
|
- t (:obj:'tensor'): time tensor |
|
""" |
|
out = self.blocks[0](x) + self.time_mlp(t).unsqueeze(-1) |
|
out = self.blocks[1](out) |
|
return out + self.residual_conv(x) |
|
|
|
|
|
class DiffusionUNet1d(nn.Module): |
|
""" |
|
Overview: |
|
Diffusion unet for 1d vector data |
|
Interfaces: |
|
``__init__``, ``forward``, ``get_pred`` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
transition_dim: int, |
|
dim: int = 32, |
|
dim_mults: SequenceType = [1, 2, 4, 8], |
|
returns_condition: bool = False, |
|
condition_dropout: float = 0.1, |
|
calc_energy: bool = False, |
|
kernel_size: int = 5, |
|
attention: bool = False, |
|
) -> None: |
|
""" |
|
Overview: |
|
Initialization of DiffusionUNet1d class |
|
Arguments: |
|
- transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim |
|
- dim (:obj:'int'): dim of layer |
|
- dim_mults (:obj:'SequenceType'): mults of dim |
|
- returns_condition (:obj:'bool'): whether use return as a condition |
|
- condition_dropout (:obj:'float'): dropout of returns condition |
|
- calc_energy (:obj:'bool'): whether use calc_energy |
|
- kernel_size (:obj:'int'): kernel_size of conv1d |
|
- attention (:obj:'bool'): whether use attention |
|
""" |
|
super().__init__() |
|
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] |
|
in_out = list(zip(dims[:-1], dims[1:])) |
|
|
|
if calc_energy: |
|
mish = False |
|
act = nn.SiLU() |
|
else: |
|
mish = True |
|
act = nn.Mish() |
|
|
|
self.time_dim = dim |
|
self.returns_dim = dim |
|
|
|
self.time_mlp = nn.Sequential( |
|
SinusoidalPosEmb(dim), |
|
nn.Linear(dim, dim * 4), |
|
act, |
|
nn.Linear(dim * 4, dim), |
|
) |
|
|
|
self.returns_condition = returns_condition |
|
self.condition_dropout = condition_dropout |
|
self.cale_energy = calc_energy |
|
|
|
if self.returns_condition: |
|
self.returns_mlp = nn.Sequential( |
|
nn.Linear(1, dim), |
|
act, |
|
nn.Linear(dim, dim * 4), |
|
act, |
|
nn.Linear(dim * 4, dim), |
|
) |
|
self.mask_dist = torch.distributions.Bernoulli(probs=1 - self.condition_dropout) |
|
embed_dim = 2 * dim |
|
else: |
|
embed_dim = dim |
|
|
|
self.downs = nn.ModuleList([]) |
|
self.ups = nn.ModuleList([]) |
|
num_resolution = len(in_out) |
|
|
|
for ind, (dim_in, dim_out) in enumerate(in_out): |
|
is_last = ind >= (num_resolution - 1) |
|
self.downs.append( |
|
nn.ModuleList( |
|
[ |
|
ResidualTemporalBlock(dim_in, dim_out, embed_dim, kernel_size, mish=mish), |
|
ResidualTemporalBlock(dim_out, dim_out, embed_dim, kernel_size, mish=mish), |
|
Residual(PreNorm(dim_out, LinearAttention(dim_out))) if attention else nn.Identity(), |
|
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity() |
|
] |
|
) |
|
) |
|
|
|
mid_dim = dims[-1] |
|
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim, kernel_size, mish) |
|
self.mid_atten = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) if attention else nn.Identity() |
|
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim, kernel_size, mish) |
|
|
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): |
|
is_last = ind >= (num_resolution - 1) |
|
self.ups.append( |
|
nn.ModuleList( |
|
[ |
|
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim, kernel_size, mish=mish), |
|
ResidualTemporalBlock(dim_in, dim_in, embed_dim, kernel_size, mish=mish), |
|
Residual(PreNorm(dim_in, LinearAttention(dim_in))) if attention else nn.Identity(), |
|
nn.ConvTranspose1d(dim_in, dim_in, 4, 2, 1) if not is_last else nn.Identity() |
|
] |
|
) |
|
) |
|
|
|
self.final_conv = nn.Sequential( |
|
DiffusionConv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, activation=act), |
|
nn.Conv1d(dim, transition_dim, 1), |
|
) |
|
|
|
def forward(self, x, cond, time, returns=None, use_dropout: bool = True, force_dropout: bool = False): |
|
""" |
|
Overview: |
|
compute diffusion unet forward |
|
Arguments: |
|
- x (:obj:'tensor'): noise trajectory |
|
- cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 |
|
- time (:obj:'int'): timestep of diffusion step |
|
- returns (:obj:'tensor'): condition returns of trajectory, returns is normal return |
|
- use_dropout (:obj:'bool'): Whether use returns condition mask |
|
- force_dropout (:obj:'bool'): Whether use returns condition |
|
""" |
|
if self.cale_energy: |
|
x_inp = x |
|
|
|
|
|
x = x.transpose(1, 2) |
|
t = self.time_mlp(time) |
|
|
|
if self.returns_condition: |
|
assert returns is not None |
|
returns_embed = self.returns_mlp(returns) |
|
if use_dropout: |
|
mask = self.mask_dist.sample(sample_shape=(returns_embed.size(0), 1)).to(returns_embed.device) |
|
returns_embed = mask * returns_embed |
|
if force_dropout: |
|
returns_embed = 0 * returns_embed |
|
t = torch.cat([t, returns_embed], dim=-1) |
|
|
|
h = [] |
|
|
|
for resnet, resnet2, atten, downsample in self.downs: |
|
x = resnet(x, t) |
|
x = resnet2(x, t) |
|
x = atten(x) |
|
h.append(x) |
|
x = downsample(x) |
|
|
|
x = self.mid_block1(x, t) |
|
x = self.mid_atten(x) |
|
x = self.mid_block2(x, t) |
|
|
|
for resnet, resnet2, atten, upsample in self.ups: |
|
x = torch.cat((x, h.pop()), dim=1) |
|
x = resnet(x, t) |
|
x = resnet2(x, t) |
|
x = atten(x) |
|
x = upsample(x) |
|
|
|
x = self.final_conv(x) |
|
|
|
x = x.transpose(1, 2) |
|
|
|
if self.cale_energy: |
|
|
|
energy = ((x - x_inp) ** 2).mean() |
|
grad = torch.autograd.grad(outputs=energy, inputs=x_inp, create_graph=True) |
|
return grad[0] |
|
else: |
|
return x |
|
|
|
def get_pred(self, x, cond, time, returns: bool = None, use_dropout: bool = True, force_dropout: bool = False): |
|
""" |
|
Overview: |
|
compute diffusion unet forward |
|
Arguments: |
|
- x (:obj:'tensor'): noise trajectory |
|
- cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 |
|
- time (:obj:'int'): timestep of diffusion step |
|
- returns (:obj:'tensor'): condition returns of trajectory, returns is normal return |
|
- use_dropout (:obj:'bool'): Whether use returns condition mask |
|
- force_dropout (:obj:'bool'): Whether use returns condition |
|
""" |
|
|
|
x = x.transpose(1, 2) |
|
t = self.time_mlp(time) |
|
|
|
if self.returns_condition: |
|
assert returns is not None |
|
returns_embed = self.returns_mlp(returns) |
|
if use_dropout: |
|
mask = self.mask_dist.sample(sample_shape=(returns_embed.size(0), 1)).to(returns_embed.device) |
|
returns_embed = mask * returns_embed |
|
if force_dropout: |
|
returns_embed = 0 * returns_embed |
|
t = torch.cat([t, returns_embed], dim=-1) |
|
|
|
h = [] |
|
|
|
for resnet, resnet2, downsample in self.downs: |
|
x = resnet(x, t) |
|
x = resnet2(x, t) |
|
h.append(x) |
|
x = downsample(x) |
|
|
|
x = self.mid_block1(x, t) |
|
x = self.mid_block2(x, t) |
|
|
|
for resnet, resnet2, upsample in self.ups: |
|
x = torch.cat((x, h.pop()), dim=1) |
|
x = resnet(x, t) |
|
x = resnet2(x, t) |
|
x = upsample(x) |
|
|
|
x = self.final_conv(x) |
|
|
|
x = x.transpose(1, 2) |
|
return x |
|
|
|
|
|
class TemporalValue(nn.Module): |
|
""" |
|
Overview: |
|
temporal net for value function |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
horizon: int, |
|
transition_dim: int, |
|
dim: int = 32, |
|
time_dim: int = None, |
|
out_dim: int = 1, |
|
kernel_size: int = 5, |
|
dim_mults: SequenceType = [1, 2, 4, 8], |
|
) -> None: |
|
""" |
|
Overview: |
|
Initialization of TemporalValue class |
|
Arguments: |
|
- horizon (:obj:'int'): horizon of trajectory |
|
- transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim |
|
- dim (:obj:'int'): dim of layer |
|
- time_dim (:obj:'int'): dim of time |
|
- out_dim (:obj:'int'): dim of output |
|
- kernel_size (:obj:'int'): kernel_size of conv1d |
|
- dim_mults (:obj:'SequenceType'): mults of dim |
|
""" |
|
super().__init__() |
|
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] |
|
in_out = list(zip(dims[:-1], dims[1:])) |
|
|
|
time_dim = time_dim or dim |
|
self.time_mlp = nn.Sequential( |
|
SinusoidalPosEmb(dim), |
|
nn.Linear(dim, dim * 4), |
|
nn.Mish(), |
|
nn.Linear(dim * 4, dim), |
|
) |
|
self.blocks = nn.ModuleList([]) |
|
|
|
for ind, (dim_in, dim_out) in enumerate(in_out): |
|
self.blocks.append( |
|
nn.ModuleList( |
|
[ |
|
ResidualTemporalBlock(dim_in, dim_out, kernel_size=kernel_size, embed_dim=time_dim), |
|
ResidualTemporalBlock(dim_out, dim_out, kernel_size=kernel_size, embed_dim=time_dim), |
|
nn.Conv1d(dim_out, dim_out, 3, 2, 1) |
|
] |
|
) |
|
) |
|
|
|
horizon = horizon // 2 |
|
|
|
mid_dim = dims[-1] |
|
mid_dim_2 = mid_dim // 2 |
|
mid_dim_3 = mid_dim // 4 |
|
|
|
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim_2, kernel_size=kernel_size, embed_dim=time_dim) |
|
self.mid_down1 = nn.Conv1d(mid_dim_2, mid_dim_2, 3, 2, 1) |
|
|
|
horizon = horizon // 2 |
|
self.mid_block2 = ResidualTemporalBlock(mid_dim_2, mid_dim_3, kernel_size=kernel_size, embed_dim=time_dim) |
|
self.mid_down2 = nn.Conv1d(mid_dim_3, mid_dim_3, 3, 2, 1) |
|
horizon = horizon // 2 |
|
|
|
fc_dim = mid_dim_3 * max(horizon, 1) |
|
self.final_block = nn.Sequential( |
|
nn.Linear(fc_dim + time_dim, fc_dim // 2), |
|
nn.Mish(), |
|
nn.Linear(fc_dim // 2, out_dim), |
|
) |
|
|
|
def forward(self, x, cond, time, *args): |
|
""" |
|
Overview: |
|
compute temporal value forward |
|
Arguments: |
|
- x (:obj:'tensor'): noise trajectory |
|
- cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 |
|
- time (:obj:'int'): timestep of diffusion step |
|
""" |
|
|
|
x = x.transpose(1, 2) |
|
t = self.time_mlp(time) |
|
for resnet, resnet2, downsample in self.blocks: |
|
x = resnet(x, t) |
|
x = resnet2(x, t) |
|
x = downsample(x) |
|
|
|
x = self.mid_block1(x, t) |
|
x = self.mid_down1(x) |
|
|
|
x = self.mid_block2(x, t) |
|
x = self.mid_down2(x) |
|
x = x.view(len(x), -1) |
|
out = self.final_block(torch.cat([x, t], dim=-1)) |
|
return out |
|
|