|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from ...configuration_utils import ConfigMixin, register_to_config |
|
from ...models.attention import FeedForward |
|
from ...models.attention_processor import ( |
|
Attention, |
|
AttentionProcessor, |
|
CogVideoXAttnProcessor2_0, |
|
) |
|
from ...models.modeling_utils import ModelMixin |
|
from ...models.normalization import AdaLayerNormContinuous |
|
from ...utils import is_torch_version, logging |
|
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed |
|
from ..modeling_outputs import Transformer2DModelOutput |
|
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class CogView3PlusTransformerBlock(nn.Module): |
|
r""" |
|
Transformer block used in [CogView](https://github.com/THUDM/CogView3) model. |
|
|
|
Args: |
|
dim (`int`): |
|
The number of channels in the input and output. |
|
num_attention_heads (`int`): |
|
The number of heads to use for multi-head attention. |
|
attention_head_dim (`int`): |
|
The number of channels in each head. |
|
time_embed_dim (`int`): |
|
The number of channels in timestep embedding. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int = 2560, |
|
num_attention_heads: int = 64, |
|
attention_head_dim: int = 40, |
|
time_embed_dim: int = 512, |
|
): |
|
super().__init__() |
|
|
|
self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim) |
|
|
|
self.attn1 = Attention( |
|
query_dim=dim, |
|
heads=num_attention_heads, |
|
dim_head=attention_head_dim, |
|
out_dim=dim, |
|
bias=True, |
|
qk_norm="layer_norm", |
|
elementwise_affine=False, |
|
eps=1e-6, |
|
processor=CogVideoXAttnProcessor2_0(), |
|
) |
|
|
|
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) |
|
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) |
|
|
|
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
emb: torch.Tensor, |
|
) -> torch.Tensor: |
|
text_seq_length = encoder_hidden_states.size(1) |
|
|
|
|
|
( |
|
norm_hidden_states, |
|
gate_msa, |
|
shift_mlp, |
|
scale_mlp, |
|
gate_mlp, |
|
norm_encoder_hidden_states, |
|
c_gate_msa, |
|
c_shift_mlp, |
|
c_scale_mlp, |
|
c_gate_mlp, |
|
) = self.norm1(hidden_states, encoder_hidden_states, emb) |
|
|
|
|
|
attn_hidden_states, attn_encoder_hidden_states = self.attn1( |
|
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states |
|
) |
|
|
|
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states |
|
encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states |
|
|
|
|
|
norm_hidden_states = self.norm2(hidden_states) |
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
|
|
|
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) |
|
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] |
|
|
|
|
|
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) |
|
ff_output = self.ff(norm_hidden_states) |
|
|
|
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:] |
|
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length] |
|
|
|
if hidden_states.dtype == torch.float16: |
|
hidden_states = hidden_states.clip(-65504, 65504) |
|
if encoder_hidden_states.dtype == torch.float16: |
|
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) |
|
return hidden_states, encoder_hidden_states |
|
|
|
|
|
class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): |
|
r""" |
|
The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay |
|
Diffusion](https://huggingface.co/papers/2403.05121). |
|
|
|
Args: |
|
patch_size (`int`, defaults to `2`): |
|
The size of the patches to use in the patch embedding layer. |
|
in_channels (`int`, defaults to `16`): |
|
The number of channels in the input. |
|
num_layers (`int`, defaults to `30`): |
|
The number of layers of Transformer blocks to use. |
|
attention_head_dim (`int`, defaults to `40`): |
|
The number of channels in each head. |
|
num_attention_heads (`int`, defaults to `64`): |
|
The number of heads to use for multi-head attention. |
|
out_channels (`int`, defaults to `16`): |
|
The number of channels in the output. |
|
text_embed_dim (`int`, defaults to `4096`): |
|
Input dimension of text embeddings from the text encoder. |
|
time_embed_dim (`int`, defaults to `512`): |
|
Output dimension of timestep embeddings. |
|
condition_dim (`int`, defaults to `256`): |
|
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size, |
|
crop_coords). |
|
pos_embed_max_size (`int`, defaults to `128`): |
|
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added |
|
to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128 |
|
means that the maximum supported height and width for image generation is `128 * vae_scale_factor * |
|
patch_size => 128 * 8 * 2 => 2048`. |
|
sample_size (`int`, defaults to `128`): |
|
The base resolution of input latents. If height/width is not provided during generation, this value is used |
|
to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024` |
|
""" |
|
|
|
_supports_gradient_checkpointing = True |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
patch_size: int = 2, |
|
in_channels: int = 16, |
|
num_layers: int = 30, |
|
attention_head_dim: int = 40, |
|
num_attention_heads: int = 64, |
|
out_channels: int = 16, |
|
text_embed_dim: int = 4096, |
|
time_embed_dim: int = 512, |
|
condition_dim: int = 256, |
|
pos_embed_max_size: int = 128, |
|
sample_size: int = 128, |
|
): |
|
super().__init__() |
|
self.out_channels = out_channels |
|
self.inner_dim = num_attention_heads * attention_head_dim |
|
|
|
|
|
|
|
self.pooled_projection_dim = 3 * 2 * condition_dim |
|
|
|
self.patch_embed = CogView3PlusPatchEmbed( |
|
in_channels=in_channels, |
|
hidden_size=self.inner_dim, |
|
patch_size=patch_size, |
|
text_hidden_size=text_embed_dim, |
|
pos_embed_max_size=pos_embed_max_size, |
|
) |
|
|
|
self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings( |
|
embedding_dim=time_embed_dim, |
|
condition_dim=condition_dim, |
|
pooled_projection_dim=self.pooled_projection_dim, |
|
timesteps_dim=self.inner_dim, |
|
) |
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
[ |
|
CogView3PlusTransformerBlock( |
|
dim=self.inner_dim, |
|
num_attention_heads=num_attention_heads, |
|
attention_head_dim=attention_head_dim, |
|
time_embed_dim=time_embed_dim, |
|
) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
|
|
self.norm_out = AdaLayerNormContinuous( |
|
embedding_dim=self.inner_dim, |
|
conditioning_embedding_dim=time_embed_dim, |
|
elementwise_affine=False, |
|
eps=1e-6, |
|
) |
|
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
@property |
|
|
|
def attn_processors(self) -> Dict[str, AttentionProcessor]: |
|
r""" |
|
Returns: |
|
`dict` of attention processors: A dictionary containing all attention processors used in the model with |
|
indexed by its weight name. |
|
""" |
|
|
|
processors = {} |
|
|
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): |
|
if hasattr(module, "get_processor"): |
|
processors[f"{name}.processor"] = module.get_processor() |
|
|
|
for sub_name, child in module.named_children(): |
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
|
|
|
return processors |
|
|
|
for name, module in self.named_children(): |
|
fn_recursive_add_processors(name, module, processors) |
|
|
|
return processors |
|
|
|
|
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
|
r""" |
|
Sets the attention processor to use to compute attention. |
|
|
|
Parameters: |
|
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): |
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor |
|
for **all** `Attention` layers. |
|
|
|
If `processor` is a dict, the key needs to define the path to the corresponding cross attention |
|
processor. This is strongly recommended when setting trainable attention processors. |
|
|
|
""" |
|
count = len(self.attn_processors.keys()) |
|
|
|
if isinstance(processor, dict) and len(processor) != count: |
|
raise ValueError( |
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" |
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." |
|
) |
|
|
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
|
if hasattr(module, "set_processor"): |
|
if not isinstance(processor, dict): |
|
module.set_processor(processor) |
|
else: |
|
module.set_processor(processor.pop(f"{name}.processor")) |
|
|
|
for sub_name, child in module.named_children(): |
|
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
|
|
|
for name, module in self.named_children(): |
|
fn_recursive_attn_processor(name, module, processor) |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if hasattr(module, "gradient_checkpointing"): |
|
module.gradient_checkpointing = value |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
timestep: torch.LongTensor, |
|
original_size: torch.Tensor, |
|
target_size: torch.Tensor, |
|
crop_coords: torch.Tensor, |
|
return_dict: bool = True, |
|
) -> Union[torch.Tensor, Transformer2DModelOutput]: |
|
""" |
|
The [`CogView3PlusTransformer2DModel`] forward method. |
|
|
|
Args: |
|
hidden_states (`torch.Tensor`): |
|
Input `hidden_states` of shape `(batch size, channel, height, width)`. |
|
encoder_hidden_states (`torch.Tensor`): |
|
Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape |
|
`(batch_size, sequence_len, text_embed_dim)` |
|
timestep (`torch.LongTensor`): |
|
Used to indicate denoising step. |
|
original_size (`torch.Tensor`): |
|
CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of |
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). |
|
target_size (`torch.Tensor`): |
|
CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of |
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). |
|
crop_coords (`torch.Tensor`): |
|
CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of |
|
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain |
|
tuple. |
|
|
|
Returns: |
|
`torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]: |
|
The denoised latents using provided inputs as conditioning. |
|
""" |
|
height, width = hidden_states.shape[-2:] |
|
text_seq_length = encoder_hidden_states.shape[1] |
|
|
|
hidden_states = self.patch_embed( |
|
hidden_states, encoder_hidden_states |
|
) |
|
emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype) |
|
|
|
encoder_hidden_states = hidden_states[:, :text_seq_length] |
|
hidden_states = hidden_states[:, text_seq_length:] |
|
|
|
for index_block, block in enumerate(self.transformer_blocks): |
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
encoder_hidden_states, |
|
emb, |
|
**ckpt_kwargs, |
|
) |
|
else: |
|
hidden_states, encoder_hidden_states = block( |
|
hidden_states=hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
emb=emb, |
|
) |
|
|
|
hidden_states = self.norm_out(hidden_states, emb) |
|
hidden_states = self.proj_out(hidden_states) |
|
|
|
|
|
patch_size = self.config.patch_size |
|
height = height // patch_size |
|
width = width // patch_size |
|
|
|
hidden_states = hidden_states.reshape( |
|
shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size) |
|
) |
|
hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states) |
|
output = hidden_states.reshape( |
|
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) |
|
) |
|
|
|
if not return_dict: |
|
return (output,) |
|
|
|
return Transformer2DModelOutput(sample=output) |
|
|