from typing import Any, Dict, List, Optional, Tuple, Union import copy import torch from torch import nn, svd_lowrank from peft.tuners.lora import LoraLayer, Conv2d as PeftConv2d from diffusers.configuration_utils import register_to_config from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput, UNet2DConditionModel as UNet2DConditionModel class UNet2DConditionModelEx(UNet2DConditionModel): @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, dropout: float = 0.0, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: float = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, attention_type: str = "default", class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads: int = 64, extra_condition_names: List[str] = [], ): num_extra_conditions = len(extra_condition_names) super().__init__( sample_size=sample_size, in_channels=in_channels * (1 + num_extra_conditions), out_channels=out_channels, center_input_sample=center_input_sample, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift, down_block_types=down_block_types, mid_block_type=mid_block_type, up_block_types=up_block_types, only_cross_attention=only_cross_attention, block_out_channels=block_out_channels, layers_per_block=layers_per_block, downsample_padding=downsample_padding, mid_block_scale_factor=mid_block_scale_factor, dropout=dropout, act_fn=act_fn, norm_num_groups=norm_num_groups, norm_eps=norm_eps, cross_attention_dim=cross_attention_dim, transformer_layers_per_block=transformer_layers_per_block, reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, encoder_hid_dim=encoder_hid_dim, encoder_hid_dim_type=encoder_hid_dim_type, attention_head_dim=attention_head_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, class_embed_type=class_embed_type, addition_embed_type=addition_embed_type, addition_time_embed_dim=addition_time_embed_dim, num_class_embeds=num_class_embeds, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, time_embedding_type=time_embedding_type, time_embedding_dim=time_embedding_dim, time_embedding_act_fn=time_embedding_act_fn, timestep_post_act=timestep_post_act, time_cond_proj_dim=time_cond_proj_dim, conv_in_kernel=conv_in_kernel, conv_out_kernel=conv_out_kernel, projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, attention_type=attention_type, class_embeddings_concat=class_embeddings_concat, mid_block_only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, addition_embed_type_num_heads=addition_embed_type_num_heads,) self._internal_dict = copy.deepcopy(self._internal_dict) self.config.in_channels = in_channels self.config.extra_condition_names = extra_condition_names @property def extra_condition_names(self) -> List[str]: return self.config.extra_condition_names def add_extra_conditions(self, extra_condition_names: Union[str, List[str]]): if isinstance(extra_condition_names, str): extra_condition_names = [extra_condition_names] conv_in_kernel = self.config.conv_in_kernel conv_in_weight = self.conv_in.weight self.config.extra_condition_names += extra_condition_names full_in_channels = self.config.in_channels * (1 + len(self.config.extra_condition_names)) new_conv_in_weight = torch.zeros( conv_in_weight.shape[0], full_in_channels, conv_in_kernel, conv_in_kernel, dtype=conv_in_weight.dtype, device=conv_in_weight.device,) new_conv_in_weight[:,:conv_in_weight.shape[1]] = conv_in_weight self.conv_in.weight = nn.Parameter( new_conv_in_weight.data, requires_grad=conv_in_weight.requires_grad,) self.conv_in.in_channels = full_in_channels return self def activate_extra_condition_adapters(self): lora_layers = [layer for layer in self.modules() if isinstance(layer, LoraLayer)] if len(lora_layers) > 0: self._hf_peft_config_loaded = True for lora_layer in lora_layers: adapter_names = [k for k in lora_layer.scaling.keys() if k in self.config.extra_condition_names] adapter_names += lora_layer.active_adapters adapter_names = list(set(adapter_names)) lora_layer.set_adapter(adapter_names) def set_extra_condition_scale(self, scale: Union[float, List[float]] = 1.0): if isinstance(scale, float): scale = [scale] * len(self.config.extra_condition_names) lora_layers = [layer for layer in self.modules() if isinstance(layer, LoraLayer)] for s, n in zip(scale, self.config.extra_condition_names): for lora_layer in lora_layers: lora_layer.set_scale(n, s) @property def default_half_lora_target_modules(self) -> List[str]: module_names = [] for name, module in self.named_modules(): if "conv_out" in name or "up_blocks" in name: continue if isinstance(module, (nn.Linear, nn.Conv2d)): module_names.append(name) return list(set(module_names)) @property def default_full_lora_target_modules(self) -> List[str]: module_names = [] for name, module in self.named_modules(): if isinstance(module, (nn.Linear, nn.Conv2d)): module_names.append(name) return list(set(module_names)) @property def default_half_skip_attn_lora_target_modules(self) -> List[str]: return [ module_name for module_name in self.default_half_lora_target_modules if all( not module_name.endswith(attn_name) for attn_name in ["to_k", "to_q", "to_v", "to_out.0"] ) ] @property def default_full_skip_attn_lora_target_modules(self) -> List[str]: return [ module_name for module_name in self.default_full_lora_target_modules if all( not module_name.endswith(attn_name) for attn_name in ["to_k", "to_q", "to_v", "to_out.0"] ) ] def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, extra_conditions: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: if extra_conditions is not None: if isinstance(extra_conditions, list): extra_conditions = torch.cat(extra_conditions, dim=1) sample = torch.cat([sample, extra_conditions], dim=1) return super().forward( sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states, class_labels=class_labels, timestep_cond=timestep_cond, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, down_intrablock_additional_residuals=down_intrablock_additional_residuals, encoder_attention_mask=encoder_attention_mask, return_dict=return_dict,) class PeftConv2dEx(PeftConv2d): def reset_lora_parameters(self, adapter_name, init_lora_weights): if init_lora_weights is False: return if isinstance(init_lora_weights, str) and "pissa" in init_lora_weights.lower(): if self.conv2d_pissa_init(adapter_name, init_lora_weights): return # Failed init_lora_weights = "gaussian" super(PeftConv2d, self).reset_lora_parameters(adapter_name, init_lora_weights) def conv2d_pissa_init(self, adapter_name, init_lora_weights): weight = weight_ori = self.get_base_layer().weight weight = weight.flatten(start_dim=1) if self.r[adapter_name] > weight.shape[0]: return False dtype = weight.dtype if dtype not in [torch.float32, torch.float16, torch.bfloat16]: raise TypeError( "Please initialize PiSSA under float32, float16, or bfloat16. " "Subsequently, re-quantize the residual model to help minimize quantization errors." ) weight = weight.to(torch.float32) if init_lora_weights == "pissa": # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel}, V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False) Vr = V[:, : self.r[adapter_name]] Sr = S[: self.r[adapter_name]] Sr /= self.scaling[adapter_name] Uhr = Uh[: self.r[adapter_name]] elif len(init_lora_weights.split("_niter_")) == 2: Vr, Sr, Ur = svd_lowrank( weight.data, self.r[adapter_name], niter=int(init_lora_weights.split("_niter_")[-1]) ) Sr /= self.scaling[adapter_name] Uhr = Ur.t() else: raise ValueError( f"init_lora_weights should be 'pissa' or 'pissa_niter_[number of iters]', got {init_lora_weights} instead." ) lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr lora_B = Vr @ torch.diag(torch.sqrt(Sr)) self.lora_A[adapter_name].weight.data = lora_A.view([-1] + list(weight_ori.shape[1:])) self.lora_B[adapter_name].weight.data = lora_B.view([-1, self.r[adapter_name]] + [1] * (weight_ori.ndim - 2)) weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A weight = weight.to(dtype) self.get_base_layer().weight.data = weight.view_as(weight_ori) return True # Patch peft conv2d PeftConv2d.reset_lora_parameters = PeftConv2dEx.reset_lora_parameters PeftConv2d.conv2d_pissa_init = PeftConv2dEx.conv2d_pissa_init