Spaces:
Runtime error
Runtime error
File size: 3,840 Bytes
f631117 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import torch
from torch.nn import functional as F
from .dit import DiffusionTransformer
from .adp import UNet1d
from .sampling import sample
import math
from model.base import BaseModule
import pdb
target_length = 1536
def pad_and_create_mask(matrix, target_length):
T = matrix.shape[2]
if T > target_length:
raise ValueError("The third dimension length %s should not exceed %s"%(T, target_length))
padding_size = target_length - T
padded_matrix = F.pad(matrix, (0, padding_size), "constant", 0)
mask = torch.ones((1, target_length))
mask[:, T:] = 0 # Set the padding part to 0
return padded_matrix.to(matrix.device), mask.to(matrix.device)
class Stable_Diffusion(BaseModule):
def __init__(self):
super(Stable_Diffusion, self).__init__()
self.diffusion = DiffusionTransformer(
io_channels=80,
# input_concat_dim=80,
embed_dim=768,
# cond_token_dim=target_length,
depth=24,
num_heads=24,
project_cond_tokens=False,
transformer_type="continuous_transformer",
)
# self.diffusion = UNet1d(
# in_channels=80,
# channels=256,
# resnet_groups=16,
# kernel_multiplier_downsample=2,
# multipliers=[4, 4, 4, 5, 5],
# factors=[1, 2, 2, 4], # θΎε
₯ιΏεΊ¦δΈδΈθ΄ε·η§―ηΌ©η
# num_blocks=[2, 2, 2, 2],
# attentions=[1, 3, 3, 3, 3],
# attention_heads=16,
# attention_multiplier=4,
# use_nearest_upsample=False,
# use_skip_scale=True,
# use_context_time=True
# )
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
@torch.no_grad()
def forward(self, mu, mask, n_timesteps):
# pdb.set_trace()
mask = mask.squeeze(1)
# noise = torch.randn_like(mu).to(mu.device)
# mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length)
# extra_args = {"cross_attn_cond": mu, "cross_attn_cond_mask": mask, "mask": mask}
extra_args = {"mask": mask}
fakes = sample(self.diffusion, mu, n_timesteps, 0, **extra_args)
return fakes
def compute_loss(self, x0, mask, mu):
# pdb.set_trace()
t = self.rng.draw(x0.shape[0])[:, 0].to(x0.device)
alphas, sigmas = torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
alphas = alphas[:, None, None]
sigmas = sigmas[:, None, None]
noise = torch.randn_like(x0)
noised_inputs = x0 * alphas + noise * sigmas
targets = mu * alphas - x0 * sigmas
mask = mask.squeeze(1)
# mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length)
# output = self.diffusion(noised_inputs, t, cross_attn_cond=mu,
# cross_attn_cond_mask=mask, mask=mask, cfg_dropout_prob=0.1)
output = self.diffusion(noised_inputs, t, mask=mask, cfg_dropout_prob=0.1)
return self.mse_loss(output, targets, mask), output
def mse_loss(self, output, targets, mask):
mse_loss = F.mse_loss(output, targets, reduction='none')
if mask.ndim == 2 and mse_loss.ndim == 3:
mask = mask.unsqueeze(1)
if mask.shape[1] != mse_loss.shape[1]:
mask = mask.repeat(1, mse_loss.shape[1], 1)
mse_loss = mse_loss[mask]
mse_loss = mse_loss.mean()
return mse_loss |