fffiloni commited on
Commit
d89e409
1 Parent(s): 0a986cd

Update models/controlnet.py

Browse files
Files changed (1) hide show
  1. models/controlnet.py +44 -1
models/controlnet.py CHANGED
@@ -43,9 +43,10 @@ class ControlNetOutput(BaseOutput):
43
  down_block_res_samples: Tuple[torch.Tensor]
44
  mid_block_res_sample: torch.Tensor
45
 
46
-
47
  class ControlNetConditioningEmbedding(nn.Module):
48
  """
 
49
  Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
50
  [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
51
  training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
@@ -53,6 +54,7 @@ class ControlNetConditioningEmbedding(nn.Module):
53
  (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
54
  model) to encode image-space conditions ... into feature maps ..."
55
  """
 
56
 
57
  def __init__(
58
  self,
@@ -87,7 +89,48 @@ class ControlNetConditioningEmbedding(nn.Module):
87
  embedding = self.conv_out(embedding)
88
 
89
  return embedding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  class ControlNetModel3D(ModelMixin, ConfigMixin):
93
  _supports_gradient_checkpointing = True
 
43
  down_block_res_samples: Tuple[torch.Tensor]
44
  mid_block_res_sample: torch.Tensor
45
 
46
+ """
47
  class ControlNetConditioningEmbedding(nn.Module):
48
  """
49
+ """
50
  Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
51
  [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
52
  training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
 
54
  (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
55
  model) to encode image-space conditions ... into feature maps ..."
56
  """
57
+ """
58
 
59
  def __init__(
60
  self,
 
89
  embedding = self.conv_out(embedding)
90
 
91
  return embedding
92
+ """
93
+
94
+ class ControlNetConditioningEmbedding(nn.Module):
95
+ def __init__(
96
+ self,
97
+ conditioning_embedding_channels: int,
98
+ conditioning_channels: int = 3,
99
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
100
+ ):
101
+ super().__init__()
102
+
103
+ self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
104
+ self.bn_in = nn.BatchNorm3d(block_out_channels[0])
105
+
106
+ self.blocks = nn.ModuleList([])
107
+ self.bns = nn.ModuleList([])
108
+
109
+ for i in range(len(block_out_channels) - 1):
110
+ channel_in = block_out_channels[i]
111
+ channel_out = block_out_channels[i + 1]
112
+ self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1))
113
+ self.bns.append(nn.BatchNorm3d(channel_in))
114
+ self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
115
+ self.bns.append(nn.BatchNorm3d(channel_out))
116
 
117
+ self.conv_out = zero_module(
118
+ InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
119
+ )
120
+
121
+ def forward(self, conditioning):
122
+ embedding = self.conv_in(conditioning)
123
+ embedding = self.bn_in(embedding)
124
+ embedding = F.silu(embedding)
125
+
126
+ for block, bn in zip(self.blocks, self.bns):
127
+ embedding = block(embedding)
128
+ embedding = bn(embedding)
129
+ embedding = F.silu(embedding)
130
+
131
+ embedding = self.conv_out(embedding)
132
+
133
+ return embedding
134
 
135
  class ControlNetModel3D(ModelMixin, ConfigMixin):
136
  _supports_gradient_checkpointing = True