Spaces:
Runtime error
Runtime error
File size: 7,109 Bytes
9632411 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import torch
import torch.nn as nn
import math
class Unet(nn.Module):
def __init__(
self,
dim_model,
num_heads,
num_layers,
dropout_p,
dim_input,
dim_output,
free_p=0.1,
text_emb=True,
device='cuda',
**kwargs
):
super().__init__()
# INFO
self.model_type = "Transformer"
self.dim_model = dim_model
self.text_emb = text_emb
self.dim_input = dim_input
self.device = device
try:
self.Disc = kwargs['Disc']
except:
self.Disc = False
# layers
self.free_p = free_p
self.positional_encoder = PositionalEncoding(
dim_model=dim_model, dropout_p=dropout_p, max_len=5000
)
self.embedding_input = nn.Linear(dim_input, dim_model)
self.embedding_original = nn.Linear(dim_input, dim_model)
encoder_layer = nn.TransformerEncoderLayer(d_model=dim_model,
nhead=num_heads,
dim_feedforward=dim_model*4,
dropout=dropout_p,
activation="gelu",
)
self.transformer = nn.TransformerEncoder(encoder_layer,
num_layers=num_layers,
)
if self.Disc:
# for discriminator
self.pred = nn.Sequential(nn.Linear(dim_output, dim_output),
nn.SiLU(inplace=False),
nn.Linear(dim_output, 1),
nn.Sigmoid())
self.out = nn.Linear(dim_model, dim_output)
self.embed_timestep = TimestepEmbedder(self.dim_model, self.positional_encoder)
if self.text_emb:
#for embedding progress indicator
print("text embedding is enabled!")
self.positional_encoder_pi = PositionalEncoding(
dim_model=dim_model, dropout_p=dropout_p, max_len=5000
)
self.embed_prog_ind = ProgIndEmbedder(self.dim_model, self.positional_encoder_pi)
def forward_disc(self, x, timesteps):
t_emb = self.embed_timestep(timesteps) # t_emb refers to time embedding
x, t_emb = x.permute(1, 0, 2), t_emb.permute(1, 0, 2)
x = self.embedding_input(x) * math.sqrt(self.dim_model)
x = torch.cat((t_emb, x), dim=0)
x = self.positional_encoder(x)
x = self.transformer(x)
output = self.out(x)[1:]
output = output.permute(1, 0, 2)
output = output.mean(dim=1)
output = self.pred(output)
return output
def forward_(self, x, timesteps, text_emb=None, prog_ind=None, joints_orig=None):
t_emb = self.embed_timestep(timesteps) # t_emb refers to time embedding
if self.text_emb:
text_emb = text_emb.unsqueeze(1) # batchsize, 1, 512
assert text_emb.shape == (x.shape[0], 1, self.dim_model), \
f'text_emb shape should be (batchsize, 1, {self.dim_model})'
x, joints_orig, t_emb = x.permute(1, 0, 2), joints_orig.permute(1, 0, 2), t_emb.permute(1, 0, 2)
x = self.embedding_input(x) * math.sqrt(self.dim_model)
joints_orig = self.embedding_original(joints_orig) * math.sqrt(self.dim_model)
x = (x + joints_orig) / 2.
if not self.text_emb:
x = torch.cat((t_emb, x), dim=0) # (seq_len+1), batchsize, dim_model
else:
text_emb = text_emb.permute(1, 0, 2)
prog_ind = (prog_ind*100).round().to(torch.int64)
prog_ind_emb = self.embed_prog_ind(prog_ind).permute(1, 0, 2)
t_emb = (t_emb + text_emb/10.0 + prog_ind_emb) * math.sqrt(self.dim_model)
x = torch.cat((t_emb, x), dim=0)
x = self.positional_encoder(x)
x = self.transformer(x)
output = self.out(x)[1:]
output = output.permute(1, 0, 2)
return output
def forward(self, x, timesteps, text_emb=None, prog_ind=None, joints_orig=None):
if self.Disc:
return self.forward_disc(x, timesteps)
else:
return self.forward_(x, timesteps, text_emb, prog_ind, joints_orig)
class PositionalEncoding(nn.Module):
def __init__(self, dim_model, dropout_p, max_len):
super().__init__()
# Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
# max_len determines how far the position can have an effect on a token (window)
# Info
self.dropout = nn.Dropout(dropout_p)
# Encoding - From formula
pos_encoding = torch.zeros(max_len, dim_model)
positions_list = torch.arange(0, max_len, dtype=torch.float).reshape(-1, 1) # 0, 1, 2, 3, 4, 5
division_term = torch.exp(
torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
# PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
# PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
# Saving buffer (same as parameter without gradients needed)
pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
self.register_buffer("pos_encoding", pos_encoding)
def forward(self, token_embedding: torch.tensor) -> torch.tensor:
# Residual connection + pos encoding
return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])
class TimestepEmbedder(nn.Module):
def __init__(self, latent_dim, sequence_pos_encoder):
super().__init__()
self.latent_dim = latent_dim
self.sequence_pos_encoder = sequence_pos_encoder
time_embed_dim = self.latent_dim
self.time_embed = nn.Sequential(
nn.Linear(self.latent_dim, time_embed_dim),
nn.SiLU(inplace=False),
nn.Linear(time_embed_dim, time_embed_dim),
)
def forward(self, timesteps):
return self.time_embed(self.sequence_pos_encoder.pos_encoding[timesteps])#.permute(1, 0, 2)
# totally the same as TimeStepEmbedder
class ProgIndEmbedder(nn.Module):
def __init__(self, latent_dim, sequence_pos_encoder):
super().__init__()
self.latent_dim = latent_dim
self.sequence_pos_encoder = sequence_pos_encoder
time_embed_dim = self.latent_dim
self.time_embed = nn.Sequential(
nn.Linear(self.latent_dim, time_embed_dim),
nn.SiLU(inplace=False),
nn.Linear(time_embed_dim, time_embed_dim),
)
def forward(self, timesteps):
return self.time_embed(self.sequence_pos_encoder.pos_encoding[timesteps])#.permute(1, 0, 2)
|