|  | import math | 
					
						
						|  | from typing import List, Optional, Tuple | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | from torch import Tensor | 
					
						
						|  |  | 
					
						
						|  | from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ControlNetEmbedder(nn.Module): | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | img_size: int, | 
					
						
						|  | patch_size: int, | 
					
						
						|  | in_chans: int, | 
					
						
						|  | attention_head_dim: int, | 
					
						
						|  | num_attention_heads: int, | 
					
						
						|  | adm_in_channels: int, | 
					
						
						|  | num_layers: int, | 
					
						
						|  | main_model_double: int, | 
					
						
						|  | double_y_emb: bool, | 
					
						
						|  | device: torch.device, | 
					
						
						|  | dtype: torch.dtype, | 
					
						
						|  | pos_embed_max_size: Optional[int] = None, | 
					
						
						|  | operations = None, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.main_model_double = main_model_double | 
					
						
						|  | self.dtype = dtype | 
					
						
						|  | self.hidden_size = num_attention_heads * attention_head_dim | 
					
						
						|  | self.patch_size = patch_size | 
					
						
						|  | self.x_embedder = PatchEmbed( | 
					
						
						|  | img_size=img_size, | 
					
						
						|  | patch_size=patch_size, | 
					
						
						|  | in_chans=in_chans, | 
					
						
						|  | embed_dim=self.hidden_size, | 
					
						
						|  | strict_img_size=pos_embed_max_size is None, | 
					
						
						|  | device=device, | 
					
						
						|  | dtype=dtype, | 
					
						
						|  | operations=operations, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations) | 
					
						
						|  |  | 
					
						
						|  | self.double_y_emb = double_y_emb | 
					
						
						|  | if self.double_y_emb: | 
					
						
						|  | self.orig_y_embedder = VectorEmbedder( | 
					
						
						|  | adm_in_channels, self.hidden_size, dtype, device, operations=operations | 
					
						
						|  | ) | 
					
						
						|  | self.y_embedder = VectorEmbedder( | 
					
						
						|  | self.hidden_size, self.hidden_size, dtype, device, operations=operations | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | self.y_embedder = VectorEmbedder( | 
					
						
						|  | adm_in_channels, self.hidden_size, dtype, device, operations=operations | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.transformer_blocks = nn.ModuleList( | 
					
						
						|  | DismantledBlock( | 
					
						
						|  | hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True, | 
					
						
						|  | dtype=dtype, device=device, operations=operations | 
					
						
						|  | ) | 
					
						
						|  | for _ in range(num_layers) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.use_y_embedder = True | 
					
						
						|  |  | 
					
						
						|  | self.controlnet_blocks = nn.ModuleList([]) | 
					
						
						|  | for _ in range(len(self.transformer_blocks)): | 
					
						
						|  | controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) | 
					
						
						|  | self.controlnet_blocks.append(controlnet_block) | 
					
						
						|  |  | 
					
						
						|  | self.pos_embed_input = PatchEmbed( | 
					
						
						|  | img_size=img_size, | 
					
						
						|  | patch_size=patch_size, | 
					
						
						|  | in_chans=in_chans, | 
					
						
						|  | embed_dim=self.hidden_size, | 
					
						
						|  | strict_img_size=False, | 
					
						
						|  | device=device, | 
					
						
						|  | dtype=dtype, | 
					
						
						|  | operations=operations, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | x: torch.Tensor, | 
					
						
						|  | timesteps: torch.Tensor, | 
					
						
						|  | y: Optional[torch.Tensor] = None, | 
					
						
						|  | context: Optional[torch.Tensor] = None, | 
					
						
						|  | hint = None, | 
					
						
						|  | ) -> Tuple[Tensor, List[Tensor]]: | 
					
						
						|  | x_shape = list(x.shape) | 
					
						
						|  | x = self.x_embedder(x) | 
					
						
						|  | if not self.double_y_emb: | 
					
						
						|  | h = (x_shape[-2] + 1) // self.patch_size | 
					
						
						|  | w = (x_shape[-1] + 1) // self.patch_size | 
					
						
						|  | x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device) | 
					
						
						|  | c = self.t_embedder(timesteps, dtype=x.dtype) | 
					
						
						|  | if y is not None and self.y_embedder is not None: | 
					
						
						|  | if self.double_y_emb: | 
					
						
						|  | y = self.orig_y_embedder(y) | 
					
						
						|  | y = self.y_embedder(y) | 
					
						
						|  | c = c + y | 
					
						
						|  |  | 
					
						
						|  | x = x + self.pos_embed_input(hint) | 
					
						
						|  |  | 
					
						
						|  | block_out = () | 
					
						
						|  |  | 
					
						
						|  | repeat = math.ceil(self.main_model_double / len(self.transformer_blocks)) | 
					
						
						|  | for i in range(len(self.transformer_blocks)): | 
					
						
						|  | out = self.transformer_blocks[i](x, c) | 
					
						
						|  | if not self.double_y_emb: | 
					
						
						|  | x = out | 
					
						
						|  | block_out += (self.controlnet_blocks[i](out),) * repeat | 
					
						
						|  |  | 
					
						
						|  | return {"output": block_out} | 
					
						
						|  |  |