|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from torch.utils.checkpoint import checkpoint |
|
|
|
from ...configuration_utils import ConfigMixin, register_to_config |
|
from ...loaders import PeftAdapterMixin |
|
from ..attention import BasicTransformerBlock, SkipFFTransformerBlock |
|
from ..attention_processor import ( |
|
ADDED_KV_ATTENTION_PROCESSORS, |
|
CROSS_ATTENTION_PROCESSORS, |
|
AttentionProcessor, |
|
AttnAddedKVProcessor, |
|
AttnProcessor, |
|
) |
|
from ..embeddings import TimestepEmbedding, get_timestep_embedding |
|
from ..modeling_utils import ModelMixin |
|
from ..normalization import GlobalResponseNorm, RMSNorm |
|
from ..resnet import Downsample2D, Upsample2D |
|
|
|
|
|
class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): |
|
_supports_gradient_checkpointing = True |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
|
|
hidden_size: int = 1024, |
|
use_bias: bool = False, |
|
hidden_dropout: float = 0.0, |
|
|
|
cond_embed_dim: int = 768, |
|
micro_cond_encode_dim: int = 256, |
|
micro_cond_embed_dim: int = 1280, |
|
encoder_hidden_size: int = 768, |
|
|
|
vocab_size: int = 8256, |
|
codebook_size: int = 8192, |
|
|
|
in_channels: int = 768, |
|
block_out_channels: int = 768, |
|
num_res_blocks: int = 3, |
|
downsample: bool = False, |
|
upsample: bool = False, |
|
block_num_heads: int = 12, |
|
|
|
num_hidden_layers: int = 22, |
|
num_attention_heads: int = 16, |
|
|
|
attention_dropout: float = 0.0, |
|
|
|
intermediate_size: int = 2816, |
|
|
|
layer_norm_eps: float = 1e-6, |
|
ln_elementwise_affine: bool = True, |
|
sample_size: int = 64, |
|
): |
|
super().__init__() |
|
|
|
self.encoder_proj = nn.Linear(encoder_hidden_size, hidden_size, bias=use_bias) |
|
self.encoder_proj_layer_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) |
|
|
|
self.embed = UVit2DConvEmbed( |
|
in_channels, block_out_channels, vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias |
|
) |
|
|
|
self.cond_embed = TimestepEmbedding( |
|
micro_cond_embed_dim + cond_embed_dim, hidden_size, sample_proj_bias=use_bias |
|
) |
|
|
|
self.down_block = UVitBlock( |
|
block_out_channels, |
|
num_res_blocks, |
|
hidden_size, |
|
hidden_dropout, |
|
ln_elementwise_affine, |
|
layer_norm_eps, |
|
use_bias, |
|
block_num_heads, |
|
attention_dropout, |
|
downsample, |
|
False, |
|
) |
|
|
|
self.project_to_hidden_norm = RMSNorm(block_out_channels, layer_norm_eps, ln_elementwise_affine) |
|
self.project_to_hidden = nn.Linear(block_out_channels, hidden_size, bias=use_bias) |
|
|
|
self.transformer_layers = nn.ModuleList( |
|
[ |
|
BasicTransformerBlock( |
|
dim=hidden_size, |
|
num_attention_heads=num_attention_heads, |
|
attention_head_dim=hidden_size // num_attention_heads, |
|
dropout=hidden_dropout, |
|
cross_attention_dim=hidden_size, |
|
attention_bias=use_bias, |
|
norm_type="ada_norm_continuous", |
|
ada_norm_continous_conditioning_embedding_dim=hidden_size, |
|
norm_elementwise_affine=ln_elementwise_affine, |
|
norm_eps=layer_norm_eps, |
|
ada_norm_bias=use_bias, |
|
ff_inner_dim=intermediate_size, |
|
ff_bias=use_bias, |
|
attention_out_bias=use_bias, |
|
) |
|
for _ in range(num_hidden_layers) |
|
] |
|
) |
|
|
|
self.project_from_hidden_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) |
|
self.project_from_hidden = nn.Linear(hidden_size, block_out_channels, bias=use_bias) |
|
|
|
self.up_block = UVitBlock( |
|
block_out_channels, |
|
num_res_blocks, |
|
hidden_size, |
|
hidden_dropout, |
|
ln_elementwise_affine, |
|
layer_norm_eps, |
|
use_bias, |
|
block_num_heads, |
|
attention_dropout, |
|
downsample=False, |
|
upsample=upsample, |
|
) |
|
|
|
self.mlm_layer = ConvMlmLayer( |
|
block_out_channels, in_channels, use_bias, ln_elementwise_affine, layer_norm_eps, codebook_size |
|
) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def _set_gradient_checkpointing(self, module, value: bool = False) -> None: |
|
pass |
|
|
|
def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None): |
|
encoder_hidden_states = self.encoder_proj(encoder_hidden_states) |
|
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) |
|
|
|
micro_cond_embeds = get_timestep_embedding( |
|
micro_conds.flatten(), self.config.micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0 |
|
) |
|
|
|
micro_cond_embeds = micro_cond_embeds.reshape((input_ids.shape[0], -1)) |
|
|
|
pooled_text_emb = torch.cat([pooled_text_emb, micro_cond_embeds], dim=1) |
|
pooled_text_emb = pooled_text_emb.to(dtype=self.dtype) |
|
pooled_text_emb = self.cond_embed(pooled_text_emb).to(encoder_hidden_states.dtype) |
|
|
|
hidden_states = self.embed(input_ids) |
|
|
|
hidden_states = self.down_block( |
|
hidden_states, |
|
pooled_text_emb=pooled_text_emb, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
) |
|
|
|
batch_size, channels, height, width = hidden_states.shape |
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels) |
|
|
|
hidden_states = self.project_to_hidden_norm(hidden_states) |
|
hidden_states = self.project_to_hidden(hidden_states) |
|
|
|
for layer in self.transformer_layers: |
|
if self.training and self.gradient_checkpointing: |
|
|
|
def layer_(*args): |
|
return checkpoint(layer, *args) |
|
|
|
else: |
|
layer_ = layer |
|
|
|
hidden_states = layer_( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
added_cond_kwargs={"pooled_text_emb": pooled_text_emb}, |
|
) |
|
|
|
hidden_states = self.project_from_hidden_norm(hidden_states) |
|
hidden_states = self.project_from_hidden(hidden_states) |
|
|
|
hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) |
|
|
|
hidden_states = self.up_block( |
|
hidden_states, |
|
pooled_text_emb=pooled_text_emb, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
) |
|
|
|
logits = self.mlm_layer(hidden_states) |
|
|
|
return logits |
|
|
|
@property |
|
|
|
def attn_processors(self) -> Dict[str, AttentionProcessor]: |
|
r""" |
|
Returns: |
|
`dict` of attention processors: A dictionary containing all attention processors used in the model with |
|
indexed by its weight name. |
|
""" |
|
|
|
processors = {} |
|
|
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): |
|
if hasattr(module, "get_processor"): |
|
processors[f"{name}.processor"] = module.get_processor() |
|
|
|
for sub_name, child in module.named_children(): |
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
|
|
|
return processors |
|
|
|
for name, module in self.named_children(): |
|
fn_recursive_add_processors(name, module, processors) |
|
|
|
return processors |
|
|
|
|
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
|
r""" |
|
Sets the attention processor to use to compute attention. |
|
|
|
Parameters: |
|
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): |
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor |
|
for **all** `Attention` layers. |
|
|
|
If `processor` is a dict, the key needs to define the path to the corresponding cross attention |
|
processor. This is strongly recommended when setting trainable attention processors. |
|
|
|
""" |
|
count = len(self.attn_processors.keys()) |
|
|
|
if isinstance(processor, dict) and len(processor) != count: |
|
raise ValueError( |
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" |
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." |
|
) |
|
|
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
|
if hasattr(module, "set_processor"): |
|
if not isinstance(processor, dict): |
|
module.set_processor(processor) |
|
else: |
|
module.set_processor(processor.pop(f"{name}.processor")) |
|
|
|
for sub_name, child in module.named_children(): |
|
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
|
|
|
for name, module in self.named_children(): |
|
fn_recursive_attn_processor(name, module, processor) |
|
|
|
|
|
def set_default_attn_processor(self): |
|
""" |
|
Disables custom attention processors and sets the default attention implementation. |
|
""" |
|
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): |
|
processor = AttnAddedKVProcessor() |
|
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): |
|
processor = AttnProcessor() |
|
else: |
|
raise ValueError( |
|
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" |
|
) |
|
|
|
self.set_attn_processor(processor) |
|
|
|
|
|
class UVit2DConvEmbed(nn.Module): |
|
def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias): |
|
super().__init__() |
|
self.embeddings = nn.Embedding(vocab_size, in_channels) |
|
self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine) |
|
self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias) |
|
|
|
def forward(self, input_ids): |
|
embeddings = self.embeddings(input_ids) |
|
embeddings = self.layer_norm(embeddings) |
|
embeddings = embeddings.permute(0, 3, 1, 2) |
|
embeddings = self.conv(embeddings) |
|
return embeddings |
|
|
|
|
|
class UVitBlock(nn.Module): |
|
def __init__( |
|
self, |
|
channels, |
|
num_res_blocks: int, |
|
hidden_size, |
|
hidden_dropout, |
|
ln_elementwise_affine, |
|
layer_norm_eps, |
|
use_bias, |
|
block_num_heads, |
|
attention_dropout, |
|
downsample: bool, |
|
upsample: bool, |
|
): |
|
super().__init__() |
|
|
|
if downsample: |
|
self.downsample = Downsample2D( |
|
channels, |
|
use_conv=True, |
|
padding=0, |
|
name="Conv2d_0", |
|
kernel_size=2, |
|
norm_type="rms_norm", |
|
eps=layer_norm_eps, |
|
elementwise_affine=ln_elementwise_affine, |
|
bias=use_bias, |
|
) |
|
else: |
|
self.downsample = None |
|
|
|
self.res_blocks = nn.ModuleList( |
|
[ |
|
ConvNextBlock( |
|
channels, |
|
layer_norm_eps, |
|
ln_elementwise_affine, |
|
use_bias, |
|
hidden_dropout, |
|
hidden_size, |
|
) |
|
for i in range(num_res_blocks) |
|
] |
|
) |
|
|
|
self.attention_blocks = nn.ModuleList( |
|
[ |
|
SkipFFTransformerBlock( |
|
channels, |
|
block_num_heads, |
|
channels // block_num_heads, |
|
hidden_size, |
|
use_bias, |
|
attention_dropout, |
|
channels, |
|
attention_bias=use_bias, |
|
attention_out_bias=use_bias, |
|
) |
|
for _ in range(num_res_blocks) |
|
] |
|
) |
|
|
|
if upsample: |
|
self.upsample = Upsample2D( |
|
channels, |
|
use_conv_transpose=True, |
|
kernel_size=2, |
|
padding=0, |
|
name="conv", |
|
norm_type="rms_norm", |
|
eps=layer_norm_eps, |
|
elementwise_affine=ln_elementwise_affine, |
|
bias=use_bias, |
|
interpolate=False, |
|
) |
|
else: |
|
self.upsample = None |
|
|
|
def forward(self, x, pooled_text_emb, encoder_hidden_states, cross_attention_kwargs): |
|
if self.downsample is not None: |
|
x = self.downsample(x) |
|
|
|
for res_block, attention_block in zip(self.res_blocks, self.attention_blocks): |
|
x = res_block(x, pooled_text_emb) |
|
|
|
batch_size, channels, height, width = x.shape |
|
x = x.view(batch_size, channels, height * width).permute(0, 2, 1) |
|
x = attention_block( |
|
x, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs |
|
) |
|
x = x.permute(0, 2, 1).view(batch_size, channels, height, width) |
|
|
|
if self.upsample is not None: |
|
x = self.upsample(x) |
|
|
|
return x |
|
|
|
|
|
class ConvNextBlock(nn.Module): |
|
def __init__( |
|
self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4 |
|
): |
|
super().__init__() |
|
self.depthwise = nn.Conv2d( |
|
channels, |
|
channels, |
|
kernel_size=3, |
|
padding=1, |
|
groups=channels, |
|
bias=use_bias, |
|
) |
|
self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine) |
|
self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias) |
|
self.channelwise_act = nn.GELU() |
|
self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor)) |
|
self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias) |
|
self.channelwise_dropout = nn.Dropout(hidden_dropout) |
|
self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias) |
|
|
|
def forward(self, x, cond_embeds): |
|
x_res = x |
|
|
|
x = self.depthwise(x) |
|
|
|
x = x.permute(0, 2, 3, 1) |
|
x = self.norm(x) |
|
|
|
x = self.channelwise_linear_1(x) |
|
x = self.channelwise_act(x) |
|
x = self.channelwise_norm(x) |
|
x = self.channelwise_linear_2(x) |
|
x = self.channelwise_dropout(x) |
|
|
|
x = x.permute(0, 3, 1, 2) |
|
|
|
x = x + x_res |
|
|
|
scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1) |
|
x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None] |
|
|
|
return x |
|
|
|
|
|
class ConvMlmLayer(nn.Module): |
|
def __init__( |
|
self, |
|
block_out_channels: int, |
|
in_channels: int, |
|
use_bias: bool, |
|
ln_elementwise_affine: bool, |
|
layer_norm_eps: float, |
|
codebook_size: int, |
|
): |
|
super().__init__() |
|
self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias) |
|
self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine) |
|
self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias) |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.conv1(hidden_states) |
|
hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) |
|
logits = self.conv2(hidden_states) |
|
return logits |
|
|