VinayHajare commited on
Commit
c5025e3
·
verified ·
1 Parent(s): 8ff3e05

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +7 -6
  2. app.py +56 -0
  3. inference.py +113 -0
  4. model.py +924 -0
  5. utils.py +43 -0
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
- title: WordCraft
3
- emoji: 🐠
4
- colorFrom: green
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 4.26.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Text To Image EfficientCLIP GAN
3
+ emoji: 📸
4
+ colorFrom: indigo
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.22.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: Create images from text utilizing the EfficientCLIP-GAN
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import gradio as gr
4
+ import requests
5
+ from PIL import Image
6
+
7
+ from utils import read_css_from_file
8
+ from inference import generate_image_from_text, generate_image_from_text_with_persistent_storage
9
+
10
+ # Read CSS from file
11
+ css = read_css_from_file("style.css")
12
+
13
+ DESCRIPTION = '''
14
+ <div id="content_align">
15
+ <span style="color:darkred;font-size:32px;font-weight:bold">
16
+ WordCraft : Visuals from Verbs
17
+ </span>
18
+ </div>
19
+ <div id="content_align">
20
+ <span style="color:blue;font-size:18px;font-weight:bold;">
21
+ <br>A small, lighting fast efficient AI image generator
22
+ </span>
23
+ </div>
24
+ <div id="content_align" style="margin-top: 10px;font-weight:bold;">
25
+ <br>This 💻 demo uses the EfficientCLIP-GAN model trained on CUB 🐦🐥 and CC12M 📸🌃🌉 dataset.
26
+ <br>Keep your prompt coherent to domain of the selected model.
27
+ <br>If you like the demo, don't forget to click on the like 💖 button.
28
+ </div>
29
+ '''
30
+ available_models = [
31
+ ("EfficientCLIP-GAN trained on CUB dataset (Restricted to birds)", "CUB"),
32
+ ("EfficientCLIP-GAN trained on CC12M dataset (More flexible)", "CC12M")
33
+ ]
34
+
35
+ # Creating Gradio interface
36
+ with gr.Blocks(css=css) as app:
37
+ gr.Markdown(DESCRIPTION)
38
+ with gr.Row():
39
+ with gr.Column():
40
+ text_prompt = gr.Textbox(label="Input Prompt", value="this tiny bird has a very small bill, a belly covered with white delicate feathers and has a set of black rounded eyes.", lines=3)
41
+ model_selector = gr.Dropdown(choices=available_models, value="CUB", label="Select Model", info="Select a model with which you want to generate images")
42
+ generate_button = gr.Button("Generate Images", variant='primary')
43
+
44
+ with gr.Row():
45
+ with gr.Column():
46
+ image_output1 = gr.Image(label="Generated Image 1")
47
+ image_output2 = gr.Image(label="Generated Image 2")
48
+
49
+ with gr.Column():
50
+ image_output3 = gr.Image(label="Generated Image 3")
51
+ image_output4 = gr.Image(label="Generated Image 4")
52
+
53
+ generate_button.click(generate_image_from_text, inputs=[text_prompt, model_selector], outputs=[image_output1, image_output2, image_output3, image_output4])
54
+
55
+ # Launch the app
56
+ app.launch()
inference.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import io
4
+ import torch
5
+ import torchvision
6
+ import clip
7
+ import numpy as np
8
+ from huggingface_hub import hf_hub_download
9
+ from PIL import Image
10
+ from torchvision.transforms.functional import to_pil_image
11
+
12
+ from utils import load_model_weights
13
+ from model import NetG, CLIP_TXT_ENCODER
14
+
15
+ # checking the device
16
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
17
+
18
+ # Getting the HF token
19
+ HF_TOKEN = os.getenv("HF_TOKEN")
20
+
21
+ # repository of the model
22
+ repo_id = "VinayHajare/EfficientCLIP-GAN"
23
+ cub_model = "saved_models/state_epoch_1480.pth"
24
+ cc12m_model = "saved_models/EfficientCLIP-GAN-CC12M.pth"
25
+
26
+ # clip model wrapped with the custom encoder
27
+ clip_text = "ViT-B/32"
28
+ clip_model, preprocessor = clip.load(clip_text, device=device)
29
+ clip_model = clip_model.eval()
30
+ text_encoder = CLIP_TXT_ENCODER(clip_model).to(device)
31
+
32
+ # loading the models from the repository and extracting the generator model
33
+ cub_model_path = hf_hub_download(repo_id = repo_id, filename = cub_model, token = HF_TOKEN)
34
+ checkpoint_cub = torch.load(cub_model_path, map_location=torch.device(device))
35
+ cc12m_model_path = hf_hub_download(repo_id = repo_id, filename = cc12m_model, token = HF_TOKEN)
36
+ checkpoint_cc12m = torch.load(cc12m_model_path, map_location=torch.device(device))
37
+
38
+ # Create a new Generator model and initialize it with the pre-trained weights
39
+ netG = NetG(64, 100, 512, 256, 3, False, clip_model).to(device)
40
+ #cub = load_model_weights(netG, checkpoint_cub['model']['netG'], multi_gpus=False)
41
+ #cc12m = load_model_weights(netG, checkpoint_cc12m['model']['netG'], multi_gpus=False)
42
+
43
+ # Function to generate images from text
44
+ def generate_image_from_text(caption, model, batch_size=4):
45
+ if model == "CUB":
46
+ generator = load_model_weights(netG, checkpoint_cub['model']['netG'], multi_gpus=False)
47
+ else:
48
+ generator = load_model_weights(netG, checkpoint_cc12m['model']['netG'], multi_gpus=False)
49
+
50
+ # Create the noise tensor
51
+ noise = torch.randn((batch_size, 100)).to(device)
52
+ with torch.no_grad():
53
+ # Tokenize caption
54
+ tokenized_text = clip.tokenize([caption]).to(device)
55
+ # Extract the sentence and word embedding from Custom CLIP ENCODER
56
+ sent_emb, word_emb = text_encoder(tokenized_text)
57
+ # Repeat the sentence embedding to match the batch size
58
+ sent_emb = sent_emb.repeat(batch_size, 1)
59
+ # generate the images
60
+ generated_images = generator(noise, sent_emb, eval=True).float()
61
+
62
+ # Convert the tensor images to PIL format
63
+ pil_images = []
64
+ for image_tensor in generated_images.unbind(0):
65
+ # Rescale tensor values to [0, 1]
66
+ image_tensor = image_tensor.data.clamp(-1, 1)
67
+ image_tensor = (image_tensor + 1.0) / 2.0
68
+
69
+ # Convert tensor to numpy array
70
+ image_numpy = image_tensor.permute(1, 2, 0).cpu().numpy()
71
+
72
+ # Clip numpy array values to [0, 1]
73
+ image_numpy = np.clip(image_numpy, 0, 1)
74
+
75
+ # Create a PIL image from the numpy array
76
+ pil_image = Image.fromarray((image_numpy * 255).astype(np.uint8))
77
+
78
+ pil_images.append(pil_image)
79
+
80
+ return pil_images
81
+
82
+ # Function to generate images from text
83
+ def generate_image_from_text_with_persistent_storage(caption, model, batch_size=4):
84
+ if model == "CUB":
85
+ generator = load_model_weights(netG, checkpoint_cub['model']['netG'], multi_gpus=False)
86
+ else:
87
+ generator = load_model_weights(netG, checkpoint_cc12m['model']['netG'], multi_gpus=False)
88
+
89
+ # Create the noise tensor
90
+ noise = torch.randn((batch_size, 100)).to(device)
91
+ with torch.no_grad():
92
+ # Tokenize caption
93
+ tokenized_text = clip.tokenize([caption]).to(device)
94
+ # Extract the sentence and word embedding from Custom CLIP ENCODER
95
+ sent_emb, word_emb = text_encoder(tokenized_text)
96
+ # Repeat the sentence embedding to match the batch size
97
+ sent_emb = sent_emb.repeat(batch_size, 1)
98
+ # generate the images
99
+ generated_images = generator(noise, sent_emb, eval=True).float()
100
+
101
+ # Create a permanent directory if it doesn't exist
102
+ permanent_dir = "generated_images"
103
+ if not os.path.exists(permanent_dir):
104
+ os.makedirs(permanent_dir)
105
+
106
+ image_paths = []
107
+ for idx, image_tensor in enumerate(generated_images.unbind(0)):
108
+ # Save the image tensor to a permanent file
109
+ image_path = os.path.join(permanent_dir, f"image_{idx}.png")
110
+ torchvision.utils.save_image(image_tensor.data, image_path, value_range=(-1, 1), normalize=True)
111
+ image_paths.append(image_path)
112
+
113
+ return image_paths
model.py ADDED
@@ -0,0 +1,924 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ from collections import OrderedDict
6
+ from utils import dummy_context_mgr
7
+
8
+
9
+ class CLIP_IMG_ENCODER(nn.Module):
10
+ """
11
+ CLIP_IMG_ENCODER module for encoding images using CLIP's visual transformer.
12
+ """
13
+
14
+ def __init__(self, CLIP):
15
+ """
16
+ Initialize the CLIP_IMG_ENCODER module.
17
+
18
+ Args:
19
+ CLIP (CLIP): Pre-trained CLIP model.
20
+ """
21
+ super(CLIP_IMG_ENCODER, self).__init__()
22
+ model = CLIP.visual
23
+ self.define_module(model)
24
+ # freeze the parameters of the CLIP model
25
+ for param in self.parameters():
26
+ param.requires_grad = False
27
+
28
+ def define_module(self, model):
29
+ """
30
+ Define the individual layers and modules of the CLIP visual transformer model.
31
+ Args:
32
+ model (nn.Module): CLIP visual transformer model.
33
+ """
34
+ # Extract required modules from the CLIP model
35
+ self.conv1 = model.conv1 # Convolutional layer
36
+ self.class_embedding = model.class_embedding # Class embedding layer
37
+ self.positional_embedding = model.positional_embedding # Positional embedding layer
38
+ self.ln_pre = model.ln_pre # Linear Normalization layer for pre-normalization
39
+ self.transformer = model.transformer # Transformer block
40
+ self.ln_post = model.ln_post # Linear Normalization layer for post-normalization
41
+ self.proj = model.proj # projection matrix
42
+
43
+ @property
44
+ def dtype(self):
45
+ """
46
+ Get the data type of the convolutional layer weights.
47
+ """
48
+ return self.conv1.weight.dtype
49
+
50
+ def transf_to_CLIP_input(self, inputs):
51
+ """
52
+ Transform input images to the format expected by CLIP.
53
+
54
+ Args:
55
+ inputs (torch.Tensor): Input images.
56
+
57
+ Returns:
58
+ torch.Tensor: Transformed images.
59
+ """
60
+ device = inputs.device
61
+ # Check the size of the input image tensor
62
+ if len(inputs.size()) != 4:
63
+ raise ValueError('Expect the (B, C, X, Y) tensor.')
64
+ else:
65
+ # Normalize input images
66
+ mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device)
67
+ var = torch.tensor([0.26862954, 0.26130258, 0.27577711]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device)
68
+ inputs = F.interpolate(inputs * 0.5 + 0.5, size=(224, 224))
69
+ inputs = ((inputs + 1) * 0.5 - mean) / var
70
+ return inputs
71
+
72
+ def forward(self, img: torch.Tensor):
73
+ """
74
+ Forward pass of the CLIP_IMG_ENCODER module.
75
+
76
+ Args:
77
+ img (torch.Tensor): Input images.
78
+
79
+ Returns:
80
+ torch.Tensor: Local features extracted from the image.
81
+ torch.Tensor: Encoded image embeddings.
82
+ """
83
+ # Transform input images to the format expected by CLIP and set its datatype appropriately
84
+ x = self.transf_to_CLIP_input(img)
85
+ x = x.type(self.dtype)
86
+
87
+ # Pass the image through Convolutional layer
88
+ x = self.conv1(x) # shape = [*, width, grid, grid]
89
+ grid = x.size(-1)
90
+
91
+ # Reshape and permute the tensor for transformer input
92
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
93
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
94
+
95
+ # Add class and positional embeddings
96
+ x = torch.cat(
97
+ [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
98
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
99
+ x = x + self.positional_embedding.to(x.dtype)
100
+ x = self.ln_pre(x)
101
+
102
+ # NLD (Batch Size - Length - Dimension) -> LND (Length - Batch Size - Dimension)
103
+ x = x.permute(1, 0, 2)
104
+
105
+ # Extract local features using transformer blocks
106
+ selected = [1, 4, 8]
107
+ local_features = []
108
+ for i in range(12):
109
+ x = self.transformer.resblocks[i](x)
110
+ if i in selected:
111
+ local_features.append(
112
+ x.permute(1, 0, 2)[:, 1:, :].permute(0, 2, 1).reshape(-1, 768, grid, grid).contiguous().type(
113
+ img.dtype))
114
+ x = x.permute(1, 0, 2) # LND -> NLD
115
+ x = self.ln_post(x[:, 0, :])
116
+ if self.proj is not None:
117
+ x = x @ self.proj # Perform matrix multiplication with projection matrix and tensor
118
+ return torch.stack(local_features, dim=1), x.type(img.dtype)
119
+
120
+
121
+ class CLIP_TXT_ENCODER(nn.Module):
122
+ """
123
+ CLIP_TXT_ENCODER module for encoding text inputs using CLIP's transformer.
124
+ """
125
+
126
+ def __init__(self, CLIP):
127
+ """
128
+ Initialize the CLIP_TXT_ENCODER module.
129
+
130
+ Args:
131
+ CLIP (CLIP): Pre-trained CLIP model.
132
+ """
133
+ super(CLIP_TXT_ENCODER, self).__init__()
134
+ self.define_module(CLIP)
135
+ # Freeze the parameters of the CLIP model
136
+ for param in self.parameters():
137
+ param.requires_grad = False
138
+
139
+ def define_module(self, CLIP):
140
+ """
141
+ Define the individual modules of the CLIP transformer model.
142
+
143
+ Args:
144
+ CLIP (CLIP): Pre-trained CLIP model.
145
+ """
146
+ self.transformer = CLIP.transformer # Transformer block
147
+ self.vocab_size = CLIP.vocab_size # Size of the vocabulary of the transformer
148
+ self.token_embedding = CLIP.token_embedding # token embedding block
149
+ self.positional_embedding = CLIP.positional_embedding # positional embedding block
150
+ self.ln_final = CLIP.ln_final # Linear Normalization layer
151
+ self.text_projection = CLIP.text_projection # Projection matrix for text
152
+
153
+ @property
154
+ def dtype(self):
155
+ """
156
+ Get the data type of the first layer's weights in the transformer.
157
+ """
158
+ return self.transformer.resblocks[0].mlp.c_fc.weight.dtype
159
+
160
+ def forward(self, text):
161
+ """
162
+ Forward pass of the CLIP_TXT_ENCODER module.
163
+
164
+ Args:
165
+ text (torch.Tensor): Input text tokens.
166
+
167
+ Returns:
168
+ torch.Tensor: Encoded sentence embeddings.
169
+ torch.Tensor: Transformer output for the input text.
170
+ """
171
+ # Embed input text tokens
172
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
173
+ # Add positional embeddings
174
+ x = x + self.positional_embedding.type(self.dtype)
175
+ # Permute dimensions for transformer input
176
+ x = x.permute(1, 0, 2) # NLD -> LND
177
+ # Pass input through the transformer
178
+ x = self.transformer(x)
179
+ # Permute dimensions back to original shape
180
+ x = x.permute(1, 0, 2) # LND -> NLD
181
+ # Apply layer normalization
182
+ x = self.ln_final(x).type(self.dtype) # shape = [batch_size, n_ctx, transformer.width]
183
+ # Extract sentence embeddings from the end-of-text (eot_token : is the highest number in each sequence)
184
+ sent_emb = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
185
+
186
+ # Return the sentence embedding and transformer ouput
187
+ return sent_emb, x
188
+
189
+
190
+ class CLIP_Mapper(nn.Module):
191
+ """
192
+ CLIP_Mapper module for mapping images with prompts using CLIP's transformer.
193
+ """
194
+
195
+ def __init__(self, CLIP):
196
+ """
197
+ Initialize the CLIP_Mapper module.
198
+
199
+ Args:
200
+ CLIP (CLIP): Pre-trained CLIP model.
201
+ """
202
+ super(CLIP_Mapper, self).__init__()
203
+ model = CLIP.visual
204
+ self.define_module(model)
205
+ # Freeze the parameters of the CLIP visual model
206
+ for param in model.parameters():
207
+ param.requires_grad = False
208
+
209
+ def define_module(self, model):
210
+ """
211
+ Define the individual modules of the CLIP visual model.
212
+
213
+ Args:
214
+ model: Pre-trained CLIP visual model.
215
+ """
216
+ self.conv1 = model.conv1
217
+ self.class_embedding = model.class_embedding
218
+ self.positional_embedding = model.positional_embedding
219
+ self.ln_pre = model.ln_pre
220
+ self.transformer = model.transformer
221
+
222
+ @property
223
+ def dtype(self):
224
+ """
225
+ Get the data type of the weights of the first convolutional layer.
226
+ """
227
+ return self.conv1.weight.dtype
228
+
229
+ def forward(self, img: torch.Tensor, prompts: torch.Tensor):
230
+ """
231
+ Forward pass of the CLIP_Mapper module.
232
+
233
+ Args:
234
+ img (torch.Tensor): Input image tensor.
235
+ prompts (torch.Tensor): Prompt tokens for mapping.
236
+
237
+ Returns:
238
+ torch.Tensor: Mapped features from the CLIP model.
239
+ """
240
+
241
+ # Convert input image and prompts to the appropriate data type
242
+ x = img.type(self.dtype)
243
+ prompts = prompts.type(self.dtype)
244
+ grid = x.size(-1)
245
+
246
+ # Reshape the input image tensor
247
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
248
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
249
+
250
+ # Append the class embeddings to input tensors
251
+ x = torch.cat(
252
+ [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
253
+ x],
254
+ dim=1
255
+ ) # shape = [*, grid ** 2 + 1, width]
256
+
257
+ # Append the positional embeddings to the input tensor
258
+ x = x + self.positional_embedding.to(x.dtype)
259
+
260
+ # Perform the layer normalization
261
+ x = self.ln_pre(x)
262
+ # NLD -> LND
263
+ x = x.permute(1, 0, 2)
264
+ # Local features
265
+ selected = [1, 2, 3, 4, 5, 6, 7, 8]
266
+ begin, end = 0, 12
267
+ prompt_idx = 0
268
+ for i in range(begin, end):
269
+ # Add prompt to the input tensor
270
+ if i in selected:
271
+ prompt = prompts[:, prompt_idx, :].unsqueeze(0)
272
+ prompt_idx = prompt_idx + 1
273
+ x = torch.cat((x, prompt), dim=0)
274
+ x = self.transformer.resblocks[i](x)
275
+ x = x[:-1, :, :]
276
+ else:
277
+ x = self.transformer.resblocks[i](x)
278
+ # Reshape and return mapped features
279
+ return x.permute(1, 0, 2)[:, 1:, :].permute(0, 2, 1).reshape(-1, 768, grid, grid).contiguous().type(img.dtype)
280
+
281
+
282
+ class CLIP_Adapter(nn.Module):
283
+ """
284
+ CLIP_Adapter module for adapting features from a generator to match the CLIP model's input requirements.
285
+ """
286
+
287
+ def __init__(self, in_ch, mid_ch, out_ch, G_ch, CLIP_ch, cond_dim, k, s, p, map_num, CLIP):
288
+ """
289
+ Initialize the CLIP_Adapter module.
290
+
291
+ Args:
292
+ in_ch (int): Number of input channels.
293
+ mid_ch (int): Number of channels in the intermediate layers.
294
+ out_ch (int): Number of output channels.
295
+ G_ch (int): Number of channels in the generator's output.
296
+ CLIP_ch (int): Number of channels in the CLIP model's input.
297
+ cond_dim (int): Dimension of the conditioning vector.
298
+ k (int): Kernel size for convolutional layers.
299
+ s (int): Stride for convolutional layers.
300
+ p (int): Padding for convolutional layers.
301
+ map_num (int): Number of mapping blocks.
302
+ CLIP: Pre-trained CLIP model.
303
+ """
304
+ super(CLIP_Adapter, self).__init__()
305
+ self.CLIP_ch = CLIP_ch
306
+ self.FBlocks = nn.ModuleList([])
307
+ # Define Mapping blocks (M_Block) and them to Feature blocks (FBlock) for given number of mapping blocks.
308
+ self.FBlocks.append(M_Block(in_ch, mid_ch, out_ch, cond_dim, k, s, p))
309
+ for i in range(map_num - 1):
310
+ self.FBlocks.append(M_Block(out_ch, mid_ch, out_ch, cond_dim, k, s, p))
311
+ # Convolutional layer to fuse adapted features
312
+ self.conv_fuse = nn.Conv2d(out_ch, CLIP_ch, 5, 1, 2)
313
+ # CLIP Mapper module to map adapted features to CLIP's input space
314
+ self.CLIP_ViT = CLIP_Mapper(CLIP)
315
+ # Convolutional layer to further process mapped features
316
+ self.conv = nn.Conv2d(768, G_ch, 5, 1, 2)
317
+ # Fully connected layer for conditioning
318
+ self.fc_prompt = nn.Linear(cond_dim, CLIP_ch * 8)
319
+
320
+ def forward(self, out, c):
321
+ """
322
+ Forward pass of the CLIP_Adapter module. Takes output features from the generator and conditioning vector
323
+ as input, adapts features using the Feature block having multiple mapping blocks, fuses them, map them to
324
+ CLIPs input space and returns the processed features
325
+
326
+ Args:
327
+ out (torch.Tensor): Output features from the generator.
328
+ c (torch.Tensor): Conditioning vector.
329
+
330
+ Returns:
331
+ torch.Tensor: Adapted and mapped features for the generator.
332
+ """
333
+
334
+ # Generate prompts from the conditioning vector
335
+ prompts = self.fc_prompt(c).view(c.size(0), -1, self.CLIP_ch)
336
+
337
+ # Pass features through feature block consisting of multiple mapping blocks
338
+ for FBlock in self.FBlocks:
339
+ out = FBlock(out, c)
340
+ # Fuse adapted features
341
+ fuse_feat = self.conv_fuse(out)
342
+ # Map fused features to CLIP's input space
343
+ map_feat = self.CLIP_ViT(fuse_feat, prompts)
344
+ # Further process mapped features and return
345
+ return self.conv(fuse_feat + 0.1 * map_feat)
346
+
347
+
348
+ class NetG(nn.Module):
349
+ """
350
+ Generator network for synthesizing images conditioned on text and noise
351
+ """
352
+
353
+ def __init__(self, ngf, nz, cond_dim, imsize, ch_size, mixed_precision, CLIP):
354
+ """
355
+ Initializes the Generator network.
356
+
357
+ Parameters:
358
+ ngf (int): Number of generator filters.
359
+ nz (int): Dimensionality of the input noise vector.
360
+ cond_dim (int): Dimensionality of the conditioning vector.
361
+ imsize (int): Size of the generated images.
362
+ ch_size (int): Number of output channels for the generated images.
363
+ mixed_precision (bool): Whether to use mixed precision training.
364
+ CLIP: CLIP model for feature adaptation.
365
+
366
+ """
367
+ super(NetG, self).__init__()
368
+ # Define attributes
369
+ self.ngf = ngf
370
+ self.mixed_precision = mixed_precision
371
+
372
+ # Build CLIP Mapper
373
+ self.code_sz, self.code_ch, self.mid_ch = 7, 64, 32
374
+ self.CLIP_ch = 768
375
+ # fully connected layer to convert the noise vector into a feature map of dimensions (code_sz * code_sz * code_ch)
376
+ self.fc_code = nn.Linear(nz, self.code_sz * self.code_sz * self.code_ch)
377
+ self.mapping = CLIP_Adapter(self.code_ch, self.mid_ch, self.code_ch, ngf * 8, self.CLIP_ch, cond_dim + nz, 3, 1,
378
+ 1, 4, CLIP)
379
+ # Build GBlocks
380
+ self.GBlocks = nn.ModuleList([])
381
+ in_out_pairs = list(get_G_in_out_chs(ngf, imsize))
382
+ imsize = 4
383
+ for idx, (in_ch, out_ch) in enumerate(in_out_pairs):
384
+ if idx < (len(in_out_pairs) - 1):
385
+ imsize = imsize * 2
386
+ else:
387
+ imsize = 224
388
+ self.GBlocks.append(G_Block(cond_dim + nz, in_ch, out_ch, imsize))
389
+
390
+ # To RGB image conversion using the sequential layers having leakyReLU activation function
391
+ self.to_rgb = nn.Sequential(
392
+ nn.LeakyReLU(0.2, inplace=True),
393
+ nn.Conv2d(out_ch, ch_size, 3, 1, 1),
394
+ )
395
+
396
+ def forward(self, noise, c, eval=False): # x=noise, c=ent_emb
397
+ """
398
+ Forward pass of the generator network.
399
+
400
+ Args:
401
+ noise (torch.Tensor): Input noise vector.
402
+ c (torch.Tensor): Conditioning information, typically an embedding representing attributes of the output.
403
+ eval (bool, optional): Flag indicating whether the network is in evaluation mode. Defaults to False.
404
+
405
+ Returns:
406
+ torch.Tensor: Generated RGB images.
407
+ """
408
+ # Context manager for enabling automatic mixed precision training
409
+ with torch.cuda.amp.autocast() if self.mixed_precision and not eval else dummy_context_mgr() as mp:
410
+ # Concatenate noise and conditioning information
411
+ cond = torch.cat((noise, c), dim=1)
412
+
413
+ # Pass noise through fully connected layer to generate feature map and adapt features using CLIP Mapper
414
+ out = self.mapping(self.fc_code(noise).view(noise.size(0), self.code_ch, self.code_sz, self.code_sz), cond)
415
+
416
+ # Apply GBlocks to progressively upsample feature representation, fuse text and visual features
417
+ for GBlock in self.GBlocks:
418
+ out = GBlock(out, cond)
419
+
420
+ # Convert final feature representation to RGB images
421
+ out = self.to_rgb(out)
422
+
423
+ return out
424
+
425
+
426
+ class NetD(nn.Module):
427
+ """
428
+ Discriminator network for evaluating the realism of images.
429
+ Attributes:
430
+ DBlocks (nn.ModuleList): List of D_Block modules for processing feature maps.
431
+ main (D_Block): Main D_Block module for final processing.
432
+ """
433
+
434
+ def __init__(self, ndf, imsize, ch_size, mixed_precision):
435
+ """
436
+ Initializes the Discriminator network
437
+
438
+ Args:
439
+ ndf (int): Number of channels in the initial features.
440
+ imsize (int): Size of the input images (assumed square).
441
+ ch_size (int): Number of channels in the output feature maps.
442
+ mixed_precision (bool): Flag indicating whether to use mixed precision training.
443
+ """
444
+ super(NetD, self).__init__()
445
+ self.mixed_precision = mixed_precision
446
+ # Define the DBlock
447
+ self.DBlocks = nn.ModuleList([
448
+ D_Block(768, 768, 3, 1, 1, res=True, CLIP_feat=True),
449
+ D_Block(768, 768, 3, 1, 1, res=True, CLIP_feat=True),
450
+ ])
451
+ # Define the main DBlock for the final processing
452
+ self.main = D_Block(768, 512, 3, 1, 1, res=True, CLIP_feat=False)
453
+
454
+ def forward(self, h):
455
+ """
456
+ Forward pass of the discriminator network.
457
+ Args:
458
+ h (torch.Tensor): Input feature maps.
459
+ Returns:
460
+ torch.Tensor: Discriminator output.
461
+ """
462
+ with torch.cuda.amp.autocast() if self.mixed_precision else dummy_context_mgr() as mpc:
463
+ # Initial feature map
464
+ out = h[:, 0]
465
+ # Pass the input feature through each DBlock
466
+ for idx in range(len(self.DBlocks)):
467
+ out = self.DBlocks[idx](out, h[:, idx + 1])
468
+ # Final processing through the main DBlock
469
+ out = self.main(out)
470
+ return out
471
+
472
+
473
+ class NetC(nn.Module):
474
+ """
475
+ Classifier / Comparator network for classifying the joint features of the generator output and condition text.
476
+ Attributes:
477
+ cond_dim (int): Dimensionality of the conditioning information.
478
+ mixed_precision (bool): Flag indicating whether to use mixed precision training.
479
+ joint_conv (nn.Sequential): Sequential module defining the classifier layers.
480
+ """
481
+ def __init__(self, ndf, cond_dim, mixed_precision):
482
+ """
483
+
484
+ """
485
+ super(NetC, self).__init__()
486
+ self.cond_dim = cond_dim
487
+ self.mixed_precision = mixed_precision
488
+ # Define the classifier layers, sequential convolutional 2D layer with LeakyReLU as the activation function
489
+ self.joint_conv = nn.Sequential(
490
+ nn.Conv2d(512 + 512, 128, 4, 1, 0, bias=False),
491
+ nn.LeakyReLU(0.2, inplace=True),
492
+ nn.Conv2d(128, 1, 4, 1, 0, bias=False),
493
+ )
494
+
495
+ def forward(self, out, cond):
496
+ """
497
+ Forward pass of the classifier network.
498
+
499
+ Args:
500
+ out (torch.Tensor): Generator output feature map.
501
+ cond (torch.Tensor): Conditioning information vector
502
+ """
503
+ with torch.cuda.amp.autocast() if self.mixed_precision else dummy_context_mgr() as mpc:
504
+ # Reshape and repeat conditioning information vector to match the feature map size
505
+ cond = cond.view(-1, self.cond_dim, 1, 1)
506
+ cond = cond.repeat(1, 1, 7, 7)
507
+
508
+ # Concatenate feature map and conditioned information
509
+ h_c_code = torch.cat((out, cond), 1)
510
+
511
+ # Pass through the classifier layers
512
+ out = self.joint_conv(h_c_code)
513
+ return out
514
+
515
+
516
+ class M_Block(nn.Module):
517
+ """
518
+ Multi-scale block consisting of convolutional layers and conditioning.
519
+
520
+ Attributes:
521
+ conv1 (nn.Conv2d): First convolutional layer.
522
+ fuse1 (DFBlock): Conditioning block for the first convolutional layer.
523
+ conv2 (nn.Conv2d): Second convolutional layer.
524
+ fuse2 (DFBlock): Conditioning block for the second convolutional layer.
525
+ learnable_sc (bool): Flag indicating whether the shortcut connection is learnable.
526
+ c_sc (nn.Conv2d): Convolutional layer for the shortcut connection.
527
+
528
+ """
529
+ def __init__(self, in_ch, mid_ch, out_ch, cond_dim, k, s, p):
530
+ """
531
+ Initializes the Multi-scale block.
532
+
533
+ Args:
534
+ in_ch (int): Number of input channels.
535
+ mid_ch (int): Number of channels in the intermediate layers.
536
+ out_ch (int): Number of output channels.
537
+ cond_dim (int): Dimensionality of the conditioning information.
538
+ k (int): Kernel size for convolutional layers.
539
+ s (int): Stride for convolutional layers.
540
+ p (int): Padding for convolutional layers.
541
+
542
+ """
543
+ super(M_Block, self).__init__()
544
+
545
+ # Define convolutional layers and conditioning blocks
546
+ self.conv1 = nn.Conv2d(in_ch, mid_ch, k, s, p)
547
+ self.fuse1 = DFBLK(cond_dim, mid_ch)
548
+ self.conv2 = nn.Conv2d(mid_ch, out_ch, k, s, p)
549
+ self.fuse2 = DFBLK(cond_dim, out_ch)
550
+
551
+ # Learnable shortcut connection
552
+ self.learnable_sc = in_ch != out_ch
553
+ if self.learnable_sc:
554
+ self.c_sc = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
555
+
556
+ def shortcut(self, x):
557
+ """
558
+ Defines the shortcut connection.
559
+
560
+ Args:
561
+ x (torch.Tensor): Input tensor.
562
+
563
+ Returns:
564
+ torch.Tensor: Shortcut connection output.
565
+ """
566
+ if self.learnable_sc:
567
+ x = self.c_sc(x)
568
+ return x
569
+
570
+ def residual(self, h, text):
571
+ """
572
+ Defines the residual path with conditioning.
573
+
574
+ Args:
575
+ h (torch.Tensor): Input tensor.
576
+ text (torch.Tensor): Conditioning information.
577
+
578
+ Returns:
579
+ torch.Tensor: Residual path output.
580
+ """
581
+ h = self.conv1(h)
582
+ h = self.fuse1(h, text)
583
+ h = self.conv2(h)
584
+ h = self.fuse2(h, text)
585
+ return h
586
+
587
+ def forward(self, h, c):
588
+ """
589
+ Forward pass of the multi-scale block.
590
+
591
+ Args:
592
+ h (torch.Tensor): Input tensor.
593
+ c (torch.Tensor): Conditioning information.
594
+
595
+ Returns:
596
+ torch.Tensor: Output tensor.
597
+ """
598
+ return self.shortcut(h) + self.residual(h, c)
599
+
600
+
601
+ class G_Block(nn.Module):
602
+ """
603
+ Generator block consisting of convolutional layers and conditioning.
604
+
605
+ Attributes:
606
+ imsize (int): Size of the output image.
607
+ learnable_sc (bool): Flag indicating whether the shortcut connection is learnable.
608
+ c1 (nn.Conv2d): First convolutional layer.
609
+ c2 (nn.Conv2d): Second convolutional layer.
610
+ fuse1 (DFBLK): Conditioning block for the first convolutional layer.
611
+ fuse2 (DFBLK): Conditioning block for the second convolutional layer.
612
+ c_sc (nn.Conv2d): Convolutional layer for the shortcut connection.
613
+ """
614
+
615
+ def __init__(self, cond_dim, in_ch, out_ch, imsize):
616
+ """
617
+ Initialize the Generator block.
618
+
619
+ Args:
620
+ cond_dim (int): Dimensionality of the conditioning information.
621
+ in_ch (int): Number of input channels.
622
+ out_ch (int): Number of output channels.
623
+ imsize (int): Size of the output image.
624
+ """
625
+ super(G_Block, self).__init__()
626
+
627
+ # Initialize attributes
628
+ self.imsize = imsize
629
+ self.learnable_sc = in_ch != out_ch
630
+
631
+ # Define convolution layers and conditioning blocks
632
+ self.c1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
633
+ self.c2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1)
634
+ self.fuse1 = DFBLK(cond_dim, in_ch)
635
+ self.fuse2 = DFBLK(cond_dim, out_ch)
636
+
637
+ # Learnable shortcut connection
638
+ if self.learnable_sc:
639
+ self.c_sc = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
640
+
641
+ def shortcut(self, x):
642
+ """
643
+ Defines the shortcut connection.
644
+
645
+ Args:
646
+ x (torch.Tensor): Input tensor.
647
+
648
+ Returns:
649
+ torch.Tensor: Shortcut connection output.
650
+ """
651
+ if self.learnable_sc:
652
+ x = self.c_sc(x)
653
+ return x
654
+
655
+ def residual(self, h, y):
656
+ """
657
+ Defines the residual path with conditioning.
658
+
659
+ Args:
660
+ h (torch.Tensor): Input tensor.
661
+ y (torch.Tensor): Conditioning information.
662
+
663
+ Returns:
664
+ torch.Tensor: Residual path output.
665
+ """
666
+ h = self.fuse1(h, y)
667
+ h = self.c1(h)
668
+ h = self.fuse2(h, y)
669
+ h = self.c2(h)
670
+ return h
671
+
672
+ def forward(self, h, y):
673
+ """
674
+ Forward pass of the generator block.
675
+
676
+ Args:
677
+ h (torch.Tensor): Input tensor.
678
+ y (torch.Tensor): Conditioning information.
679
+
680
+ Returns:
681
+ torch.Tensor: Output tensor.
682
+ """
683
+ h = F.interpolate(h, size=(self.imsize, self.imsize))
684
+ return self.shortcut(h) + self.residual(h, y)
685
+
686
+
687
+ class D_Block(nn.Module):
688
+ """
689
+ Discriminator block.
690
+ """
691
+ def __init__(self, fin, fout, k, s, p, res, CLIP_feat):
692
+ """
693
+ Initializes Discriminator block.
694
+
695
+ Args:
696
+ - fin (int): Number of input channels.
697
+ - fout (int): Number of output channels.
698
+ - k (int): Kernel size for convolutional layers.
699
+ - s (int): Stride for convolutional layers.
700
+ - p (int): Padding for convolutional layers.
701
+ - res (bool): Whether to use residual connection.
702
+ - CLIP_feat (bool): Whether to incorporate CLIP features.
703
+ """
704
+ super(D_Block, self).__init__()
705
+ self.res, self.CLIP_feat = res, CLIP_feat
706
+ self.learned_shortcut = (fin != fout)
707
+
708
+ # Convolutional layers for residual path
709
+ self.conv_r = nn.Sequential(
710
+ nn.Conv2d(fin, fout, k, s, p, bias=False),
711
+ nn.LeakyReLU(0.2, inplace=True),
712
+ nn.Conv2d(fout, fout, k, s, p, bias=False),
713
+ nn.LeakyReLU(0.2, inplace=True),
714
+ )
715
+
716
+ # Convolutional layers for shortcut connection
717
+ self.conv_s = nn.Conv2d(fin, fout, 1, stride=1, padding=0)
718
+
719
+ # Parameters for learned residual and CLIP features
720
+ if self.res == True:
721
+ self.gamma = nn.Parameter(torch.zeros(1))
722
+ if self.CLIP_feat == True:
723
+ self.beta = nn.Parameter(torch.zeros(1))
724
+
725
+ def forward(self, x, CLIP_feat=None):
726
+ """
727
+ Forward pass of the discriminator block.
728
+
729
+ Args:
730
+ - x (torch.Tensor): Input tensor.
731
+ - CLIP_feat (torch.Tensor): Optional CLIP features tensor.
732
+
733
+ Returns:
734
+ - torch.Tensor: Output tensor.
735
+ """
736
+ # Compute the residual features
737
+ res = self.conv_r(x)
738
+
739
+ # Compute the shortcut connection
740
+ if self.learned_shortcut:
741
+ x = self.conv_s(x)
742
+
743
+ # Incorporate learned residual and CLIP features if enabled
744
+ if (self.res == True) and (self.CLIP_feat == True):
745
+ return x + self.gamma * res + self.beta * CLIP_feat
746
+ elif (self.res == True) and (self.CLIP_feat != True):
747
+ return x + self.gamma * res
748
+ elif (self.res != True) and (self.CLIP_feat == True):
749
+ return x + self.beta * CLIP_feat
750
+ else:
751
+ return x
752
+
753
+
754
+ class DFBLK(nn.Module):
755
+ """
756
+ Diffusion Block of the Generator network with Conditional feature block
757
+ """
758
+ def __init__(self, cond_dim, in_ch):
759
+ """
760
+ Initializing the Conditional feature block of the DFBlock.
761
+
762
+ Args:
763
+ - cond_dim (int): Dimensionality of the conditional input.
764
+ - in_ch (int): Number of input channels.
765
+ """
766
+ super(DFBLK, self).__init__()
767
+ # Define conditional affine transformations
768
+ self.affine0 = Affine(cond_dim, in_ch)
769
+ self.affine1 = Affine(cond_dim, in_ch)
770
+
771
+ def forward(self, x, y=None):
772
+ """
773
+ Forward pass of the conditional feature block.
774
+
775
+ Args:
776
+ - x (torch.Tensor): Input tensor.
777
+ - y (torch.Tensor, optional): Conditional input tensor. Default is None.
778
+
779
+ Returns:
780
+ - torch.Tensor: Output tensor.
781
+ """
782
+ # Apply the first affine transformation and activation function
783
+ h = self.affine0(x, y)
784
+ h = nn.LeakyReLU(0.2, inplace=True)(h)
785
+ # Apply second affine transformation and activation function
786
+ h = self.affine1(h, y)
787
+ h = nn.LeakyReLU(0.2, inplace=True)(h)
788
+ return h
789
+
790
+
791
+ class QuickGELU(nn.Module):
792
+ """
793
+ Efficient and faster version of GELU,
794
+ for non-linearity and to learn complex patterns
795
+ """
796
+ def forward(self, x: torch.Tensor):
797
+ """
798
+ Forward pass of the QuickGELU activation function.
799
+
800
+ Args:
801
+ - x (torch.Tensor): Input tensor.
802
+
803
+ Returns:
804
+ - torch.Tensor: Output tensor.
805
+ """
806
+ # Apply QuickGELU activation function
807
+ return x * torch.sigmoid(1.702 * x)
808
+
809
+
810
+ # Taken from the RAT-GAN repository
811
+ class Affine(nn.Module):
812
+ """
813
+ Affine transformation module that applies conditional scaling and shifting to input features,
814
+ to incorporate additional control over the generated output based on input conditions.
815
+ """
816
+ def __init__(self, cond_dim, num_features):
817
+ """
818
+ Initialize the affine transformation module.
819
+ Args:
820
+ cond_dim (int): Dimensionality of the conditioning information.
821
+ num_features (int): Number of input features.
822
+ """
823
+ super(Affine, self).__init__()
824
+ # Define 2 fully connected networks to compute gamma and beta parameters
825
+ # each 2 linear layers with RELU activation in between
826
+ self.fc_gamma = nn.Sequential(OrderedDict([
827
+ ('linear1', nn.Linear(cond_dim, num_features)),
828
+ ('relu1', nn.ReLU(inplace=True)),
829
+ ('linear2', nn.Linear(num_features, num_features)),
830
+ ]))
831
+ self.fc_beta = nn.Sequential(OrderedDict([
832
+ ('linear1', nn.Linear(cond_dim, num_features)),
833
+ ('relu1', nn.ReLU(inplace=True)),
834
+ ('linear2', nn.Linear(num_features, num_features)),
835
+ ]))
836
+ # Initializes the weights and biases of the network
837
+ self._initialize()
838
+
839
+ def _initialize(self):
840
+ """
841
+ Initializes the weights and biases of the linear layers responsible for computing gamma and beta
842
+ """
843
+ nn.init.zeros_(self.fc_gamma.linear2.weight.data)
844
+ nn.init.ones_(self.fc_gamma.linear2.bias.data)
845
+ nn.init.zeros_(self.fc_beta.linear2.weight.data)
846
+ nn.init.zeros_(self.fc_beta.linear2.bias.data)
847
+
848
+ def forward(self, x, y=None):
849
+ """
850
+ Forward pass of the Affine transformation module.
851
+
852
+ Args:
853
+ x (torch.Tensor): Input tensor.
854
+ y (torch.Tensor, optional): Conditioning information tensor. Default is None.
855
+
856
+ Returns:
857
+ torch.Tensor: Transformed tensor after applying affine transformation.
858
+ """
859
+ # Compute gamma and beta parameters
860
+ weight = self.fc_gamma(y)
861
+ bias = self.fc_beta(y)
862
+
863
+ # Ensure proper shape for weight and bias tensors
864
+ if weight.dim() == 1:
865
+ weight = weight.unsqueeze(0)
866
+ if bias.dim() == 1:
867
+ bias = bias.unsqueeze(0)
868
+
869
+ # Expand weight and bias tensors to match input tensor shape
870
+ size = x.size()
871
+ weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
872
+ bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)
873
+
874
+ # Apply affine transformation
875
+ return weight * x + bias
876
+
877
+
878
+ def get_G_in_out_chs(nf, imsize):
879
+ """
880
+ Compute input-output channel pairs for generator blocks based on given number of channels and image size.
881
+
882
+ Args:
883
+ nf (int): Number of input channels.
884
+ imsize (int): Size of the input image.
885
+
886
+ Returns:
887
+ list: List of tuples containing input-output channel pairs for generator blocks.
888
+ """
889
+ # Determine the number of layers based on image size
890
+ layer_num = int(np.log2(imsize)) - 1
891
+
892
+ # Compute the number of channels for each layer
893
+ channel_nums = [nf * min(2 ** idx, 8) for idx in range(layer_num)]
894
+
895
+ # Reverse the channel numbers to start with the highest channel count
896
+ channel_nums = channel_nums[::-1]
897
+
898
+ # Generate input-output channel pairs for generator blocks
899
+ in_out_pairs = zip(channel_nums[:-1], channel_nums[1:])
900
+
901
+ return in_out_pairs
902
+
903
+
904
+ def get_D_in_out_chs(nf, imsize):
905
+ """
906
+ Compute input-output channel pairs for discriminator blocks based on given number of channels and image size.
907
+
908
+ Args:
909
+ nf (int): Number of input channels.
910
+ imsize (int): Size of the input image.
911
+
912
+ Returns:
913
+ list: List of tuples containing input-output channel pairs for discriminator blocks.
914
+ """
915
+ # Determine the number of layers based on image size
916
+ layer_num = int(np.log2(imsize)) - 1
917
+
918
+ # Compute the number of channels for each layer
919
+ channel_nums = [nf * min(2 ** idx, 8) for idx in range(layer_num)]
920
+
921
+ # Generate input-output channel pairs for discriminator blocks
922
+ in_out_pairs = zip(channel_nums[:-1], channel_nums[1:])
923
+
924
+ return in_out_pairs
utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def load_model_weights(model, weights, multi_gpus, train=True):
2
+ """
3
+ Load the model weights from the given checkpoint file
4
+ """
5
+ # If model was originally trained on a single GPU but needs to be loaded onto multiple ones,
6
+ # it removes the "module" prefix from the weight keys
7
+ if list(weights.keys())[0].find('module') == -1:
8
+ pretrained_with_multi_gpu = False
9
+ else:
10
+ pretrained_with_multi_gpu = True
11
+
12
+ if (multi_gpus is False) or (train is False):
13
+ if pretrained_with_multi_gpu:
14
+ state_dict = {
15
+ key[7:]: value
16
+ for key, value in weights.items()
17
+ }
18
+ else:
19
+ state_dict = weights
20
+ else:
21
+ state_dict = weights
22
+
23
+ # load the model from the state_dict
24
+ model.load_state_dict(state_dict)
25
+ return model
26
+
27
+
28
+ # Class to work with if mixed precision is failing
29
+ class dummy_context_mgr:
30
+ def __init__(self):
31
+ pass
32
+
33
+ def __enter__(self):
34
+ return None
35
+
36
+ def __exit__(self, exc_type, exc_value, traceback):
37
+ return False
38
+
39
+
40
+ # Function to read CSS from file
41
+ def read_css_from_file(filename):
42
+ with open(filename, 'r') as file:
43
+ return file.read()