Yzy00518's picture
Upload src/model/unet.py with huggingface_hub
9632411
import torch
import torch.nn as nn
import math
class Unet(nn.Module):
def __init__(
self,
dim_model,
num_heads,
num_layers,
dropout_p,
dim_input,
dim_output,
free_p=0.1,
text_emb=True,
device='cuda',
**kwargs
):
super().__init__()
# INFO
self.model_type = "Transformer"
self.dim_model = dim_model
self.text_emb = text_emb
self.dim_input = dim_input
self.device = device
try:
self.Disc = kwargs['Disc']
except:
self.Disc = False
# layers
self.free_p = free_p
self.positional_encoder = PositionalEncoding(
dim_model=dim_model, dropout_p=dropout_p, max_len=5000
)
self.embedding_input = nn.Linear(dim_input, dim_model)
self.embedding_original = nn.Linear(dim_input, dim_model)
encoder_layer = nn.TransformerEncoderLayer(d_model=dim_model,
nhead=num_heads,
dim_feedforward=dim_model*4,
dropout=dropout_p,
activation="gelu",
)
self.transformer = nn.TransformerEncoder(encoder_layer,
num_layers=num_layers,
)
if self.Disc:
# for discriminator
self.pred = nn.Sequential(nn.Linear(dim_output, dim_output),
nn.SiLU(inplace=False),
nn.Linear(dim_output, 1),
nn.Sigmoid())
self.out = nn.Linear(dim_model, dim_output)
self.embed_timestep = TimestepEmbedder(self.dim_model, self.positional_encoder)
if self.text_emb:
#for embedding progress indicator
print("text embedding is enabled!")
self.positional_encoder_pi = PositionalEncoding(
dim_model=dim_model, dropout_p=dropout_p, max_len=5000
)
self.embed_prog_ind = ProgIndEmbedder(self.dim_model, self.positional_encoder_pi)
def forward_disc(self, x, timesteps):
t_emb = self.embed_timestep(timesteps) # t_emb refers to time embedding
x, t_emb = x.permute(1, 0, 2), t_emb.permute(1, 0, 2)
x = self.embedding_input(x) * math.sqrt(self.dim_model)
x = torch.cat((t_emb, x), dim=0)
x = self.positional_encoder(x)
x = self.transformer(x)
output = self.out(x)[1:]
output = output.permute(1, 0, 2)
output = output.mean(dim=1)
output = self.pred(output)
return output
def forward_(self, x, timesteps, text_emb=None, prog_ind=None, joints_orig=None):
t_emb = self.embed_timestep(timesteps) # t_emb refers to time embedding
if self.text_emb:
text_emb = text_emb.unsqueeze(1) # batchsize, 1, 512
assert text_emb.shape == (x.shape[0], 1, self.dim_model), \
f'text_emb shape should be (batchsize, 1, {self.dim_model})'
x, joints_orig, t_emb = x.permute(1, 0, 2), joints_orig.permute(1, 0, 2), t_emb.permute(1, 0, 2)
x = self.embedding_input(x) * math.sqrt(self.dim_model)
joints_orig = self.embedding_original(joints_orig) * math.sqrt(self.dim_model)
x = (x + joints_orig) / 2.
if not self.text_emb:
x = torch.cat((t_emb, x), dim=0) # (seq_len+1), batchsize, dim_model
else:
text_emb = text_emb.permute(1, 0, 2)
prog_ind = (prog_ind*100).round().to(torch.int64)
prog_ind_emb = self.embed_prog_ind(prog_ind).permute(1, 0, 2)
t_emb = (t_emb + text_emb/10.0 + prog_ind_emb) * math.sqrt(self.dim_model)
x = torch.cat((t_emb, x), dim=0)
x = self.positional_encoder(x)
x = self.transformer(x)
output = self.out(x)[1:]
output = output.permute(1, 0, 2)
return output
def forward(self, x, timesteps, text_emb=None, prog_ind=None, joints_orig=None):
if self.Disc:
return self.forward_disc(x, timesteps)
else:
return self.forward_(x, timesteps, text_emb, prog_ind, joints_orig)
class PositionalEncoding(nn.Module):
def __init__(self, dim_model, dropout_p, max_len):
super().__init__()
# Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
# max_len determines how far the position can have an effect on a token (window)
# Info
self.dropout = nn.Dropout(dropout_p)
# Encoding - From formula
pos_encoding = torch.zeros(max_len, dim_model)
positions_list = torch.arange(0, max_len, dtype=torch.float).reshape(-1, 1) # 0, 1, 2, 3, 4, 5
division_term = torch.exp(
torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
# PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
# PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
# Saving buffer (same as parameter without gradients needed)
pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
self.register_buffer("pos_encoding", pos_encoding)
def forward(self, token_embedding: torch.tensor) -> torch.tensor:
# Residual connection + pos encoding
return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])
class TimestepEmbedder(nn.Module):
def __init__(self, latent_dim, sequence_pos_encoder):
super().__init__()
self.latent_dim = latent_dim
self.sequence_pos_encoder = sequence_pos_encoder
time_embed_dim = self.latent_dim
self.time_embed = nn.Sequential(
nn.Linear(self.latent_dim, time_embed_dim),
nn.SiLU(inplace=False),
nn.Linear(time_embed_dim, time_embed_dim),
)
def forward(self, timesteps):
return self.time_embed(self.sequence_pos_encoder.pos_encoding[timesteps])#.permute(1, 0, 2)
# totally the same as TimeStepEmbedder
class ProgIndEmbedder(nn.Module):
def __init__(self, latent_dim, sequence_pos_encoder):
super().__init__()
self.latent_dim = latent_dim
self.sequence_pos_encoder = sequence_pos_encoder
time_embed_dim = self.latent_dim
self.time_embed = nn.Sequential(
nn.Linear(self.latent_dim, time_embed_dim),
nn.SiLU(inplace=False),
nn.Linear(time_embed_dim, time_embed_dim),
)
def forward(self, timesteps):
return self.time_embed(self.sequence_pos_encoder.pos_encoding[timesteps])#.permute(1, 0, 2)