import itertools from typing import Any, Optional, Dict, Tuple import torch from diffusers import StableDiffusionPipeline, AutoencoderKL from diffusers import Transformer2DModel, ModelMixin, ConfigMixin from diffusers import UNet2DConditionModel from diffusers.configuration_utils import register_to_config from diffusers.models.attention import BasicTransformerBlock from diffusers.models.resnet import ResnetBlock2D, Downsample2D, Upsample2D from diffusers.models.transformer_2d import Transformer2DModelOutput from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers from torch import nn from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor FlexibleUnetConfigurations = { # General parameters for all blocks 'sample_size': 64, 'temb_dim': 320 * 4, 'resnet_eps': 1e-5, 'resnet_act_fn': 'silu', 'num_attention_heads': 8, 'cross_attention_dim': 768, # Controls modules execute order in unet's forward 'mix_block_in_forward': True, # Down blocks parameters 'down_blocks_in_channels': [320, 320, 640], 'down_blocks_out_channels': [320, 640, 1280], 'down_blocks_num_attentions': [0, 1, 3], 'down_blocks_num_resnets': [2, 2, 1], 'add_downsample': [True, True, True], # Middle block parameters 'add_upsample_mid_block': True, 'mid_num_resnets': 4, 'mid_num_attentions': 2, # Up block parameters 'prev_output_channels': [1280, 1280, 640], 'up_blocks_num_attentions': [6, 3, 0], 'up_blocks_num_resnets': [2, 3, 3], 'add_upsample': [True, True, False], } def custom_sort_order(obj): """ Key function for sorting order of execution in forward methods """ return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__) class FlexibleIdentityBlock(nn.Module): def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): return hidden_states class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin): configurations = FlexibleUnetConfigurations @register_to_config def __init__(self): super().__init__(sample_size=self.configurations.get('sample_size', FlexibleUnetConfigurations['sample_size']), cross_attention_dim=self.configurations.get("cross_attention_dim", FlexibleUnetConfigurations['cross_attention_dim'])) num_attention_heads = self.configurations.get("num_attention_heads") cross_attention_dim = self.configurations.get("cross_attention_dim") mix_block_in_forward = self.configurations.get("mix_block_in_forward") resnet_act_fn = self.configurations.get("resnet_act_fn") resnet_eps = self.configurations.get("resnet_eps") temb_dim = self.configurations.get("temb_dim") ############### # Down blocks # ############### down_blocks_num_attentions = self.configurations.get("down_blocks_num_attentions") down_blocks_out_channels = self.configurations.get("down_blocks_out_channels") down_blocks_in_channels = self.configurations.get("down_blocks_in_channels") down_blocks_num_resnets = self.configurations.get("down_blocks_num_resnets") add_downsample = self.configurations.get("add_downsample") self.down_blocks = nn.ModuleList() for i, (in_c, out_c, n_res, n_att, add_down) in enumerate(zip(down_blocks_in_channels, down_blocks_out_channels, down_blocks_num_resnets, down_blocks_num_attentions, add_downsample)): last_block = i == len(down_blocks_in_channels) - 1 self.down_blocks.append(FlexibleCrossAttnDownBlock2D(in_channels=in_c, out_channels=out_c, temb_channels=temb_dim, num_resnets=n_res, num_attentions=n_att, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, num_attention_heads=num_attention_heads, cross_attention_dim=cross_attention_dim, add_downsample=add_down, last_block=last_block, mix_block_in_forward=mix_block_in_forward)) ############### # Mid blocks # ############### mid_block_add_upsample = self.configurations.get("add_upsample_mid_block") mid_num_attentions = self.configurations.get("mid_num_attentions") mid_num_resnets = self.configurations.get("mid_num_resnets") if mid_num_resnets == mid_num_attentions == 0: self.mid_block = FlexibleIdentityBlock() else: self.mid_block = FlexibleUNetMidBlock2DCrossAttn(in_channels=down_blocks_out_channels[-1], temb_channels=temb_dim, resnet_act_fn=resnet_act_fn, resnet_eps=resnet_eps, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, num_resnets=mid_num_resnets, num_attentions=mid_num_attentions, mix_block_in_forward=mix_block_in_forward, add_upsample=mid_block_add_upsample ) ############### # Up blocks # ############### up_blocks_num_attentions = self.configurations.get("up_blocks_num_attentions") up_blocks_num_resnets = self.configurations.get("up_blocks_num_resnets") prev_output_channels = self.configurations.get("prev_output_channels") up_upsample = self.configurations.get("add_upsample") self.up_blocks = nn.ModuleList() for in_c, out_c, prev_out, n_res, n_att, add_up in zip(reversed(down_blocks_in_channels), reversed(down_blocks_out_channels), prev_output_channels, up_blocks_num_resnets, up_blocks_num_attentions, up_upsample): self.up_blocks.append(FlexibleCrossAttnUpBlock2D(in_channels=in_c, out_channels=out_c, prev_output_channel=prev_out, temb_channels=temb_dim, num_resnets=n_res, num_attentions=n_att, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, num_attention_heads=num_attention_heads, cross_attention_dim=cross_attention_dim, add_upsample=add_up, mix_block_in_forward=mix_block_in_forward )) class FlexibleCrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_resnets: int = 1, num_attentions: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, cross_attention_dim: int = 1280, output_scale_factor: float = 1.0, downsample_padding: int = 1, add_downsample: bool = True, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, last_block: bool = False, mix_block_in_forward: bool = True, ): super().__init__() self.last_block = last_block self.mix_block_in_forward = mix_block_in_forward self.has_cross_attention = True self.num_attention_heads = num_attention_heads modules = [] add_resnets = [True] * num_resnets add_cross_attentions = [True] * num_attentions for i, (add_resnet, add_cross_attention) in enumerate( itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): in_channels = in_channels if i == 0 else out_channels if add_resnet: modules.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if add_cross_attention: modules.append( FlexibleTransformer2DModel( num_attention_heads=num_attention_heads, attention_head_dim=out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) if not mix_block_in_forward: modules = sorted(modules, key=custom_sort_order) self.modules_list = nn.ModuleList(modules) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () for module in self.modules_list: if isinstance(module, ResnetBlock2D): hidden_states = module(hidden_states, temb) elif isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): hidden_states = module( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] else: raise ValueError(f'Got an unexpected module in modules list! {type(module)}') if isinstance(module, ResnetBlock2D): output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) if not self.last_block: output_states = output_states + (hidden_states,) return hidden_states, output_states class FlexibleCrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_resnets: int = 1, num_attentions: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, cross_attention_dim: int = 1280, output_scale_factor: float = 1.0, add_upsample: bool = True, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, mix_block_in_forward: bool = True ): super().__init__() modules = [] # WARNING: This parameter is filled with number of resnets and used within StableDiffusionPipeline self.resnets = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads add_resnets = [True] * num_resnets add_cross_attentions = [True] * num_attentions for i, (add_resnet, add_cross_attention) in enumerate( itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): res_skip_channels = in_channels if (i == len(add_resnets) - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels if add_resnet: self.resnets += [True] modules.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if add_cross_attention: modules.append( FlexibleTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) if not mix_block_in_forward: modules = sorted(modules, key=custom_sort_order) self.modules_list = nn.ModuleList(modules) self.upsamplers = None if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): for module in self.modules_list: if isinstance(module, ResnetBlock2D): res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = module(hidden_states, temb) if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): hidden_states = module( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class FlexibleUNetMidBlock2DCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_resnets: int = 1, num_attentions: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, output_scale_factor: float = 1.0, cross_attention_dim: int = 1280, use_linear_projection: bool = False, upcast_attention: bool = False, mix_block_in_forward: bool = True, add_upsample: bool = True, ): super().__init__() self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # There is always at least one resnet modules = [ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, )] add_resnets = [True] * num_resnets add_cross_attentions = [True] * num_attentions for i, (add_resnet, add_cross_attention) in enumerate( itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): if add_cross_attention: modules.append( FlexibleTransformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) ) if add_resnet: modules.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not mix_block_in_forward: modules = sorted(modules, key=custom_sort_order) self.modules_list = nn.ModuleList(modules) self.upsamplers = nn.ModuleList([nn.Identity()]) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)]) def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: hidden_states = self.modules_list[0](hidden_states, temb) for module in self.modules_list: if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): hidden_states = module( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] elif isinstance(module, ResnetBlock2D): hidden_states = module(hidden_states, temb) for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class FlexibleTransformer2DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, only_cross_attention: bool = False, use_linear_projection: bool = False, upcast_attention: bool = False, norm_type: str = "layer_norm", norm_elementwise_affine: bool = True, ): super().__init__() self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim self.in_channels = in_channels inner_dim = num_attention_heads * attention_head_dim # Define input layers self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.use_linear_projection = use_linear_projection if self.use_linear_projection: self.proj_in = nn.Linear(in_channels, inner_dim) else: self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) # Define transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, norm_elementwise_affine=norm_elementwise_affine, ) for _ in range(num_layers) ] ) # Define output layers self.out_channels = in_channels if out_channels is None else out_channels if self.use_linear_projection: self.proj_out = nn.Linear(inner_dim, in_channels) else: self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, class_labels: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = False ): # 1. Input batch, _, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = self.proj_in(hidden_states) # 2. Blocks for block in self.transformer_blocks: hidden_states = block( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) # 3. Output if not self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual if return_dict: return (output,) return Transformer2DModelOutput(sample=output) class DeciDiffusionPipeline(StableDiffusionPipeline): deci_default_number_of_iterations = 30 deci_default_guidance_rescale = 0.7 def __init__(self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True ): # Replace UNet with Deci`s unet del unet unet = FlexibleUNet2DConditionModel() super().__init__(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, requires_safety_checker=requires_safety_checker ) self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor) def __call__(self, *args, **kwargs): # Set up default training parameters (if not given by user specifically) if "guidance_rescale" not in kwargs: kwargs.update({'guidance_rescale': self.deci_default_guidance_rescale}) if "num_inference_steps" not in kwargs: kwargs.update({'num_inference_steps': self.deci_default_number_of_iterations}) return super().__call__(*args, **kwargs)