import torch import torch.nn as nn import math class Unet(nn.Module): def __init__( self, dim_model, num_heads, num_layers, dropout_p, dim_input, dim_output, free_p=0.1, text_emb=True, device='cuda', **kwargs ): super().__init__() # INFO self.model_type = "Transformer" self.dim_model = dim_model self.text_emb = text_emb self.dim_input = dim_input self.device = device try: self.Disc = kwargs['Disc'] except: self.Disc = False # layers self.free_p = free_p self.positional_encoder = PositionalEncoding( dim_model=dim_model, dropout_p=dropout_p, max_len=5000 ) self.embedding_input = nn.Linear(dim_input, dim_model) self.embedding_original = nn.Linear(dim_input, dim_model) encoder_layer = nn.TransformerEncoderLayer(d_model=dim_model, nhead=num_heads, dim_feedforward=dim_model*4, dropout=dropout_p, activation="gelu", ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, ) if self.Disc: # for discriminator self.pred = nn.Sequential(nn.Linear(dim_output, dim_output), nn.SiLU(inplace=False), nn.Linear(dim_output, 1), nn.Sigmoid()) self.out = nn.Linear(dim_model, dim_output) self.embed_timestep = TimestepEmbedder(self.dim_model, self.positional_encoder) if self.text_emb: #for embedding progress indicator print("text embedding is enabled!") self.positional_encoder_pi = PositionalEncoding( dim_model=dim_model, dropout_p=dropout_p, max_len=5000 ) self.embed_prog_ind = ProgIndEmbedder(self.dim_model, self.positional_encoder_pi) def forward_disc(self, x, timesteps): t_emb = self.embed_timestep(timesteps) # t_emb refers to time embedding x, t_emb = x.permute(1, 0, 2), t_emb.permute(1, 0, 2) x = self.embedding_input(x) * math.sqrt(self.dim_model) x = torch.cat((t_emb, x), dim=0) x = self.positional_encoder(x) x = self.transformer(x) output = self.out(x)[1:] output = output.permute(1, 0, 2) output = output.mean(dim=1) output = self.pred(output) return output def forward_(self, x, timesteps, text_emb=None, prog_ind=None, joints_orig=None): t_emb = self.embed_timestep(timesteps) # t_emb refers to time embedding if self.text_emb: text_emb = text_emb.unsqueeze(1) # batchsize, 1, 512 assert text_emb.shape == (x.shape[0], 1, self.dim_model), \ f'text_emb shape should be (batchsize, 1, {self.dim_model})' x, joints_orig, t_emb = x.permute(1, 0, 2), joints_orig.permute(1, 0, 2), t_emb.permute(1, 0, 2) x = self.embedding_input(x) * math.sqrt(self.dim_model) joints_orig = self.embedding_original(joints_orig) * math.sqrt(self.dim_model) x = (x + joints_orig) / 2. if not self.text_emb: x = torch.cat((t_emb, x), dim=0) # (seq_len+1), batchsize, dim_model else: text_emb = text_emb.permute(1, 0, 2) prog_ind = (prog_ind*100).round().to(torch.int64) prog_ind_emb = self.embed_prog_ind(prog_ind).permute(1, 0, 2) t_emb = (t_emb + text_emb/10.0 + prog_ind_emb) * math.sqrt(self.dim_model) x = torch.cat((t_emb, x), dim=0) x = self.positional_encoder(x) x = self.transformer(x) output = self.out(x)[1:] output = output.permute(1, 0, 2) return output def forward(self, x, timesteps, text_emb=None, prog_ind=None, joints_orig=None): if self.Disc: return self.forward_disc(x, timesteps) else: return self.forward_(x, timesteps, text_emb, prog_ind, joints_orig) class PositionalEncoding(nn.Module): def __init__(self, dim_model, dropout_p, max_len): super().__init__() # Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html # max_len determines how far the position can have an effect on a token (window) # Info self.dropout = nn.Dropout(dropout_p) # Encoding - From formula pos_encoding = torch.zeros(max_len, dim_model) positions_list = torch.arange(0, max_len, dtype=torch.float).reshape(-1, 1) # 0, 1, 2, 3, 4, 5 division_term = torch.exp( torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model) # PE(pos, 2i) = sin(pos/1000^(2i/dim_model)) pos_encoding[:, 0::2] = torch.sin(positions_list * division_term) # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model)) pos_encoding[:, 1::2] = torch.cos(positions_list * division_term) # Saving buffer (same as parameter without gradients needed) pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1) self.register_buffer("pos_encoding", pos_encoding) def forward(self, token_embedding: torch.tensor) -> torch.tensor: # Residual connection + pos encoding return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :]) class TimestepEmbedder(nn.Module): def __init__(self, latent_dim, sequence_pos_encoder): super().__init__() self.latent_dim = latent_dim self.sequence_pos_encoder = sequence_pos_encoder time_embed_dim = self.latent_dim self.time_embed = nn.Sequential( nn.Linear(self.latent_dim, time_embed_dim), nn.SiLU(inplace=False), nn.Linear(time_embed_dim, time_embed_dim), ) def forward(self, timesteps): return self.time_embed(self.sequence_pos_encoder.pos_encoding[timesteps])#.permute(1, 0, 2) # totally the same as TimeStepEmbedder class ProgIndEmbedder(nn.Module): def __init__(self, latent_dim, sequence_pos_encoder): super().__init__() self.latent_dim = latent_dim self.sequence_pos_encoder = sequence_pos_encoder time_embed_dim = self.latent_dim self.time_embed = nn.Sequential( nn.Linear(self.latent_dim, time_embed_dim), nn.SiLU(inplace=False), nn.Linear(time_embed_dim, time_embed_dim), ) def forward(self, timesteps): return self.time_embed(self.sequence_pos_encoder.pos_encoding[timesteps])#.permute(1, 0, 2)