Spaces:
Paused
Paused
Update models/controlnet.py
Browse files- models/controlnet.py +2 -44
models/controlnet.py
CHANGED
@@ -43,10 +43,9 @@ class ControlNetOutput(BaseOutput):
|
|
43 |
down_block_res_samples: Tuple[torch.Tensor]
|
44 |
mid_block_res_sample: torch.Tensor
|
45 |
|
46 |
-
|
47 |
class ControlNetConditioningEmbedding(nn.Module):
|
48 |
"""
|
49 |
-
"""
|
50 |
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
51 |
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
52 |
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
@@ -54,7 +53,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
|
54 |
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
55 |
model) to encode image-space conditions ... into feature maps ..."
|
56 |
"""
|
57 |
-
|
58 |
|
59 |
def __init__(
|
60 |
self,
|
@@ -89,48 +88,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
|
89 |
embedding = self.conv_out(embedding)
|
90 |
|
91 |
return embedding
|
92 |
-
"""
|
93 |
-
|
94 |
-
class ControlNetConditioningEmbedding(nn.Module):
|
95 |
-
def __init__(
|
96 |
-
self,
|
97 |
-
conditioning_embedding_channels: int,
|
98 |
-
conditioning_channels: int = 3,
|
99 |
-
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
100 |
-
):
|
101 |
-
super().__init__()
|
102 |
|
103 |
-
self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
104 |
-
self.bn_in = nn.BatchNorm3d(block_out_channels[0])
|
105 |
-
|
106 |
-
self.blocks = nn.ModuleList([])
|
107 |
-
self.bns = nn.ModuleList([])
|
108 |
-
|
109 |
-
for i in range(len(block_out_channels) - 1):
|
110 |
-
channel_in = block_out_channels[i]
|
111 |
-
channel_out = block_out_channels[i + 1]
|
112 |
-
self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1))
|
113 |
-
self.bns.append(nn.BatchNorm3d(channel_in))
|
114 |
-
self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
115 |
-
self.bns.append(nn.BatchNorm3d(channel_out))
|
116 |
-
|
117 |
-
self.conv_out = zero_module(
|
118 |
-
InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
119 |
-
)
|
120 |
-
|
121 |
-
def forward(self, conditioning):
|
122 |
-
embedding = self.conv_in(conditioning)
|
123 |
-
embedding = self.bn_in(embedding)
|
124 |
-
embedding = F.silu(embedding)
|
125 |
-
|
126 |
-
for block, bn in zip(self.blocks, self.bns):
|
127 |
-
embedding = block(embedding)
|
128 |
-
embedding = bn(embedding)
|
129 |
-
embedding = F.silu(embedding)
|
130 |
-
|
131 |
-
embedding = self.conv_out(embedding)
|
132 |
-
|
133 |
-
return embedding
|
134 |
|
135 |
class ControlNetModel3D(ModelMixin, ConfigMixin):
|
136 |
_supports_gradient_checkpointing = True
|
|
|
43 |
down_block_res_samples: Tuple[torch.Tensor]
|
44 |
mid_block_res_sample: torch.Tensor
|
45 |
|
46 |
+
|
47 |
class ControlNetConditioningEmbedding(nn.Module):
|
48 |
"""
|
|
|
49 |
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
50 |
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
51 |
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
|
|
53 |
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
54 |
model) to encode image-space conditions ... into feature maps ..."
|
55 |
"""
|
56 |
+
|
57 |
|
58 |
def __init__(
|
59 |
self,
|
|
|
88 |
embedding = self.conv_out(embedding)
|
89 |
|
90 |
return embedding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
class ControlNetModel3D(ModelMixin, ConfigMixin):
|
94 |
_supports_gradient_checkpointing = True
|