Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from typing import Optional | |
| import comfy.ldm.modules.diffusionmodules.mmdit | |
| class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT): | |
| def __init__( | |
| self, | |
| num_blocks = None, | |
| control_latent_channels = None, | |
| dtype = None, | |
| device = None, | |
| operations = None, | |
| **kwargs, | |
| ): | |
| super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs) | |
| # controlnet_blocks | |
| self.controlnet_blocks = torch.nn.ModuleList([]) | |
| for _ in range(len(self.joint_blocks)): | |
| self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype)) | |
| if control_latent_channels is None: | |
| control_latent_channels = self.in_channels | |
| self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed( | |
| None, | |
| self.patch_size, | |
| control_latent_channels, | |
| self.hidden_size, | |
| bias=True, | |
| strict_img_size=False, | |
| dtype=dtype, | |
| device=device, | |
| operations=operations | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| timesteps: torch.Tensor, | |
| y: Optional[torch.Tensor] = None, | |
| context: Optional[torch.Tensor] = None, | |
| hint = None, | |
| ) -> torch.Tensor: | |
| #weird sd3 controlnet specific stuff | |
| y = torch.zeros_like(y) | |
| if self.context_processor is not None: | |
| context = self.context_processor(context) | |
| hw = x.shape[-2:] | |
| x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device) | |
| x += self.pos_embed_input(hint) | |
| c = self.t_embedder(timesteps, dtype=x.dtype) | |
| if y is not None and self.y_embedder is not None: | |
| y = self.y_embedder(y) | |
| c = c + y | |
| if context is not None: | |
| context = self.context_embedder(context) | |
| output = [] | |
| blocks = len(self.joint_blocks) | |
| for i in range(blocks): | |
| context, x = self.joint_blocks[i]( | |
| context, | |
| x, | |
| c=c, | |
| use_checkpoint=self.use_checkpoint, | |
| ) | |
| out = self.controlnet_blocks[i](x) | |
| count = self.depth // blocks | |
| if i == blocks - 1: | |
| count -= 1 | |
| for j in range(count): | |
| output.append(out) | |
| return {"output": output} | |