Spaces:
Build error
Build error
# -*- coding: utf-8 -*- | |
import torch | |
import torch.nn as nn | |
from typing import Optional | |
from diffusers.models.embeddings import Timesteps | |
import math | |
from michelangelo.models.modules.transformer_blocks import MLP | |
from michelangelo.models.modules.diffusion_transformer import UNetDiffusionTransformer | |
class ConditionalASLUDTDenoiser(nn.Module): | |
def __init__(self, *, | |
device: Optional[torch.device], | |
dtype: Optional[torch.dtype], | |
input_channels: int, | |
output_channels: int, | |
n_ctx: int, | |
width: int, | |
layers: int, | |
heads: int, | |
context_dim: int, | |
context_ln: bool = True, | |
skip_ln: bool = False, | |
init_scale: float = 0.25, | |
flip_sin_to_cos: bool = False, | |
use_checkpoint: bool = False): | |
super().__init__() | |
self.use_checkpoint = use_checkpoint | |
init_scale = init_scale * math.sqrt(1.0 / width) | |
self.backbone = UNetDiffusionTransformer( | |
device=device, | |
dtype=dtype, | |
n_ctx=n_ctx, | |
width=width, | |
layers=layers, | |
heads=heads, | |
skip_ln=skip_ln, | |
init_scale=init_scale, | |
use_checkpoint=use_checkpoint | |
) | |
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) | |
self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype) | |
self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype) | |
# timestep embedding | |
self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0) | |
self.time_proj = MLP( | |
device=device, dtype=dtype, width=width, init_scale=init_scale | |
) | |
self.context_embed = nn.Sequential( | |
nn.LayerNorm(context_dim, device=device, dtype=dtype), | |
nn.Linear(context_dim, width, device=device, dtype=dtype), | |
) | |
if context_ln: | |
self.context_embed = nn.Sequential( | |
nn.LayerNorm(context_dim, device=device, dtype=dtype), | |
nn.Linear(context_dim, width, device=device, dtype=dtype), | |
) | |
else: | |
self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype) | |
def forward(self, | |
model_input: torch.FloatTensor, | |
timestep: torch.LongTensor, | |
context: torch.FloatTensor): | |
r""" | |
Args: | |
model_input (torch.FloatTensor): [bs, n_data, c] | |
timestep (torch.LongTensor): [bs,] | |
context (torch.FloatTensor): [bs, context_tokens, c] | |
Returns: | |
sample (torch.FloatTensor): [bs, n_data, c] | |
""" | |
_, n_data, _ = model_input.shape | |
# 1. time | |
t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1) | |
# 2. conditions projector | |
context = self.context_embed(context) | |
# 3. denoiser | |
x = self.input_proj(model_input) | |
x = torch.cat([t_emb, context, x], dim=1) | |
x = self.backbone(x) | |
x = self.ln_post(x) | |
x = x[:, -n_data:] | |
sample = self.output_proj(x) | |
return sample | |