|
from dataclasses import dataclass |
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from ..configuration_utils import ConfigMixin, register_to_config |
|
from ..modeling_utils import ModelMixin |
|
from ..utils import BaseOutput |
|
from .embeddings import TimestepEmbedding, Timesteps |
|
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block |
|
|
|
|
|
@dataclass |
|
class UNet2DConditionOutput(BaseOutput): |
|
""" |
|
Args: |
|
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): |
|
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. |
|
""" |
|
|
|
sample: torch.FloatTensor |
|
|
|
|
|
class UNet2DConditionModel(ModelMixin, ConfigMixin): |
|
r""" |
|
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep |
|
and returns sample shaped output. |
|
|
|
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library |
|
implements for all the model (such as downloading or saving, etc.) |
|
|
|
Parameters: |
|
sample_size (`int`, *optional*): The size of the input sample. |
|
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. |
|
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. |
|
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. |
|
flip_sin_to_cos (`bool`, *optional*, defaults to `False`): |
|
Whether to flip the sin to cos in the time embedding. |
|
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. |
|
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): |
|
The tuple of downsample blocks to use. |
|
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): |
|
The tuple of upsample blocks to use. |
|
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): |
|
The tuple of output channels for each block. |
|
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. |
|
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. |
|
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. |
|
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. |
|
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. |
|
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. |
|
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. |
|
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. |
|
""" |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
sample_size: Optional[int] = None, |
|
in_channels: int = 4, |
|
out_channels: int = 4, |
|
center_input_sample: bool = False, |
|
flip_sin_to_cos: bool = True, |
|
freq_shift: int = 0, |
|
down_block_types: Tuple[str] = ( |
|
"CrossAttnDownBlock2D", |
|
"CrossAttnDownBlock2D", |
|
"CrossAttnDownBlock2D", |
|
"DownBlock2D", |
|
), |
|
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), |
|
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), |
|
layers_per_block: int = 2, |
|
downsample_padding: int = 1, |
|
mid_block_scale_factor: float = 1, |
|
act_fn: str = "silu", |
|
norm_num_groups: int = 32, |
|
norm_eps: float = 1e-5, |
|
cross_attention_dim: int = 1280, |
|
attention_head_dim: int = 8, |
|
): |
|
super().__init__() |
|
|
|
self.sample_size = sample_size |
|
time_embed_dim = block_out_channels[0] * 4 |
|
|
|
|
|
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) |
|
|
|
|
|
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) |
|
timestep_input_dim = block_out_channels[0] |
|
|
|
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) |
|
|
|
self.down_blocks = nn.ModuleList([]) |
|
self.mid_block = None |
|
self.up_blocks = nn.ModuleList([]) |
|
|
|
|
|
output_channel = block_out_channels[0] |
|
for i, down_block_type in enumerate(down_block_types): |
|
input_channel = output_channel |
|
output_channel = block_out_channels[i] |
|
is_final_block = i == len(block_out_channels) - 1 |
|
|
|
down_block = get_down_block( |
|
down_block_type, |
|
num_layers=layers_per_block, |
|
in_channels=input_channel, |
|
out_channels=output_channel, |
|
temb_channels=time_embed_dim, |
|
add_downsample=not is_final_block, |
|
resnet_eps=norm_eps, |
|
resnet_act_fn=act_fn, |
|
cross_attention_dim=cross_attention_dim, |
|
attn_num_head_channels=attention_head_dim, |
|
downsample_padding=downsample_padding, |
|
) |
|
self.down_blocks.append(down_block) |
|
|
|
|
|
self.mid_block = UNetMidBlock2DCrossAttn( |
|
in_channels=block_out_channels[-1], |
|
temb_channels=time_embed_dim, |
|
resnet_eps=norm_eps, |
|
resnet_act_fn=act_fn, |
|
output_scale_factor=mid_block_scale_factor, |
|
resnet_time_scale_shift="default", |
|
cross_attention_dim=cross_attention_dim, |
|
attn_num_head_channels=attention_head_dim, |
|
resnet_groups=norm_num_groups, |
|
) |
|
|
|
|
|
reversed_block_out_channels = list(reversed(block_out_channels)) |
|
output_channel = reversed_block_out_channels[0] |
|
for i, up_block_type in enumerate(up_block_types): |
|
prev_output_channel = output_channel |
|
output_channel = reversed_block_out_channels[i] |
|
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] |
|
|
|
is_final_block = i == len(block_out_channels) - 1 |
|
|
|
up_block = get_up_block( |
|
up_block_type, |
|
num_layers=layers_per_block + 1, |
|
in_channels=input_channel, |
|
out_channels=output_channel, |
|
prev_output_channel=prev_output_channel, |
|
temb_channels=time_embed_dim, |
|
add_upsample=not is_final_block, |
|
resnet_eps=norm_eps, |
|
resnet_act_fn=act_fn, |
|
cross_attention_dim=cross_attention_dim, |
|
attn_num_head_channels=attention_head_dim, |
|
) |
|
self.up_blocks.append(up_block) |
|
prev_output_channel = output_channel |
|
|
|
|
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) |
|
self.conv_act = nn.SiLU() |
|
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) |
|
|
|
def set_attention_slice(self, slice_size): |
|
if slice_size is not None and self.config.attention_head_dim % slice_size != 0: |
|
raise ValueError( |
|
f"Make sure slice_size {slice_size} is a divisor of " |
|
f"the number of heads used in cross_attention {self.config.attention_head_dim}" |
|
) |
|
if slice_size is not None and slice_size > self.config.attention_head_dim: |
|
raise ValueError( |
|
f"Chunk_size {slice_size} has to be smaller or equal to " |
|
f"the number of heads used in cross_attention {self.config.attention_head_dim}" |
|
) |
|
|
|
for block in self.down_blocks: |
|
if hasattr(block, "attentions") and block.attentions is not None: |
|
block.set_attention_slice(slice_size) |
|
|
|
self.mid_block.set_attention_slice(slice_size) |
|
|
|
for block in self.up_blocks: |
|
if hasattr(block, "attentions") and block.attentions is not None: |
|
block.set_attention_slice(slice_size) |
|
|
|
def forward( |
|
self, |
|
sample: torch.FloatTensor, |
|
timestep: Union[torch.Tensor, float, int], |
|
encoder_hidden_states: torch.Tensor, |
|
return_dict: bool = True, |
|
) -> Union[UNet2DConditionOutput, Tuple]: |
|
"""r |
|
Args: |
|
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor |
|
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps |
|
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. |
|
|
|
Returns: |
|
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: |
|
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When |
|
returning a tuple, the first element is the sample tensor. |
|
""" |
|
|
|
if self.config.center_input_sample: |
|
sample = 2 * sample - 1.0 |
|
|
|
|
|
timesteps = timestep |
|
if not torch.is_tensor(timesteps): |
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) |
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: |
|
timesteps = timesteps.to(dtype=torch.float64) |
|
timesteps = timesteps[None].to(device=sample.device) |
|
|
|
|
|
timesteps = timesteps.expand(sample.shape[0]) |
|
|
|
t_emb = self.time_proj(timesteps) |
|
|
|
t_emb = t_emb.to(sample.dtype).to(sample.device) |
|
emb = self.time_embedding(t_emb) |
|
|
|
|
|
sample = self.conv_in(sample) |
|
|
|
|
|
down_block_res_samples = (sample,) |
|
for downsample_block in self.down_blocks: |
|
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: |
|
|
|
sample, res_samples = downsample_block( |
|
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states |
|
) |
|
else: |
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) |
|
|
|
down_block_res_samples += res_samples |
|
|
|
|
|
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) |
|
|
|
|
|
for upsample_block in self.up_blocks: |
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
|
|
|
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: |
|
sample = upsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
res_hidden_states_tuple=res_samples, |
|
encoder_hidden_states=encoder_hidden_states, |
|
) |
|
else: |
|
sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples) |
|
|
|
|
|
|
|
|
|
sample = self.conv_norm_out(sample.double()).type(sample.dtype) |
|
sample = self.conv_act(sample) |
|
sample = self.conv_out(sample) |
|
|
|
if not return_dict: |
|
return (sample,) |
|
|
|
return UNet2DConditionOutput(sample=sample) |
|
|