| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from copy import deepcopy |
| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | import math |
| | import collections.abc |
| | from itertools import repeat |
| | from ldm.modules.new_attention import PositionEmbedding |
| | from einops import rearrange |
| |
|
| | def modulate(x, shift, scale): |
| | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
| |
|
| | def to_2tuple(x): |
| | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): |
| | return x |
| | return tuple(repeat(x, 2)) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TimestepEmbedder(nn.Module): |
| | """ |
| | Embeds scalar timesteps into vector representations. |
| | """ |
| | def __init__(self, hidden_size, frequency_embedding_size=256): |
| | super().__init__() |
| | self.mlp = nn.Sequential( |
| | nn.Linear(frequency_embedding_size, hidden_size, bias=True), |
| | nn.SiLU(), |
| | nn.Linear(hidden_size, hidden_size, bias=True), |
| | ) |
| | self.proj_w = nn.Linear(frequency_embedding_size,frequency_embedding_size,bias=False) |
| | self.frequency_embedding_size = frequency_embedding_size |
| |
|
| | @staticmethod |
| | def timestep_embedding(t, dim, max_period=10000): |
| | """ |
| | Create sinusoidal timestep embeddings. |
| | :param t: a 1-D Tensor of N indices, one per batch element. |
| | These may be fractional. |
| | :param dim: the dimension of the output. |
| | :param max_period: controls the minimum frequency of the embeddings. |
| | :return: an (N, D) Tensor of positional embeddings. |
| | """ |
| | |
| | half = dim // 2 |
| | freqs = torch.exp( |
| | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half |
| | ).to(device=t.device) |
| | args = t[:, None].float() * freqs[None] |
| | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| | if dim % 2: |
| | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
| | return embedding |
| |
|
| | def forward(self, t, w_cond=None): |
| | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) |
| | if w_cond is not None: |
| | t_freq = t_freq + self.proj_w(w_cond) |
| | t_emb = self.mlp(t_freq) |
| | return t_emb |
| |
|
| |
|
| | class Conv1DFinalLayer(nn.Module): |
| | """ |
| | The final layer of CrossAttnDiT. |
| | """ |
| | def __init__(self, hidden_size, out_channels): |
| | super().__init__() |
| | self.norm_final = nn.GroupNorm(16,hidden_size) |
| | self.conv1d = nn.Conv1d(hidden_size, out_channels,kernel_size=1) |
| |
|
| | def forward(self, x): |
| | x = self.norm_final(x) |
| | x = self.conv1d(x) |
| | return x |
| |
|
| | class ConditionEmbedder(nn.Module): |
| | def __init__(self, hidden_size, context_dim): |
| | super().__init__() |
| | self.mlp = nn.Sequential( |
| | nn.Linear(context_dim, hidden_size, bias=True), |
| | nn.GELU(approximate='tanh'), |
| | nn.Linear(hidden_size, hidden_size, bias=True), |
| | nn.LayerNorm(hidden_size) |
| | ) |
| |
|
| | def forward(self,x): |
| | return self.mlp(x) |
| |
|
| | from ldm.modules.new_attention import CrossAttention,Conv1dFeedForward,checkpoint,Normalize,zero_module |
| |
|
| | class BasicTransformerBlock(nn.Module): |
| | def __init__(self, dim, n_heads, d_head, dropout=0., gated_ff=True, checkpoint=True): |
| | super().__init__() |
| | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) |
| | self.ff = Conv1dFeedForward(dim, dropout=dropout, glu=gated_ff) |
| | self.attn2 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) |
| | self.norm1 = nn.LayerNorm(dim) |
| | self.norm2 = nn.LayerNorm(dim) |
| | self.norm3 = nn.LayerNorm(dim) |
| | self.checkpoint = checkpoint |
| |
|
| | def forward(self, x): |
| | return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint) |
| |
|
| | def _forward(self, x): |
| | x = self.attn1(self.norm1(x)) + x |
| | x = self.attn2(self.norm2(x)) + x |
| |
|
| | x = self.ff(self.norm3(x).permute(0,2,1)).permute(0,2,1) + x |
| | return x |
| |
|
| | class TemporalTransformer(nn.Module): |
| | """ |
| | Transformer block for image-like data. |
| | First, project the input (aka embedding) |
| | and reshape to b, t, d. |
| | Then apply standard transformer action. |
| | Finally, reshape to image |
| | """ |
| | def __init__(self, in_channels, n_heads, d_head, |
| | depth=1, dropout=0., context_dim=None): |
| | super().__init__() |
| | self.in_channels = in_channels |
| | inner_dim = n_heads * d_head |
| | self.norm = Normalize(in_channels) |
| | |
| | self.proj_in = nn.Conv1d(in_channels, |
| | inner_dim, |
| | kernel_size=1, |
| | stride=1, |
| | padding=0) |
| | |
| | self.transformer_blocks = nn.ModuleList( |
| | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout) |
| | for d in range(depth)] |
| | ) |
| |
|
| | self.proj_out = zero_module(nn.Conv1d(inner_dim, |
| | in_channels, |
| | kernel_size=1, |
| | stride=1, |
| | padding=0)) |
| |
|
| | def forward(self, x): |
| | |
| | x_in = x |
| | x = self.norm(x) |
| | x = self.proj_in(x) |
| | x = rearrange(x,'b c t -> b t c') |
| | for block in self.transformer_blocks: |
| | x = block(x) |
| | x = rearrange(x,'b t c -> b c t') |
| | |
| | x = self.proj_out(x) |
| | x = x + x_in |
| | return x |
| |
|
| | class ConcatDiT(nn.Module): |
| | """ |
| | Diffusion model with a Transformer backbone. |
| | """ |
| | def __init__( |
| | self, |
| | in_channels, |
| | context_dim, |
| | hidden_size=1152, |
| | depth=28, |
| | num_heads=16, |
| | max_len = 1000, |
| | ): |
| | super().__init__() |
| | self.in_channels = in_channels |
| | self.out_channels = in_channels |
| | self.num_heads = num_heads |
| | kernel_size = 5 |
| | self.t_embedder = TimestepEmbedder(hidden_size) |
| | self.c_embedder = ConditionEmbedder(hidden_size,context_dim) |
| | self.proj_in = nn.Conv1d(in_channels,hidden_size,kernel_size=kernel_size,padding=kernel_size//2) |
| | |
| | self.pos_emb = PositionEmbedding(num_embeddings=max_len,embedding_dim = hidden_size) |
| | self.blocks = nn.ModuleList([ |
| | TemporalTransformer(hidden_size,num_heads,d_head=hidden_size//num_heads,depth=1,context_dim=context_dim) for _ in range(depth) |
| | ]) |
| |
|
| | self.final_layer = Conv1DFinalLayer(hidden_size, self.out_channels) |
| | self.initialize_weights() |
| |
|
| | def initialize_weights(self): |
| | |
| | def _basic_init(module): |
| | if isinstance(module, nn.Linear): |
| | torch.nn.init.xavier_uniform_(module.weight) |
| | if module.bias is not None: |
| | nn.init.constant_(module.bias, 0) |
| | self.apply(_basic_init) |
| |
|
| | |
| | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
| | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
| |
|
| | def forward(self, x, t, context, w_cond=None): |
| | """ |
| | Forward pass of DiT. |
| | x: (N, C, T) tensor of temporal inputs (latent representations of melspec) |
| | t: (N,) tensor of diffusion timesteps |
| | y: (N,max_tokens_len=77, context_dim) |
| | """ |
| | t = self.t_embedder(t, w_cond=w_cond).unsqueeze(1) |
| |
|
| | c = self.c_embedder(context) |
| | extra_len = c.shape[1] + 1 |
| | x = self.proj_in(x) |
| | x = rearrange(x,'b c t -> b t c') |
| | x = torch.concat([t,c,x],dim=1) |
| | x = self.pos_emb(x) |
| | x = rearrange(x,'b t c -> b c t') |
| | for block in self.blocks: |
| | x = block(x) |
| | x = x[...,extra_len:] |
| | x = self.final_layer(x) |
| | return x |
| |
|
| | class ConcatDiT2MLP(nn.Module): |
| | """ |
| | Diffusion model with a Transformer backbone. |
| | """ |
| | def __init__( |
| | self, |
| | in_channels, |
| | context_dim, |
| | hidden_size=1152, |
| | depth=28, |
| | num_heads=16, |
| | max_len = 1000, |
| | ): |
| | super().__init__() |
| | self.in_channels = in_channels |
| | self.out_channels = in_channels |
| | self.num_heads = num_heads |
| | kernel_size = 5 |
| | self.t_embedder = TimestepEmbedder(hidden_size) |
| | self.c1_embedder = ConditionEmbedder(hidden_size,context_dim) |
| | self.c2_embedder = ConditionEmbedder(hidden_size,context_dim) |
| | self.proj_in = nn.Conv1d(in_channels,hidden_size,kernel_size=kernel_size,padding=kernel_size//2) |
| |
|
| | self.pos_emb = PositionEmbedding(num_embeddings=max_len,embedding_dim = hidden_size) |
| | self.blocks = nn.ModuleList([ |
| | TemporalTransformer(hidden_size,num_heads,d_head=hidden_size//num_heads,depth=1,context_dim=context_dim) for _ in range(depth) |
| | ]) |
| |
|
| | self.final_layer = Conv1DFinalLayer(hidden_size, self.out_channels) |
| | self.initialize_weights() |
| |
|
| | def initialize_weights(self): |
| | |
| | def _basic_init(module): |
| | if isinstance(module, nn.Linear): |
| | torch.nn.init.xavier_uniform_(module.weight) |
| | if module.bias is not None: |
| | nn.init.constant_(module.bias, 0) |
| | self.apply(_basic_init) |
| |
|
| | |
| | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
| | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
| |
|
| | def forward(self, x, t, context, w_cond=None): |
| | """ |
| | Forward pass of DiT. |
| | x: (N, C, T) tensor of temporal inputs (latent representations of melspec) |
| | t: (N,) tensor of diffusion timesteps |
| | y: (N,max_tokens_len=77, context_dim) |
| | """ |
| | t = self.t_embedder(t, w_cond=w_cond).unsqueeze(1) |
| | c1,c2 = context.chunk(2,dim=1) |
| | c1 = self.c1_embedder(c1) |
| | c2 = self.c2_embedder(c2) |
| | c = torch.cat((c1,c2),dim=1) |
| | extra_len = c.shape[1] + 1 |
| | x = self.proj_in(x) |
| | x = rearrange(x,'b c t -> b t c') |
| | x = torch.concat([t,c,x],dim=1) |
| | x = self.pos_emb(x) |
| | x = rearrange(x,'b t c -> b c t') |
| | for block in self.blocks: |
| | x = block(x) |
| | x = x[...,extra_len:] |
| | x = self.final_layer(x) |
| | return x |
| |
|
| | class ConcatOrderDiT(nn.Module): |
| | """ |
| | Diffusion model with a Transformer backbone. |
| | """ |
| | def __init__( |
| | self, |
| | in_channels, |
| | context_dim, |
| | hidden_size=1152, |
| | depth=28, |
| | num_heads=16, |
| | max_len = 1000, |
| | ): |
| | super().__init__() |
| | self.in_channels = in_channels |
| | self.out_channels = in_channels |
| | self.num_heads = num_heads |
| | kernel_size = 5 |
| | self.t_embedder = TimestepEmbedder(hidden_size) |
| | self.c_embedder = ConditionEmbedder(hidden_size,context_dim) |
| | self.proj_in = nn.Conv1d(in_channels,hidden_size,kernel_size=kernel_size,padding=kernel_size//2) |
| |
|
| | self.pos_emb = PositionEmbedding(num_embeddings=max_len,embedding_dim = hidden_size) |
| | self.order_embedding = nn.Embedding(num_embeddings=100,embedding_dim = hidden_size) |
| | self.blocks = nn.ModuleList([ |
| | TemporalTransformer(hidden_size,num_heads,d_head=hidden_size//num_heads,depth=1,context_dim=context_dim) for _ in range(depth) |
| | ]) |
| |
|
| | self.final_layer = Conv1DFinalLayer(hidden_size, self.out_channels) |
| | self.initialize_weights() |
| |
|
| | def initialize_weights(self): |
| | |
| | def _basic_init(module): |
| | if isinstance(module, nn.Linear): |
| | torch.nn.init.xavier_uniform_(module.weight) |
| | if module.bias is not None: |
| | nn.init.constant_(module.bias, 0) |
| | self.apply(_basic_init) |
| |
|
| | |
| | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
| | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
| |
|
| | def add_order_embedding(self,token_emb,token_ids,orders_list): |
| | """ |
| | token_emb: shape (N,max_tokens_len=77, hidden_size) |
| | token_ids: shape (N,max_tokens) |
| | order_list: [N*list]. len(order_list[i]) == objs_num in text[i] |
| | """ |
| | for b,orderl in enumerate(orders_list): |
| | orderl = torch.LongTensor(orderl).to(device=self.order_embedding.weight.device) |
| | order_emb = self.order_embedding(orderl) |
| | obj2index = [] |
| | cur_obj = 0 |
| | for i in range(token_ids.shape[1]): |
| | token_id = token_ids[b][i] |
| | if token_id in [101,102,0,1064]: |
| | obj2index.append(-1) |
| | if token_id == 1064: |
| | cur_obj += 1 |
| | else: |
| | obj2index.append(cur_obj) |
| | for i,order_index in enumerate(obj2index): |
| | if order_index != -1: |
| | token_emb[b][i] += order_emb[order_index] |
| | return token_emb |
| |
|
| |
|
| | def forward(self, x, t, context): |
| | """ |
| | Forward pass of DiT. |
| | x: (N, C, T) tensor of temporal inputs (latent representations of melspec) |
| | t: (N,) tensor of diffusion timesteps |
| | context: dict{'token_embedding':(N,max_tokens_len=77, context_dim),'token_ids':tokens:(N,max_tokens_len=77),'orders':orders_list} |
| | """ |
| | token_embedding = context['token_embedding'] |
| | token_ids = context['token_ids'] |
| | orders = context['orders'] |
| | t = self.t_embedder(t).unsqueeze(1) |
| | c = self.c_embedder(token_embedding) |
| | c = self.add_order_embedding(c,token_ids,orders) |
| | extra_len = c.shape[1] + 1 |
| | x = self.proj_in(x) |
| | x = rearrange(x,'b c t -> b t c') |
| | x = torch.concat([t,c,x],dim=1) |
| | x = self.pos_emb(x) |
| | x = rearrange(x,'b t c -> b c t') |
| | for block in self.blocks: |
| | x = block(x) |
| | x = x[...,extra_len:] |
| | x = self.final_layer(x) |
| | return x |
| |
|
| | class ConcatOrderDiT2(nn.Module): |
| | """ |
| | Diffusion model with a Transformer backbone. concat by token |
| | """ |
| | def __init__( |
| | self, |
| | in_channels, |
| | context_dim, |
| | hidden_size=1152, |
| | depth=28, |
| | num_heads=16, |
| | max_len = 1000, |
| | ): |
| | super().__init__() |
| | self.in_channels = in_channels |
| | self.out_channels = in_channels |
| | self.num_heads = num_heads |
| | kernel_size = 5 |
| | self.t_embedder = TimestepEmbedder(hidden_size) |
| | self.c_embedder = ConditionEmbedder(hidden_size,context_dim) |
| | self.proj_in = nn.Conv1d(in_channels,hidden_size,kernel_size=kernel_size,padding=kernel_size//2) |
| |
|
| | self.pos_emb = PositionEmbedding(num_embeddings=max_len,embedding_dim = hidden_size) |
| | self.max_objs = 10 |
| | self.max_objs_order = 100 |
| | self.order_embedding = nn.Embedding(num_embeddings=self.max_objs_order + 1,embedding_dim = hidden_size) |
| | self.blocks = nn.ModuleList([ |
| | TemporalTransformer(hidden_size,num_heads,d_head=hidden_size//num_heads,depth=1,context_dim=context_dim) for _ in range(depth) |
| | ]) |
| |
|
| | self.final_layer = Conv1DFinalLayer(hidden_size, self.out_channels) |
| | self.initialize_weights() |
| |
|
| | def initialize_weights(self): |
| | |
| | def _basic_init(module): |
| | if isinstance(module, nn.Linear): |
| | torch.nn.init.xavier_uniform_(module.weight) |
| | if module.bias is not None: |
| | nn.init.constant_(module.bias, 0) |
| | self.apply(_basic_init) |
| |
|
| | |
| | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
| | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
| |
|
| | def concat_order_embedding(self,token_emb,token_ids,orders_list): |
| | """ |
| | token_emb: shape (N,max_tokens_len=77, hidden_size) |
| | token_ids: shape (N,max_tokens) |
| | order_list: [N*list]. len(order_list[i]) == objs_num in text[i] |
| | return token_emb: shape (N,max_tokens_len+self.max_objs, hidden_size) |
| | """ |
| | bsz,t,c = token_emb.shape |
| | token_emb = list(torch.tensor_split(token_emb,bsz)) |
| | orders_list = deepcopy(orders_list) |
| | for i in range(bsz): |
| | token_emb[i] = list(torch.tensor_split(token_emb[i].squeeze(0),t)) |
| | for b,orderl in enumerate(orders_list): |
| | orderl.append(self.max_objs_order) |
| | orderl = torch.LongTensor(orderl).to(device=self.order_embedding.weight.device) |
| | order_emb = self.order_embedding(orderl) |
| | order_emb = torch.tensor_split(order_emb,len(orderl)) |
| | obj_insert_index = [] |
| | for i in range(token_ids.shape[1]): |
| | token_id = token_ids[b][i] |
| | if token_id == 1064: |
| | obj_insert_index.append(i+len(obj_insert_index)) |
| | for i,index in enumerate(obj_insert_index): |
| | token_emb[b].insert(index,order_emb[i]) |
| | |
| | for i in range(self.max_objs-len(orderl)+1): |
| | token_emb[b].append(order_emb[-1]) |
| | token_emb[b] = torch.concat(token_emb[b]) |
| | |
| | token_emb = torch.stack(token_emb) |
| | return token_emb |
| |
|
| |
|
| | def forward(self, x, t, context): |
| | """ |
| | Forward pass of DiT. |
| | x: (N, C, T) tensor of temporal inputs (latent representations of melspec) |
| | t: (N,) tensor of diffusion timesteps |
| | context: dict{'token_embedding':(N,max_tokens_len=77, context_dim),'token_ids':tokens:(N,max_tokens_len=77),'orders':orders_list} |
| | """ |
| | token_embedding = context['token_embedding'] |
| | token_ids = context['token_ids'] |
| | orders = context['orders'] |
| | t = self.t_embedder(t).unsqueeze(1) |
| | c = self.c_embedder(token_embedding) |
| | c = self.concat_order_embedding(c,token_ids,orders) |
| | extra_len = c.shape[1] + 1 |
| | x = self.proj_in(x) |
| | x = rearrange(x,'b c t -> b t c') |
| | x = torch.concat([t,c,x],dim=1) |
| | x = self.pos_emb(x) |
| | x = rearrange(x,'b t c -> b c t') |
| | for block in self.blocks: |
| | x = block(x) |
| | x = x[...,extra_len:] |
| | x = self.final_layer(x) |
| | return x |
| |
|
| |
|