Dmmc commited on
Commit
c8ddb9b
1 Parent(s): aec1df6

three-model version

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +2 -0
  2. app.py +117 -0
  3. requirements.txt +5 -0
  4. src/__init__.py +2 -0
  5. src/__pycache__/__init__.cpython-39.pyc +0 -0
  6. src/__pycache__/config.cpython-39.pyc +0 -0
  7. src/config.py +47 -0
  8. src/data/.gitkeep +0 -0
  9. src/data/__init__.py +5 -0
  10. src/data/__pycache__/__init__.cpython-39.pyc +0 -0
  11. src/data/__pycache__/collate.cpython-39.pyc +0 -0
  12. src/data/__pycache__/datasets.cpython-39.pyc +0 -0
  13. src/data/__pycache__/tokenizer.cpython-39.pyc +0 -0
  14. src/data/collate.py +43 -0
  15. src/data/datasets.py +387 -0
  16. src/data/stubs/bird.jpg +0 -0
  17. src/data/stubs/pigeon.jpg +0 -0
  18. src/data/stubs/rohit.jpeg +0 -0
  19. src/data/tokenizer.py +23 -0
  20. src/features/.gitkeep +0 -0
  21. src/features/__init__.py +0 -0
  22. src/features/build_features.py +0 -0
  23. src/models/.gitkeep +0 -0
  24. src/models/__init__.py +4 -0
  25. src/models/__pycache__/__init__.cpython-39.pyc +0 -0
  26. src/models/__pycache__/losses.cpython-39.pyc +0 -0
  27. src/models/__pycache__/train_model.cpython-39.pyc +0 -0
  28. src/models/__pycache__/utils.cpython-39.pyc +0 -0
  29. src/models/losses.py +344 -0
  30. src/models/modules/__init__.py +12 -0
  31. src/models/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  32. src/models/modules/__pycache__/acm.cpython-39.pyc +0 -0
  33. src/models/modules/__pycache__/attention.cpython-39.pyc +0 -0
  34. src/models/modules/__pycache__/cond_augment.cpython-39.pyc +0 -0
  35. src/models/modules/__pycache__/conv_utils.cpython-39.pyc +0 -0
  36. src/models/modules/__pycache__/discriminator.cpython-39.pyc +0 -0
  37. src/models/modules/__pycache__/downsample.cpython-39.pyc +0 -0
  38. src/models/modules/__pycache__/generator.cpython-39.pyc +0 -0
  39. src/models/modules/__pycache__/image_encoder.cpython-39.pyc +0 -0
  40. src/models/modules/__pycache__/residual.cpython-39.pyc +0 -0
  41. src/models/modules/__pycache__/text_encoder.cpython-39.pyc +0 -0
  42. src/models/modules/__pycache__/upsample.cpython-39.pyc +0 -0
  43. src/models/modules/acm.py +37 -0
  44. src/models/modules/attention.py +88 -0
  45. src/models/modules/cond_augment.py +57 -0
  46. src/models/modules/conv_utils.py +78 -0
  47. src/models/modules/discriminator.py +144 -0
  48. src/models/modules/downsample.py +14 -0
  49. src/models/modules/generator.py +300 -0
  50. src/models/modules/image_encoder.py +138 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/*
2
+ .idea/*
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np # this should come first to mitigate mlk-service bug
2
+ from src.models.utils import get_image_arr, load_model
3
+ from src.data import TAIMGANTokenizer
4
+ from torchvision import transforms
5
+ from src.config import config_dict
6
+ from pathlib import Path
7
+ from enum import IntEnum, auto
8
+ from PIL import Image
9
+ import gradio as gr
10
+ import torch
11
+ from src.models.modules import (
12
+ VGGEncoder,
13
+ InceptionEncoder,
14
+ TextEncoder,
15
+ Generator
16
+ )
17
+
18
+ ##########
19
+ # PARAMS #
20
+ ##########
21
+
22
+ IMG_CHANS = 3 # RGB channels for image
23
+ IMG_HW = 256 # height and width of images
24
+ HIDDEN_DIM = 128 # hidden dimensions of lstm cell in one direction
25
+ C = 2 * HIDDEN_DIM # length of embeddings
26
+
27
+ Ng = config_dict["Ng"]
28
+ cond_dim = config_dict["condition_dim"]
29
+ z_dim = config_dict["noise_dim"]
30
+
31
+
32
+ ###############
33
+ # LOAD MODELS #
34
+ ###############
35
+
36
+ models = {
37
+ "COCO": {
38
+ "dir": "weights/coco"
39
+ },
40
+ "Bird": {
41
+ "dir": "weights/bird"
42
+ },
43
+ "UTKFace": {
44
+ "dir": "weights/utkface"
45
+ }
46
+ }
47
+
48
+ for model_name in models:
49
+ # create tokenizer
50
+ models[model_name]["tokenizer"] = TAIMGANTokenizer(captions_path=f"{models[model_name]['dir']}/captions.pickle")
51
+ vocab_size = len(models[model_name]["tokenizer"].word_to_ix)
52
+ # instantiate models
53
+ models[model_name]["generator"] = Generator(Ng=Ng, D=C, conditioning_dim=cond_dim, noise_dim=z_dim).eval()
54
+ models[model_name]["lstm"] = TextEncoder(vocab_size=vocab_size, emb_dim=C, hidden_dim=HIDDEN_DIM).eval()
55
+ models[model_name]["vgg"] = VGGEncoder().eval()
56
+ models[model_name]["inception"] = InceptionEncoder(D=C).eval()
57
+ # load models
58
+ load_model(
59
+ generator=models[model_name]["generator"],
60
+ discriminator=None,
61
+ image_encoder=models[model_name]["inception"],
62
+ text_encoder=models[model_name]["lstm"],
63
+ output_dir=Path(models[model_name]["dir"]),
64
+ device=torch.device("cpu")
65
+ )
66
+
67
+
68
+ def change_image_with_text(image: Image, text: str, model_name: str) -> Image:
69
+ """
70
+ Create an image modified by text from the original image
71
+ and save it with _modified postfix
72
+
73
+ :param gr.Image image: Path to the image
74
+ :param str text: Desired caption
75
+ """
76
+ global models
77
+ tokenizer = models[model_name]["tokenizer"]
78
+ G = models[model_name]["generator"]
79
+ lstm = models[model_name]["lstm"]
80
+ inception = models[model_name]["inception"]
81
+ vgg = models[model_name]["vgg"]
82
+ # generate some noise
83
+ noise = torch.rand(z_dim).unsqueeze(0)
84
+ # transform input text and get masks with embeddings
85
+ tokens = torch.tensor(tokenizer.encode(text)).unsqueeze(0)
86
+ mask = (tokens == tokenizer.pad_token_id)
87
+ word_embs, sent_embs = lstm(tokens)
88
+ # open the image and transform it to the tensor
89
+ image = transforms.Compose([
90
+ transforms.ToTensor(),
91
+ transforms.Resize((IMG_HW, IMG_HW)),
92
+ transforms.Normalize(
93
+ mean=(0.5, 0.5, 0.5),
94
+ std=(0.5, 0.5, 0.5)
95
+ )
96
+ ])(image).unsqueeze(0)
97
+ # obtain visual features of the image
98
+ vgg_features = vgg(image)
99
+ local_features, global_features = inception(image)
100
+ # generate new image from the old one
101
+ fake_image, _, _ = G(noise, sent_embs, word_embs, global_features,
102
+ local_features, vgg_features, mask)
103
+ # denormalize the image
104
+ fake_image = Image.fromarray(get_image_arr(fake_image)[0])
105
+ # return image in gradio format
106
+ return fake_image
107
+
108
+
109
+ ##########
110
+ # GRADIO #
111
+ ##########
112
+ demo = gr.Interface(
113
+ fn=change_image_with_text,
114
+ inputs=[gr.Image(type="pil"), "text", gr.inputs.Dropdown(list(models.keys()))],
115
+ outputs=gr.Image(type="pil")
116
+ )
117
+ demo.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Pillow
2
+ torch
3
+ torchvision
4
+ torchaudio
5
+ nltk
src/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """Config file for the project."""
2
+ from .config import config_dict, update_config
src/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (260 Bytes). View file
 
src/__pycache__/config.cpython-39.pyc ADDED
Binary file (1.17 kB). View file
 
src/config.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configurations for the project."""
2
+ from pathlib import Path
3
+ from typing import Any, Dict
4
+
5
+ import torch
6
+
7
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
8
+
9
+ repo_path = Path(__file__).parent.parent.absolute()
10
+ output_path = repo_path / "models"
11
+
12
+ config_dict = {
13
+ "Ng": 32,
14
+ "D": 256,
15
+ "condition_dim": 100,
16
+ "noise_dim": 100,
17
+ "lr_config": {
18
+ "disc_lr": 2e-4,
19
+ "gen_lr": 2e-4,
20
+ "img_encoder_lr": 3e-3,
21
+ "text_encoder_lr": 3e-3,
22
+ },
23
+ "batch_size": 64,
24
+ "device": device,
25
+ "epochs": 200,
26
+ "output_dir": output_path,
27
+ "snapshot": 5,
28
+ "const_dict": {
29
+ "smooth_val_gen": 0.999,
30
+ "lambda1": 1,
31
+ "lambda2": 1,
32
+ "lambda3": 1,
33
+ "lambda4": 1,
34
+ "gamma1": 4,
35
+ "gamma2": 5,
36
+ "gamma3": 10,
37
+ },
38
+ }
39
+
40
+
41
+ def update_config(cfg_dict: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
42
+ """
43
+ Function to update the configuration dictionary.
44
+ """
45
+ for key, value in kwargs.items():
46
+ cfg_dict[key] = value
47
+ return cfg_dict
src/data/.gitkeep ADDED
File without changes
src/data/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Dataset and custom collate function to load"""
2
+
3
+ from .collate import custom_collate
4
+ from .datasets import TextImageDataset
5
+ from .tokenizer import TAIMGANTokenizer
src/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (372 Bytes). View file
 
src/data/__pycache__/collate.cpython-39.pyc ADDED
Binary file (1.3 kB). View file
 
src/data/__pycache__/datasets.cpython-39.pyc ADDED
Binary file (11.8 kB). View file
 
src/data/__pycache__/tokenizer.cpython-39.pyc ADDED
Binary file (1.55 kB). View file
 
src/data/collate.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom collate function for the data loader."""
2
+
3
+ from typing import Any, List
4
+
5
+ import torch
6
+ from torch.nn.utils.rnn import pad_sequence
7
+
8
+
9
+ def custom_collate(batch: List[Any], device: Any) -> Any:
10
+ """
11
+ Custom collate function to be used in the data loader.
12
+ :param batch: list, with length equal to number of batches.
13
+ :return: processed batch of data [add padding to text, stack tensors in batch]
14
+ """
15
+ img, correct_capt, curr_class, word_labels = zip(*batch)
16
+ batched_img = torch.stack(img, dim=0).to(
17
+ device
18
+ ) # shape: (batch_size, 3, height, width)
19
+ correct_capt_len = torch.tensor(
20
+ [len(capt) for capt in correct_capt], dtype=torch.int64
21
+ ).unsqueeze(
22
+ 1
23
+ ) # shape: (batch_size, 1)
24
+ batched_correct_capt = pad_sequence(
25
+ correct_capt, batch_first=True, padding_value=0
26
+ ).to(
27
+ device
28
+ ) # shape: (batch_size, max_seq_len)
29
+ batched_curr_class = torch.stack(curr_class, dim=0).to(
30
+ device
31
+ ) # shape: (batch_size, 1)
32
+ batched_word_labels = pad_sequence(
33
+ word_labels, batch_first=True, padding_value=0
34
+ ).to(
35
+ device
36
+ ) # shape: (batch_size, max_seq_len)
37
+ return (
38
+ batched_img,
39
+ batched_correct_capt,
40
+ correct_capt_len,
41
+ batched_curr_class,
42
+ batched_word_labels,
43
+ )
src/data/datasets.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pytorch Dataset classes for the datasets used in the project."""
2
+
3
+ import os
4
+ import pickle
5
+ from collections import defaultdict
6
+ from typing import Any
7
+
8
+ import nltk
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+ import torchvision.transforms.functional as F
13
+ from nltk.tokenize import RegexpTokenizer
14
+ from PIL import Image
15
+ from torch.utils.data import Dataset
16
+ from torchvision import transforms
17
+
18
+
19
+ class TextImageDataset(Dataset): # type: ignore
20
+ """Custom PyTorch Dataset class to load Image and Text data."""
21
+
22
+ # pylint: disable=too-many-instance-attributes
23
+ # pylint: disable=too-many-locals
24
+ # pylint: disable=too-many-function-args
25
+
26
+ def __init__(
27
+ self, data_path: str, split: str, num_captions: int, transform: Any = None
28
+ ):
29
+ """
30
+ :param data_path: Path to the data directory. [i.e. can be './birds/', or './coco/]
31
+ :param split: 'train' or 'test' split
32
+ :param num_captions: number of captions present per image.
33
+ [For birds, this is 10, for coco, this is 5]
34
+ :param transform: PyTorch transform to apply to the images.
35
+ """
36
+ self.transform = transform
37
+ self.bound_box_map = None
38
+ self.file_names = self.load_filenames(data_path, split)
39
+ self.data_path = data_path
40
+ self.num_captions_per_image = num_captions
41
+ (
42
+ self.captions,
43
+ self.ix_to_word,
44
+ self.word_to_ix,
45
+ self.vocab_len,
46
+ ) = self.get_capt_and_vocab(data_path, split)
47
+ self.normalize = transforms.Compose(
48
+ [
49
+ transforms.ToTensor(),
50
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
51
+ ]
52
+ )
53
+ self.class_ids = self.get_class_id(data_path, split, len(self.file_names))
54
+ if self.data_path.endswith("birds/"):
55
+ self.bound_box_map = self.get_bound_box(data_path)
56
+
57
+ elif self.data_path.endswith("coco/"):
58
+ pass
59
+
60
+ else:
61
+ raise ValueError(
62
+ "Invalid data path. Please ensure the data [CUB/COCO] is stored in correct folders."
63
+ )
64
+
65
+ def __len__(self) -> int:
66
+ """Return the length of the dataset."""
67
+ return len(self.file_names)
68
+
69
+ def __getitem__(self, idx: int) -> Any:
70
+ """
71
+ Return the item at index idx.
72
+ :param idx: index of the item to return
73
+ :return img_tensor: image tensor
74
+ :return correct_caption: correct caption for the image [list of word indices]
75
+ :return curr_class_id: class id of the image
76
+ :return word_labels: POS_tagged word labels [1 for noun and adjective, 0 else]
77
+
78
+ """
79
+ file_name = self.file_names[idx]
80
+ curr_class_id = self.class_ids[idx]
81
+
82
+ if self.bound_box_map is not None:
83
+ bbox = self.bound_box_map[file_name]
84
+ images_dir = os.path.join(self.data_path, "CUB_200_2011/images")
85
+ else:
86
+ bbox = None
87
+ images_dir = os.path.join(self.data_path, "images")
88
+
89
+ img_path = os.path.join(images_dir, file_name + ".jpg")
90
+ img_tensor = self.get_image(img_path, bbox, self.transform)
91
+
92
+ rand_sent_idx = np.random.randint(0, self.num_captions_per_image)
93
+ rand_sent_idx = idx * self.num_captions_per_image + rand_sent_idx
94
+
95
+ correct_caption = torch.tensor(self.captions[rand_sent_idx], dtype=torch.int64)
96
+ num_words = len(correct_caption)
97
+
98
+ capt_token_list = []
99
+ for i in range(num_words):
100
+ capt_token_list.append(self.ix_to_word[correct_caption[i].item()])
101
+
102
+ pos_tag_list = nltk.tag.pos_tag(capt_token_list)
103
+ word_labels = []
104
+
105
+ for pos_tag in pos_tag_list:
106
+ if (
107
+ "NN" in pos_tag[1] or "JJ" in pos_tag[1]
108
+ ): # check for Nouns and Adjective only
109
+ word_labels.append(1)
110
+ else:
111
+ word_labels.append(0)
112
+
113
+ word_labels = torch.tensor(word_labels).float() # type: ignore
114
+
115
+ curr_class_id = torch.tensor(curr_class_id, dtype=torch.int64).unsqueeze(0)
116
+
117
+ return (
118
+ img_tensor,
119
+ correct_caption,
120
+ curr_class_id,
121
+ word_labels,
122
+ )
123
+
124
+ def get_capt_and_vocab(self, data_dir: str, split: str) -> Any:
125
+ """
126
+ Helper function to get the captions, vocab dict for each image.
127
+ :param data_dir: path to the data directory [i.e. './birds/' or './coco/']
128
+ :param split: 'train' or 'test' split
129
+ :return captions: list of all captions for each image
130
+ :return ix_to_word: dictionary mapping index to word
131
+ :return word_to_ix: dictionary mapping word to index
132
+ :return num_words: number of unique words in the vocabulary
133
+ """
134
+ captions_ckpt_path = os.path.join(data_dir, "stubs/captions.pickle")
135
+ if os.path.exists(
136
+ captions_ckpt_path
137
+ ): # check if previously processed captions exist
138
+ with open(captions_ckpt_path, "rb") as ckpt_file:
139
+ captions = pickle.load(ckpt_file)
140
+ train_captions, test_captions = captions[0], captions[1]
141
+ ix_to_word, word_to_ix = captions[2], captions[3]
142
+ num_words = len(ix_to_word)
143
+ del captions
144
+ if split == "train":
145
+ return train_captions, ix_to_word, word_to_ix, num_words
146
+ return test_captions, ix_to_word, word_to_ix, num_words
147
+
148
+ else: # if not, process the captions and save them
149
+ train_files = self.load_filenames(data_dir, "train")
150
+ test_files = self.load_filenames(data_dir, "test")
151
+
152
+ train_captions_tokenized = self.get_tokenized_captions(
153
+ data_dir, train_files
154
+ )
155
+ test_captions_tokenized = self.get_tokenized_captions(
156
+ data_dir, test_files
157
+ ) # we need both train and test captions to build the vocab
158
+
159
+ (
160
+ train_captions,
161
+ test_captions,
162
+ ix_to_word,
163
+ word_to_ix,
164
+ num_words,
165
+ ) = self.build_vocab( # type: ignore
166
+ train_captions_tokenized, test_captions_tokenized, split
167
+ )
168
+ vocab_list = [train_captions, test_captions, ix_to_word, word_to_ix]
169
+ with open(captions_ckpt_path, "wb") as ckpt_file:
170
+ pickle.dump(vocab_list, ckpt_file)
171
+
172
+ if split == "train":
173
+ return train_captions, ix_to_word, word_to_ix, num_words
174
+ if split == "test":
175
+ return test_captions, ix_to_word, word_to_ix, num_words
176
+ raise ValueError("Invalid split. Please use 'train' or 'test'")
177
+
178
+ def build_vocab(
179
+ self, tokenized_captions_train: list, tokenized_captions_test: list # type: ignore
180
+ ) -> Any:
181
+ """
182
+ Helper function which builds the vocab dicts.
183
+ :param tokenized_captions_train: list containing all the
184
+ train tokenized captions in the dataset. This is list of lists.
185
+ :param tokenized_captions_test: list containing all the
186
+ test tokenized captions in the dataset. This is list of lists.
187
+ :return train_captions_int: list of all captions in training,
188
+ where each word is replaced by its index in the vocab
189
+ :return test_captions_int: list of all captions in test,
190
+ where each word is replaced by its index in the vocab
191
+ :return ix_to_word: dictionary mapping index to word
192
+ :return word_to_ix: dictionary mapping word to index
193
+ :return num_words: number of unique words in the vocabulary
194
+ """
195
+ vocab = defaultdict(int) # type: ignore
196
+ total_captions = tokenized_captions_train + tokenized_captions_test
197
+ for caption in total_captions:
198
+ for word in caption:
199
+ vocab[word] += 1
200
+
201
+ # sort vocab dict by frequency in descending order
202
+ vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True) # type: ignore
203
+
204
+ ix_to_word = {}
205
+ word_to_ix = {}
206
+ ix_to_word[0] = "<end>"
207
+ word_to_ix["<end>"] = 0
208
+
209
+ word_idx = 1
210
+ for word, _ in vocab:
211
+ word_to_ix[word] = word_idx
212
+ ix_to_word[word_idx] = word
213
+ word_idx += 1
214
+
215
+ train_captions_int = [] # we want to convert words to indices in vocab.
216
+ for caption in tokenized_captions_train:
217
+ curr_caption_int = []
218
+ for word in caption:
219
+ curr_caption_int.append(word_to_ix[word])
220
+
221
+ train_captions_int.append(curr_caption_int)
222
+
223
+ test_captions_int = []
224
+ for caption in tokenized_captions_test:
225
+ curr_caption_int = []
226
+ for word in caption:
227
+ curr_caption_int.append(word_to_ix[word])
228
+
229
+ test_captions_int.append(curr_caption_int)
230
+
231
+ return (
232
+ train_captions_int,
233
+ test_captions_int,
234
+ ix_to_word,
235
+ word_to_ix,
236
+ len(ix_to_word),
237
+ )
238
+
239
+ def get_tokenized_captions(self, data_dir: str, filenames: list) -> Any: # type: ignore
240
+ """
241
+ Helper function to tokenize and return captions for each image in filenames.
242
+ :param data_dir: path to the data directory [i.e. './birds/' or './coco/']
243
+ :param filenames: list of all filenames corresponding to the split
244
+ :return tokenized_captions: list of all tokenized captions for all files in filenames.
245
+ [this returns a list, where each element is again a list of tokens/words]
246
+ """
247
+
248
+ all_captions = []
249
+ for filename in filenames:
250
+ caption_path = os.path.join(data_dir, "text", filename + ".txt")
251
+ with open(caption_path, "r", encoding="utf8") as txt_file:
252
+ captions = txt_file.readlines()
253
+ count = 0
254
+ for caption in captions:
255
+ if len(caption) == 0:
256
+ continue
257
+
258
+ caption = caption.replace("\ufffd\ufffd", " ")
259
+ tokenizer = RegexpTokenizer(r"\w+")
260
+ tokens = tokenizer.tokenize(
261
+ caption.lower()
262
+ ) # splits current caption/line to list of words/tokens
263
+ if len(tokens) == 0:
264
+ continue
265
+
266
+ tokens = [
267
+ t.encode("ascii", "ignore").decode("ascii") for t in tokens
268
+ ]
269
+ tokens = [t for t in tokens if len(t) > 0]
270
+
271
+ all_captions.append(tokens)
272
+ count += 1
273
+ if count == self.num_captions_per_image:
274
+ break
275
+ if count < self.num_captions_per_image:
276
+ raise ValueError(
277
+ f"Number of captions for {filename} is only {count},\
278
+ which is less than {self.num_captions_per_image}."
279
+ )
280
+
281
+ return all_captions
282
+
283
+ def get_image(self, img_path: str, bbox: list, transform: Any) -> Any: # type: ignore
284
+ """
285
+ Helper function to load and transform an image.
286
+ :param img_path: path to the image
287
+ :param bbox: bounding box coordinates [x, y, width, height]
288
+ :param transform: PyTorch transform to apply to the image
289
+ :return img_tensor: transformed image tensor
290
+ """
291
+ img = Image.open(img_path).convert("RGB")
292
+ width, height = img.size
293
+
294
+ if bbox is not None:
295
+ r_val = int(np.maximum(bbox[2], bbox[3]) * 0.75)
296
+
297
+ center_x = int((2 * bbox[0] + bbox[2]) / 2)
298
+ center_y = int((2 * bbox[1] + bbox[3]) / 2)
299
+ y1_coord = np.maximum(0, center_y - r_val)
300
+ y2_coord = np.minimum(height, center_y + r_val)
301
+ x1_coord = np.maximum(0, center_x - r_val)
302
+ x2_coord = np.minimum(width, center_x + r_val)
303
+
304
+ img = img.crop(
305
+ [x1_coord, y1_coord, x2_coord, y2_coord]
306
+ ) # This preprocessing steps seems to follow from
307
+ # Stackgan: Text to photo-realistic image synthesis
308
+
309
+ if transform is not None:
310
+ img_tensor = transform(img) # this scales to 304x304, i.e. 256 x (76/64).
311
+ x_val = np.random.randint(0, 48) # 304 - 256 = 48
312
+ y_val = np.random.randint(0, 48)
313
+ flip = np.random.rand() > 0.5
314
+
315
+ # crop
316
+ img_tensor = img_tensor.crop(
317
+ [x_val, y_val, x_val + 256, y_val + 256]
318
+ ) # this crops to 256x256
319
+ if flip:
320
+ img_tensor = F.hflip(img_tensor)
321
+
322
+ img_tensor = self.normalize(img_tensor)
323
+
324
+ return img_tensor
325
+
326
+ def load_filenames(self, data_dir: str, split: str) -> Any:
327
+ """
328
+ Helper function to get list of all image filenames.
329
+ :param data_dir: path to the data directory [i.e. './birds/' or './coco/']
330
+ :param split: 'train' or 'test' split
331
+ :return filenames: list of all image filenames
332
+ """
333
+ filepath = f"{data_dir}{split}/filenames.pickle"
334
+ if os.path.isfile(filepath):
335
+ with open(filepath, "rb") as pick_file:
336
+ filenames = pickle.load(pick_file)
337
+ else:
338
+ raise ValueError(
339
+ "Invalid split. Please use 'train' or 'test',\
340
+ or make sure the filenames.pickle file exists."
341
+ )
342
+ return filenames
343
+
344
+ def get_class_id(self, data_dir: str, split: str, total_elems: int) -> Any:
345
+ """
346
+ Helper function to get list of all image class ids.
347
+ :param data_dir: path to the data directory [i.e. './birds/' or './coco/']
348
+ :param split: 'train' or 'test' split
349
+ :param total_elems: total number of elements in the dataset
350
+ :return class_ids: list of all image class ids
351
+ """
352
+ filepath = f"{data_dir}{split}/class_info.pickle"
353
+ if os.path.isfile(filepath):
354
+ with open(filepath, "rb") as class_file:
355
+ class_ids = pickle.load(class_file, encoding="latin1")
356
+ else:
357
+ class_ids = np.arange(total_elems)
358
+ return class_ids
359
+
360
+ def get_bound_box(self, data_path: str) -> Any:
361
+ """
362
+ Helper function to get the bounding box for birds dataset.
363
+ :param data_path: path to birds data directory [i.e. './data/birds/']
364
+ :return imageToBox: dictionary mapping image name to bounding box coordinates
365
+ """
366
+ bbox_path = os.path.join(data_path, "CUB_200_2011/bounding_boxes.txt")
367
+ df_bounding_boxes = pd.read_csv(
368
+ bbox_path, delim_whitespace=True, header=None
369
+ ).astype(int)
370
+
371
+ filepath = os.path.join(data_path, "CUB_200_2011/images.txt")
372
+ df_filenames = pd.read_csv(filepath, delim_whitespace=True, header=None)
373
+ filenames = df_filenames[
374
+ 1
375
+ ].tolist() # df_filenames[0] just contains the index or ID.
376
+
377
+ img_to_box = { # type: ignore
378
+ img_file[:-4]: [] for img_file in filenames
379
+ } # remove the .jpg extension from the names
380
+ num_imgs = len(filenames)
381
+
382
+ for i in range(0, num_imgs):
383
+ bbox = df_bounding_boxes.iloc[i][1:].tolist()
384
+ key = filenames[i][:-4]
385
+ img_to_box[key] = bbox
386
+
387
+ return img_to_box
src/data/stubs/bird.jpg ADDED
src/data/stubs/pigeon.jpg ADDED
src/data/stubs/rohit.jpeg ADDED
src/data/tokenizer.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import re
3
+ from typing import List
4
+
5
+
6
+ class TAIMGANTokenizer:
7
+ def __init__(self, captions_path):
8
+ with open(captions_path, "rb") as ckpt_file:
9
+ captions = pickle.load(ckpt_file)
10
+ self.ix_to_word = captions[2]
11
+ self.word_to_ix = captions[3]
12
+ self.token_regex = r'\w+'
13
+ self.pad_token_id = self.word_to_ix["<end>"]
14
+ self.pad_repr = "[PAD]"
15
+
16
+ def encode(self, text: str) -> List[int]:
17
+ return [self.word_to_ix.get(word, self.pad_token_id)
18
+ for word in re.findall(self.token_regex, text.lower())]
19
+
20
+ def decode(self, tokens: List[int]) -> str:
21
+ return ' '.join([self.ix_to_word[token]
22
+ if token != self.pad_token_id else self.pad_repr
23
+ for token in tokens])
src/features/.gitkeep ADDED
File without changes
src/features/__init__.py ADDED
File without changes
src/features/build_features.py ADDED
File without changes
src/models/.gitkeep ADDED
File without changes
src/models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """Helper functions for training loop."""
2
+ from .losses import discriminator_loss, generator_loss, kl_loss
3
+ from .train_model import train
4
+ from .utils import copy_gen_params, define_optimizers, load_params, prepare_labels
src/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (461 Bytes). View file
 
src/models/__pycache__/losses.cpython-39.pyc ADDED
Binary file (8.36 kB). View file
 
src/models/__pycache__/train_model.cpython-39.pyc ADDED
Binary file (3.82 kB). View file
 
src/models/__pycache__/utils.cpython-39.pyc ADDED
Binary file (8.76 kB). View file
 
src/models/losses.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing the loss functions for the GANs."""
2
+ from typing import Any, Dict
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ # pylint: disable=too-many-arguments
8
+ # pylint: disable=too-many-locals
9
+
10
+
11
+ def generator_loss(
12
+ logits: Dict[str, Dict[str, torch.Tensor]],
13
+ local_fake_incept_feat: torch.Tensor,
14
+ global_fake_incept_feat: torch.Tensor,
15
+ real_labels: torch.Tensor,
16
+ words_emb: torch.Tensor,
17
+ sent_emb: torch.Tensor,
18
+ match_labels: torch.Tensor,
19
+ cap_lens: torch.Tensor,
20
+ class_ids: torch.Tensor,
21
+ real_vgg_feat: torch.Tensor,
22
+ fake_vgg_feat: torch.Tensor,
23
+ const_dict: Dict[str, float],
24
+ ) -> Any:
25
+ """Calculate the loss for the generator.
26
+
27
+ Args:
28
+ logits: Dictionary with fake/real and word-level/uncond/cond logits
29
+
30
+ local_fake_incept_feat: The local inception features for the fake images.
31
+
32
+ global_fake_incept_feat: The global inception features for the fake images.
33
+
34
+ real_labels: Label for "real" image as predicted by discriminator,
35
+ this is a tensor of ones. [shape: (batch_size, 1)].
36
+
37
+ word_labels: POS tagged word labels for the captions. [shape: (batch_size, L)]
38
+
39
+ words_emb: The embeddings for all the words in the captions.
40
+ shape: (batch_size, embedding_size, max_caption_length)
41
+
42
+ sent_emb: The embeddings for the sentences.
43
+ shape: (batch_size, embedding_size)
44
+
45
+ match_labels: Tensor of shape: (batch_size, 1).
46
+ This is of the form torch.tensor([0, 1, 2, ..., batch-1])
47
+
48
+ cap_lens: The length of the 'actual' captions in the batch [without padding]
49
+ shape: (batch_size, 1)
50
+
51
+ class_ids: The class ids for the instance. shape: (batch_size, 1)
52
+
53
+ real_vgg_feat: The vgg features for the real images. shape: (batch_size, 128, 128, 128)
54
+ fake_vgg_feat: The vgg features for the fake images. shape: (batch_size, 128, 128, 128)
55
+
56
+ const_dict: The dictionary containing the constants.
57
+ """
58
+ lambda1 = const_dict["lambda1"]
59
+ total_error_g = 0.0
60
+
61
+ cond_logits = logits["fake"]["cond"]
62
+ cond_err_g = nn.BCEWithLogitsLoss()(cond_logits, real_labels)
63
+
64
+ uncond_logits = logits["fake"]["uncond"]
65
+ uncond_err_g = nn.BCEWithLogitsLoss()(uncond_logits, real_labels)
66
+
67
+ # add up the conditional and unconditional losses
68
+ loss_g = cond_err_g + uncond_err_g
69
+ total_error_g += loss_g
70
+
71
+ # DAMSM Loss from attnGAN.
72
+ loss_damsm = damsm_loss(
73
+ local_fake_incept_feat,
74
+ global_fake_incept_feat,
75
+ words_emb,
76
+ sent_emb,
77
+ match_labels,
78
+ cap_lens,
79
+ class_ids,
80
+ const_dict,
81
+ )
82
+
83
+ total_error_g += loss_damsm
84
+
85
+ loss_per = 0.5 * nn.MSELoss()(real_vgg_feat, fake_vgg_feat) # perceptual loss
86
+
87
+ total_error_g += lambda1 * loss_per
88
+
89
+ return total_error_g
90
+
91
+
92
+ def damsm_loss(
93
+ local_incept_feat: torch.Tensor,
94
+ global_incept_feat: torch.Tensor,
95
+ words_emb: torch.Tensor,
96
+ sent_emb: torch.Tensor,
97
+ match_labels: torch.Tensor,
98
+ cap_lens: torch.Tensor,
99
+ class_ids: torch.Tensor,
100
+ const_dict: Dict[str, float],
101
+ ) -> Any:
102
+ """Calculate the DAMSM loss from the attnGAN paper.
103
+
104
+ Args:
105
+ local_incept_feat: The local inception features. [shape: (batch, D, 17, 17)]
106
+
107
+ global_incept_feat: The global inception features. [shape: (batch, D)]
108
+
109
+ words_emb: The embeddings for all the words in the captions.
110
+
111
+ shape: (batch, D, max_caption_length)
112
+
113
+ sent_emb: The embeddings for the sentences. shape: (batch_size, D)
114
+
115
+ match_labels: Tensor of shape: (batch_size, 1).
116
+ This is of the form torch.tensor([0, 1, 2, ..., batch-1])
117
+
118
+ cap_lens: The length of the 'actual' captions in the batch [without padding]
119
+ shape: (batch_size, 1)
120
+
121
+ class_ids: The class ids for the instance. shape: (batch, 1)
122
+
123
+ const_dict: The dictionary containing the constants.
124
+ """
125
+ batch_size = match_labels.size(0)
126
+ # Mask mis-match samples, that come from the same class as the real sample
127
+ masks = []
128
+
129
+ match_scores = []
130
+ gamma1 = const_dict["gamma1"]
131
+ gamma2 = const_dict["gamma2"]
132
+ gamma3 = const_dict["gamma3"]
133
+ lambda3 = const_dict["lambda3"]
134
+
135
+ for i in range(batch_size):
136
+ mask = (class_ids == class_ids[i]).int()
137
+ # This ensures that "correct class" index is not included in the mask.
138
+ mask[i] = 0
139
+ masks.append(mask.reshape(1, -1)) # shape: (1, batch)
140
+
141
+ numb_words = int(cap_lens[i])
142
+ # shape: (1, D, L), this picks the caption at ith batch index.
143
+ query_words = words_emb[i, :, :numb_words].unsqueeze(0)
144
+ # shape: (batch, D, L), this expands the same caption for all batch indices.
145
+ query_words = query_words.repeat(batch_size, 1, 1)
146
+
147
+ c_i = compute_region_context_vector(
148
+ local_incept_feat, query_words, gamma1
149
+ ) # Taken from attnGAN paper. shape: (batch, D, L)
150
+
151
+ query_words = query_words.transpose(1, 2) # shape: (batch, L, D)
152
+ c_i = c_i.transpose(1, 2) # shape: (batch, L, D)
153
+ query_words = query_words.reshape(
154
+ batch_size * numb_words, -1
155
+ ) # shape: (batch * L, D)
156
+ c_i = c_i.reshape(batch_size * numb_words, -1) # shape: (batch * L, D)
157
+
158
+ r_i = compute_relevance(
159
+ c_i, query_words
160
+ ) # cosine similarity, or R(c_i, e_i) from attnGAN paper. shape: (batch * L, 1)
161
+ r_i = r_i.view(batch_size, numb_words) # shape: (batch, L)
162
+ r_i = torch.exp(r_i * gamma2) # shape: (batch, L)
163
+ r_i = r_i.sum(dim=1, keepdim=True) # shape: (batch, 1)
164
+ r_i = torch.log(
165
+ r_i
166
+ ) # This is image-text matching score b/w whole image and caption, shape: (batch, 1)
167
+ match_scores.append(r_i)
168
+
169
+ masks = torch.cat(masks, dim=0).bool() # type: ignore
170
+ match_scores = torch.cat(match_scores, dim=1) # type: ignore
171
+
172
+ # This corresponds to P(D|Q) from attnGAN.
173
+ match_scores = gamma3 * match_scores # type: ignore
174
+ match_scores.data.masked_fill_( # type: ignore
175
+ masks, -float("inf")
176
+ ) # mask out the scores for mis-matched samples
177
+
178
+ match_scores_t = match_scores.transpose( # type: ignore
179
+ 0, 1
180
+ ) # This corresponds to P(Q|D) from attnGAN.
181
+
182
+ # This corresponds to L1_w from attnGAN.
183
+ l1_w = nn.CrossEntropyLoss()(match_scores, match_labels)
184
+ # This corresponds to L2_w from attnGAN.
185
+ l2_w = nn.CrossEntropyLoss()(match_scores_t, match_labels)
186
+
187
+ incept_feat_norm = torch.linalg.norm(global_incept_feat, dim=1)
188
+ sent_emb_norm = torch.linalg.norm(sent_emb, dim=1)
189
+
190
+ # shape: (batch, batch)
191
+ global_match_score = global_incept_feat @ (sent_emb.T)
192
+
193
+ global_match_score = (
194
+ global_match_score / torch.outer(incept_feat_norm, sent_emb_norm)
195
+ ).clamp(min=1e-8)
196
+ global_match_score = gamma3 * global_match_score
197
+
198
+ # mask out the scores for mis-matched samples
199
+ global_match_score.data.masked_fill_(masks, -float("inf")) # type: ignore
200
+
201
+ global_match_t = global_match_score.T # shape: (batch, batch)
202
+
203
+ # This corresponds to L1_s from attnGAN.
204
+ l1_s = nn.CrossEntropyLoss()(global_match_score, match_labels)
205
+ # This corresponds to L2_s from attnGAN.
206
+ l2_s = nn.CrossEntropyLoss()(global_match_t, match_labels)
207
+
208
+ loss_damsm = lambda3 * (l1_w + l2_w + l1_s + l2_s)
209
+
210
+ return loss_damsm
211
+
212
+
213
+ def compute_relevance(c_i: torch.Tensor, query_words: torch.Tensor) -> Any:
214
+ """Computes the cosine similarity between the region context vector and the query words.
215
+
216
+ Args:
217
+ c_i: The region context vector. shape: (batch * L, D)
218
+ query_words: The query words. shape: (batch * L, D)
219
+ """
220
+ prod = c_i * query_words # shape: (batch * L, D)
221
+ numr = torch.sum(prod, dim=1) # shape: (batch * L, 1)
222
+ norm_c = torch.linalg.norm(c_i, ord=2, dim=1)
223
+ norm_q = torch.linalg.norm(query_words, ord=2, dim=1)
224
+ denr = norm_c * norm_q
225
+ r_i = (numr / denr).clamp(min=1e-8).squeeze() # shape: (batch * L, 1)
226
+ return r_i
227
+
228
+
229
+ def compute_region_context_vector(
230
+ local_incept_feat: torch.Tensor, query_words: torch.Tensor, gamma1: float
231
+ ) -> Any:
232
+ """Compute the region context vector (c_i) from attnGAN paper.
233
+
234
+ Args:
235
+ local_incept_feat: The local inception features. [shape: (batch, D, 17, 17)]
236
+ query_words: The embeddings for all the words in the captions. shape: (batch, D, L)
237
+ gamma1: The gamma1 value from attnGAN paper.
238
+ """
239
+ batch, L = query_words.size(0), query_words.size(2) # pylint: disable=invalid-name
240
+
241
+ feat_height, feat_width = local_incept_feat.size(2), local_incept_feat.size(3)
242
+ N = feat_height * feat_width # pylint: disable=invalid-name
243
+
244
+ # Reshape the local inception features to (batch, D, N)
245
+ local_incept_feat = local_incept_feat.view(batch, -1, N)
246
+ # shape: (batch, N, D)
247
+ incept_feat_t = local_incept_feat.transpose(1, 2)
248
+
249
+ sim_matrix = incept_feat_t @ query_words # shape: (batch, N, L)
250
+ sim_matrix = sim_matrix.view(batch * N, L) # shape: (batch * N, L)
251
+
252
+ sim_matrix = nn.Softmax(dim=1)(sim_matrix) # shape: (batch * N, L)
253
+ sim_matrix = sim_matrix.view(batch, N, L) # shape: (batch, N, L)
254
+
255
+ sim_matrix = torch.transpose(sim_matrix, 1, 2) # shape: (batch, L, N)
256
+ sim_matrix = sim_matrix.reshape(batch * L, N) # shape: (batch * L, N)
257
+
258
+ alpha_j = gamma1 * sim_matrix # shape: (batch * L, N)
259
+ alpha_j = nn.Softmax(dim=1)(alpha_j) # shape: (batch * L, N)
260
+ alpha_j = alpha_j.view(batch, L, N) # shape: (batch, L, N)
261
+ alpha_j_t = torch.transpose(alpha_j, 1, 2) # shape: (batch, N, L)
262
+
263
+ c_i = (
264
+ local_incept_feat @ alpha_j_t
265
+ ) # shape: (batch, D, L) [summing over N dimension in paper, so we multiply like this]
266
+ return c_i
267
+
268
+
269
+ def discriminator_loss(
270
+ logits: Dict[str, Dict[str, torch.Tensor]],
271
+ labels: Dict[str, Dict[str, torch.Tensor]],
272
+ ) -> Any:
273
+ """
274
+ Calculate discriminator objective
275
+
276
+ :param dict[str, dict[str, torch.Tensor]] logits:
277
+ Dictionary with fake/real and word-level/uncond/cond logits
278
+
279
+ Example:
280
+
281
+ logits = {
282
+ "fake": {
283
+ "word_level": torch.Tensor (BxL)
284
+ "uncond": torch.Tensor (Bx1)
285
+ "cond": torch.Tensor (Bx1)
286
+ },
287
+ "real": {
288
+ "word_level": torch.Tensor (BxL)
289
+ "uncond": torch.Tensor (Bx1)
290
+ "cond": torch.Tensor (Bx1)
291
+ },
292
+ }
293
+ :param dict[str, dict[str, torch.Tensor]] labels:
294
+ Dictionary with fake/real and word-level/image labels
295
+
296
+ Example:
297
+
298
+ labels = {
299
+ "fake": {
300
+ "word_level": torch.Tensor (BxL)
301
+ "image": torch.Tensor (Bx1)
302
+ },
303
+ "real": {
304
+ "word_level": torch.Tensor (BxL)
305
+ "image": torch.Tensor (Bx1)
306
+ },
307
+ }
308
+ :param float lambda_4: Hyperparameter for word loss in paper
309
+ :return: Discriminator objective loss
310
+ :rtype: Any
311
+ """
312
+ # define main loss functions for logit losses
313
+ tot_loss = 0.0
314
+ bce_logits = nn.BCEWithLogitsLoss()
315
+ bce = nn.BCELoss()
316
+ # calculate word-level loss
317
+ word_loss = bce(logits["real"]["word_level"], labels["real"]["word_level"])
318
+ # calculate unconditional adversarial loss
319
+ uncond_loss = bce_logits(logits["real"]["uncond"], labels["real"]["image"])
320
+
321
+ # calculate conditional adversarial loss
322
+ cond_loss = bce_logits(logits["real"]["cond"], labels["real"]["image"])
323
+
324
+ tot_loss = (uncond_loss + cond_loss) / 2.0
325
+
326
+ fake_uncond_loss = bce_logits(logits["fake"]["uncond"], labels["fake"]["image"])
327
+ fake_cond_loss = bce_logits(logits["fake"]["cond"], labels["fake"]["image"])
328
+
329
+ tot_loss += (fake_uncond_loss + fake_cond_loss) / 3.0
330
+ tot_loss += word_loss
331
+
332
+ return tot_loss
333
+
334
+
335
+ def kl_loss(mu_tensor: torch.Tensor, logvar: torch.Tensor) -> Any:
336
+ """
337
+ Calculate KL loss
338
+
339
+ :param torch.Tensor mu_tensor: Mean of latent distribution
340
+ :param torch.Tensor logvar: Log variance of latent distribution
341
+ :return: KL loss [-0.5 * (1 + log(sigma) - mu^2 - sigma^2)]
342
+ :rtype: Any
343
+ """
344
+ return torch.mean(-0.5 * (1 + 0.5 * logvar - mu_tensor.pow(2) - torch.exp(logvar)))
src/models/modules/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """All the modules used in creation of Generator and Discriminator"""
2
+ from .acm import ACM
3
+ from .attention import ChannelWiseAttention, SpatialAttention
4
+ from .cond_augment import CondAugmentation
5
+ from .conv_utils import calc_out_conv, conv1d, conv2d
6
+ from .discriminator import Discriminator, WordLevelLogits
7
+ from .downsample import down_sample
8
+ from .generator import Generator
9
+ from .image_encoder import InceptionEncoder, VGGEncoder
10
+ from .residual import ResidualBlock
11
+ from .text_encoder import TextEncoder
12
+ from .upsample import img_up_block, up_sample
src/models/modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (891 Bytes). View file
 
src/models/modules/__pycache__/acm.cpython-39.pyc ADDED
Binary file (1.66 kB). View file
 
src/models/modules/__pycache__/attention.cpython-39.pyc ADDED
Binary file (3.38 kB). View file
 
src/models/modules/__pycache__/cond_augment.cpython-39.pyc ADDED
Binary file (2.52 kB). View file
 
src/models/modules/__pycache__/conv_utils.cpython-39.pyc ADDED
Binary file (2.37 kB). View file
 
src/models/modules/__pycache__/discriminator.cpython-39.pyc ADDED
Binary file (5.1 kB). View file
 
src/models/modules/__pycache__/downsample.cpython-39.pyc ADDED
Binary file (598 Bytes). View file
 
src/models/modules/__pycache__/generator.cpython-39.pyc ADDED
Binary file (9.03 kB). View file
 
src/models/modules/__pycache__/image_encoder.cpython-39.pyc ADDED
Binary file (4.27 kB). View file
 
src/models/modules/__pycache__/residual.cpython-39.pyc ADDED
Binary file (1.31 kB). View file
 
src/models/modules/__pycache__/text_encoder.cpython-39.pyc ADDED
Binary file (1.92 kB). View file
 
src/models/modules/__pycache__/upsample.cpython-39.pyc ADDED
Binary file (983 Bytes). View file
 
src/models/modules/acm.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ACM and its variations"""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from .conv_utils import conv2d
9
+
10
+
11
+ class ACM(nn.Module):
12
+ """Affine Combination Module from ManiGAN"""
13
+
14
+ def __init__(self, img_chans: int, text_chans: int, inner_dim: int = 64) -> None:
15
+ """
16
+ Initialize the convolutional layers
17
+
18
+ :param int img_chans: Channels in visual input
19
+ :param int text_chans: Channels of textual input
20
+ :param int inner_dim: Hyperparameters for inner dimensionality of features
21
+ """
22
+ super().__init__()
23
+ self.conv = conv2d(in_channels=img_chans, out_channels=inner_dim)
24
+ self.weights = conv2d(in_channels=inner_dim, out_channels=text_chans)
25
+ self.biases = conv2d(in_channels=inner_dim, out_channels=text_chans)
26
+
27
+ def forward(self, text: torch.Tensor, img: torch.Tensor) -> Any:
28
+ """
29
+ Propagate the textual and visual input through the ACM module
30
+
31
+ :param torch.Tensor text: Textual input (can be hidden features)
32
+ :param torch.Tensor img: Image input
33
+ :return: Affine combination of text and image
34
+ :rtype: torch.Tensor
35
+ """
36
+ img_features = self.conv(img)
37
+ return text * self.weights(img_features) + self.biases(img_features)
src/models/modules/attention.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Attention modules"""
2
+ from typing import Any, Optional
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from src.models.modules.conv_utils import conv1d
8
+
9
+
10
+ class ChannelWiseAttention(nn.Module):
11
+ """ChannelWise attention adapted from ControlGAN"""
12
+
13
+ def __init__(self, fm_size: int, text_d: int) -> None:
14
+ """
15
+ Initialize the Channel-Wise attention module
16
+
17
+ :param int fm_size:
18
+ Height and width of feature map on k-th iteration of forward-pass.
19
+ In paper, it's H_k * W_k
20
+ :param int text_d: Dimensionality of sentence. From paper, it's D
21
+ """
22
+ super().__init__()
23
+ # perception layer
24
+ self.text_conv = conv1d(text_d, fm_size)
25
+ # attention across channel dimension
26
+ self.softmax = nn.Softmax(2)
27
+
28
+ def forward(self, v_k: torch.Tensor, w_text: torch.Tensor) -> Any:
29
+ """
30
+ Apply attention to visual features taking into account features of words
31
+
32
+ :param torch.Tensor v_k: Visual context
33
+ :param torch.Tensor w_text: Textual features
34
+ :return: Fused hidden visual features and word features
35
+ :rtype: Any
36
+ """
37
+ w_hat = self.text_conv(w_text)
38
+ m_k = v_k @ w_hat
39
+ a_k = self.softmax(m_k)
40
+ w_hat = torch.transpose(w_hat, 1, 2)
41
+ return a_k @ w_hat
42
+
43
+
44
+ class SpatialAttention(nn.Module):
45
+ """Spatial attention module for attending textual context to visual features"""
46
+
47
+ def __init__(self, d: int, d_hat: int) -> None:
48
+ """
49
+ Set up softmax and conv layers
50
+
51
+ :param int d: Initial embedding size for textual features. D from paper
52
+ :param int d_hat: Height of image feature map. D_hat from paper
53
+ """
54
+ super().__init__()
55
+ self.softmax = nn.Softmax(2)
56
+ self.conv = conv1d(d, d_hat)
57
+
58
+ def forward(
59
+ self,
60
+ text_context: torch.Tensor,
61
+ image: torch.Tensor,
62
+ mask: Optional[torch.Tensor] = None,
63
+ ) -> Any:
64
+ """
65
+ Project image features into the latent space
66
+ of textual features and apply attention
67
+
68
+ :param torch.Tensor text_context: D x T tensor of hidden textual features
69
+ :param torch.Tensor image: D_hat x N visual features
70
+ :param Optional[torch.Tensor] mask:
71
+ Boolean tensor for masking the padded words. BxL
72
+ :return: Word features attended by visual features
73
+ :rtype: Any
74
+ """
75
+ # number of features on image feature map H * W
76
+ feature_num = image.size(2)
77
+ # number of words in caption
78
+ len_caption = text_context.size(2)
79
+ text_context = self.conv(text_context)
80
+ image = torch.transpose(image, 1, 2)
81
+ s_i_j = image @ text_context
82
+ if mask is not None:
83
+ # duplicating mask and aligning dims with s_i_j
84
+ mask = mask.repeat(1, feature_num).view(-1, feature_num, len_caption)
85
+ s_i_j[mask] = -float("inf")
86
+ b_i_j = self.softmax(s_i_j)
87
+ c_i_j = b_i_j @ torch.transpose(text_context, 1, 2)
88
+ return torch.transpose(c_i_j, 1, 2)
src/models/modules/cond_augment.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Conditioning Augmentation Module"""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+
9
+ class CondAugmentation(nn.Module):
10
+ """Conditioning Augmentation Module"""
11
+
12
+ def __init__(self, D: int, conditioning_dim: int):
13
+ """
14
+ :param D: Dimension of the text embedding space [D from AttnGAN paper]
15
+ :param conditioning_dim: Dimension of the conditioning space
16
+ """
17
+ super().__init__()
18
+ self.cond_dim = conditioning_dim
19
+ self.cond_augment = nn.Linear(D, conditioning_dim * 4, bias=True)
20
+ self.glu = nn.GLU(dim=1)
21
+
22
+ def encode(self, text_embedding: torch.Tensor) -> Any:
23
+ """
24
+ This function encodes the text embedding into the conditioning space
25
+ :param text_embedding: Text embedding
26
+ :return: Conditioning embedding
27
+ """
28
+ x_tensor = self.glu(self.cond_augment(text_embedding))
29
+ mu_tensor = x_tensor[:, : self.cond_dim]
30
+ logvar = x_tensor[:, self.cond_dim :]
31
+ return mu_tensor, logvar
32
+
33
+ def sample(self, mu_tensor: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
34
+ """
35
+ This function samples from the Gaussian distribution
36
+ :param mu: Mean of the Gaussian distribution
37
+ :param logvar: Log variance of the Gaussian distribution
38
+ :return: Sample from the Gaussian distribution
39
+ """
40
+ std = torch.exp(0.5 * logvar)
41
+ eps = torch.randn_like(
42
+ std
43
+ ) # check if this should add requires_grad = True to this tensor?
44
+ return mu_tensor + eps * std
45
+
46
+ def forward(self, text_embedding: torch.Tensor) -> Any:
47
+ """
48
+ This function encodes the text embedding into the conditioning space,
49
+ and samples from the Gaussian distribution.
50
+ :param text_embedding: Text embedding
51
+ :return c_hat: Conditioning embedding (C^ from StackGAN++ paper)
52
+ :return mu: Mean of the Gaussian distribution
53
+ :return logvar: Log variance of the Gaussian distribution
54
+ """
55
+ mu_tensor, logvar = self.encode(text_embedding)
56
+ c_hat = self.sample(mu_tensor, logvar)
57
+ return c_hat, mu_tensor, logvar
src/models/modules/conv_utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Frequently used convolution modules"""
2
+
3
+ from torch import nn
4
+
5
+ from typing import Tuple
6
+
7
+
8
+ def conv2d(
9
+ in_channels: int,
10
+ out_channels: int,
11
+ kernel_size: int = 3,
12
+ stride: int = 1,
13
+ padding: int = 1,
14
+ ) -> nn.Conv2d:
15
+ """
16
+ Template convolution which is typically used throughout the project
17
+
18
+ :param int in_channels: Number of input channels
19
+ :param int out_channels: Number of output channels
20
+ :param int kernel_size: Size of sliding kernel
21
+ :param int stride: How many steps kernel does when sliding
22
+ :param int padding: How many dimensions to pad
23
+ :return: Convolution layer with parameters
24
+ :rtype: nn.Conv2d
25
+ """
26
+ return nn.Conv2d(
27
+ in_channels=in_channels,
28
+ out_channels=out_channels,
29
+ kernel_size=kernel_size,
30
+ stride=stride,
31
+ padding=padding,
32
+ )
33
+
34
+
35
+ def conv1d(
36
+ in_channels: int,
37
+ out_channels: int,
38
+ kernel_size: int = 1,
39
+ stride: int = 1,
40
+ padding: int = 0,
41
+ ) -> nn.Conv1d:
42
+ """
43
+ Template 1d convolution which is typically used throughout the project
44
+
45
+ :param int in_channels: Number of input channels
46
+ :param int out_channels: Number of output channels
47
+ :param int kernel_size: Size of sliding kernel
48
+ :param int stride: How many steps kernel does when sliding
49
+ :param int padding: How many dimensions to pad
50
+ :return: Convolution layer with parameters
51
+ :rtype: nn.Conv2d
52
+ """
53
+ return nn.Conv1d(
54
+ in_channels=in_channels,
55
+ out_channels=out_channels,
56
+ kernel_size=kernel_size,
57
+ stride=stride,
58
+ padding=padding,
59
+ )
60
+
61
+
62
+ def calc_out_conv(
63
+ h_in: int, w_in: int, kernel_size: int = 3, stride: int = 1, padding: int = 0
64
+ ) -> Tuple[int, int]:
65
+ """
66
+ Calculate the dimensionalities of images propagated through conv layers
67
+
68
+ :param h_in: Height of the image
69
+ :param w_in: Width of the image
70
+ :param kernel_size: Size of sliding kernel
71
+ :param stride: How many steps kernel does when sliding
72
+ :param padding: How many dimensions to pad
73
+ :return: Height and width of image through convolution
74
+ :rtype: tuple[int, int]
75
+ """
76
+ h_out = int((h_in + 2 * padding - kernel_size) / stride + 1)
77
+ w_out = int((w_in + 2 * padding - kernel_size) / stride + 1)
78
+ return h_out, w_out
src/models/modules/discriminator.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Discriminator providing word-level feedback"""
2
+ from typing import Any
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from src.models.modules.conv_utils import conv1d, conv2d
8
+ from src.models.modules.image_encoder import InceptionEncoder
9
+
10
+
11
+ class WordLevelLogits(nn.Module):
12
+ """API for converting regional feature maps into logits for multi-class classification"""
13
+
14
+ def __init__(self) -> None:
15
+ """
16
+ Instantiate the module with softmax on channel dimension
17
+ """
18
+ super().__init__()
19
+ self.softmax = nn.Softmax(dim=1)
20
+ # layer for flattening the feature maps
21
+ self.flat = nn.Flatten(start_dim=2)
22
+ # change dism of of textual embs to correlate with chans of inception
23
+ self.chan_reduction = conv1d(256, 128)
24
+
25
+ def forward(
26
+ self, visual_features: torch.Tensor, word_embs: torch.Tensor, mask: torch.Tensor
27
+ ) -> Any:
28
+ """
29
+ Fuse two types of features together to get output for feeding into the classification loss
30
+ :param torch.Tensor visual_features:
31
+ Feature maps of an image after being processed by Inception encoder. Bx128x17x17
32
+ :param torch.Tensor word_embs:
33
+ Word-level embeddings from the text encoder Bx256xL
34
+ :return: Logits for each word in the picture. BxL
35
+ :rtype: Any
36
+ """
37
+ # make textual and visual features have the same amount of channels
38
+ word_embs = self.chan_reduction(word_embs)
39
+ # flattening the feature maps
40
+ visual_features = self.flat(visual_features)
41
+ word_embs = torch.transpose(word_embs, 1, 2)
42
+ word_region_correlations = word_embs @ visual_features
43
+ # normalize across L dimension
44
+ m_norm_l = nn.functional.normalize(word_region_correlations, dim=1)
45
+ # normalize across H*W dimension
46
+ m_norm_hw = nn.functional.normalize(m_norm_l, dim=2)
47
+ m_norm_hw = torch.transpose(m_norm_hw, 1, 2)
48
+ weighted_img_feats = visual_features @ m_norm_hw
49
+ weighted_img_feats = torch.sum(weighted_img_feats, dim=1)
50
+ weighted_img_feats[mask] = -float("inf")
51
+ deltas = self.softmax(weighted_img_feats)
52
+ return deltas
53
+
54
+
55
+ class UnconditionalLogits(nn.Module):
56
+ """Head for retrieving logits from an image"""
57
+
58
+ def __init__(self) -> None:
59
+ """Initialize modules that reduce the features down to a set of logits"""
60
+ super().__init__()
61
+ self.conv = nn.Conv2d(128, 1, kernel_size=17)
62
+ # flattening BxLx1x1 into Bx1
63
+ self.flat = nn.Flatten()
64
+
65
+ def forward(self, visual_features: torch.Tensor) -> Any:
66
+ """
67
+ Compute logits for unconditioned adversarial loss
68
+
69
+ :param visual_features: Local features from Inception network. Bx128x17x17
70
+ :return: Logits for unconditioned adversarial loss. Bx1
71
+ :rtype: Any
72
+ """
73
+ # reduce channels and feature maps for visual features
74
+ visual_features = self.conv(visual_features)
75
+ # flatten Bx1x1x1 into Bx1
76
+ logits = self.flat(visual_features)
77
+ return logits
78
+
79
+
80
+ class ConditionalLogits(nn.Module):
81
+ """Logits extractor for conditioned adversarial loss"""
82
+
83
+ def __init__(self) -> None:
84
+ super().__init__()
85
+ # layer for forming the feature maps out of textual info
86
+ self.text_to_fm = conv1d(256, 17 * 17)
87
+ # fitting the size of text channels to the size of visual channels
88
+ self.chan_aligner = conv2d(1, 128)
89
+ # for reduced textual + visual features down to 1x1 feature map
90
+ self.joint_conv = nn.Conv2d(2 * 128, 1, kernel_size=17)
91
+ # converting Bx1x1x1 into Bx1
92
+ self.flat = nn.Flatten()
93
+
94
+ def forward(self, visual_features: torch.Tensor, sent_embs: torch.Tensor) -> Any:
95
+ """
96
+ Compute logits for conditional adversarial loss
97
+
98
+ :param torch.Tensor visual_features: Features from Inception encoder. Bx128x17x17
99
+ :param torch.Tensor sent_embs: Sentence embeddings from text encoder. Bx256
100
+ :return: Logits for conditional adversarial loss. BxL
101
+ :rtype: Any
102
+ """
103
+ # make text and visual features have the same sizes of feature maps
104
+ # Bx256 -> Bx256x1 -> Bx289x1
105
+ sent_embs = sent_embs.view(-1, 256, 1)
106
+ sent_embs = self.text_to_fm(sent_embs)
107
+ # transform textual info into shape of visual feature maps
108
+ # Bx289x1 -> Bx1x17x17
109
+ sent_embs = sent_embs.view(-1, 1, 17, 17)
110
+ # propagate text embs through 1d conv to
111
+ # align dims with visual feature maps
112
+ sent_embs = self.chan_aligner(sent_embs)
113
+ # unite textual and visual features across the dim of channels
114
+ cross_features = torch.cat((visual_features, sent_embs), dim=1)
115
+ # reduce dims down to length of caption and form raw logits
116
+ cross_features = self.joint_conv(cross_features)
117
+ # form logits from Bx1x1x1 into Bx1
118
+ logits = self.flat(cross_features)
119
+ return logits
120
+
121
+
122
+ class Discriminator(nn.Module):
123
+ """Simple CNN-based discriminator"""
124
+
125
+ def __init__(self) -> None:
126
+ """Use a pretrained InceptionNet to extract features"""
127
+ super().__init__()
128
+ self.encoder = InceptionEncoder(D=128)
129
+ # define different logit extractors for different losses
130
+ self.logits_word_level = WordLevelLogits()
131
+ self.logits_uncond = UnconditionalLogits()
132
+ self.logits_cond = ConditionalLogits()
133
+
134
+ def forward(self, images: torch.Tensor) -> Any:
135
+ """
136
+ Retrieves image features encoded by the image encoder
137
+
138
+ :param torch.Tensor images: Images to be analyzed. Bx3x256x256
139
+ :return: image features encoded by image encoder. Bx128x17x17
140
+ """
141
+ # only taking the local features from inception
142
+ # Bx3x256x256 -> Bx128x17x17
143
+ img_features, _ = self.encoder(images)
144
+ return img_features
src/models/modules/downsample.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """downsample module."""
2
+
3
+ from torch import nn
4
+
5
+
6
+ def down_sample(in_planes: int, out_planes: int) -> nn.Module:
7
+ """UpSample module."""
8
+ return nn.Sequential(
9
+ nn.Conv2d(
10
+ in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=False
11
+ ),
12
+ nn.BatchNorm2d(out_planes),
13
+ nn.LeakyReLU(0.2, inplace=True),
14
+ )
src/models/modules/generator.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generator Module"""
2
+
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from src.models.modules.acm import ACM
9
+ from src.models.modules.attention import ChannelWiseAttention, SpatialAttention
10
+ from src.models.modules.cond_augment import CondAugmentation
11
+ from src.models.modules.downsample import down_sample
12
+ from src.models.modules.residual import ResidualBlock
13
+ from src.models.modules.upsample import img_up_block, up_sample
14
+
15
+
16
+ class InitStageG(nn.Module):
17
+ """Initial Stage Generator Module"""
18
+
19
+ # pylint: disable=too-many-instance-attributes
20
+ # pylint: disable=too-many-arguments
21
+ # pylint: disable=invalid-name
22
+ # pylint: disable=too-many-locals
23
+
24
+ def __init__(
25
+ self, Ng: int, Ng_init: int, conditioning_dim: int, D: int, noise_dim: int
26
+ ):
27
+ """
28
+ :param Ng: Number of channels.
29
+ :param Ng_init: Initial value of Ng, this is output channel of first image upsample.
30
+ :param conditioning_dim: Dimension of the conditioning space
31
+ :param D: Dimension of the text embedding space [D from AttnGAN paper]
32
+ :param noise_dim: Dimension of the noise space
33
+ """
34
+ super().__init__()
35
+ self.gf_dim = Ng
36
+ self.gf_init = Ng_init
37
+ self.in_dim = noise_dim + conditioning_dim + D
38
+ self.text_dim = D
39
+
40
+ self.define_module()
41
+
42
+ def define_module(self) -> None:
43
+ """Defines FC, Upsample, Residual, ACM, Attention modules"""
44
+ nz, ng = self.in_dim, self.gf_dim
45
+ self.fully_connect = nn.Sequential(
46
+ nn.Linear(nz, ng * 4 * 4 * 2, bias=False),
47
+ nn.BatchNorm1d(ng * 4 * 4 * 2),
48
+ nn.GLU(dim=1), # we start from 4 x 4 feat_map and return hidden_64.
49
+ )
50
+
51
+ self.upsample1 = up_sample(ng, ng // 2)
52
+ self.upsample2 = up_sample(ng // 2, ng // 4)
53
+ self.upsample3 = up_sample(ng // 4, ng // 8)
54
+ self.upsample4 = up_sample(
55
+ ng // 8 * 3, ng // 16
56
+ ) # multiply channel by 3 because concat spatial and channel att
57
+
58
+ self.residual = self._make_layer(ResidualBlock, ng // 8 * 3)
59
+ self.acm_module = ACM(self.gf_init, ng // 8 * 3)
60
+
61
+ self.spatial_att = SpatialAttention(self.text_dim, ng // 8)
62
+ self.channel_att = ChannelWiseAttention(
63
+ 32 * 32, self.text_dim
64
+ ) # 32 x 32 is the feature map size
65
+
66
+ def _make_layer(self, block: Any, channel_num: int) -> nn.Module:
67
+ layers = []
68
+ for _ in range(2): # number of residual blocks hardcoded to 2
69
+ layers.append(block(channel_num))
70
+ return nn.Sequential(*layers)
71
+
72
+ def forward(
73
+ self,
74
+ noise: torch.Tensor,
75
+ condition: torch.Tensor,
76
+ global_inception: torch.Tensor,
77
+ local_upsampled_inception: torch.Tensor,
78
+ word_embeddings: torch.Tensor,
79
+ mask: Optional[torch.Tensor] = None,
80
+ ) -> Any:
81
+ """
82
+ :param noise: Noise tensor
83
+ :param condition: Condition tensor (c^ from stackGAN++ paper)
84
+ :param global_inception: Global inception feature
85
+ :param local_upsampled_inception: Local inception feature, upsampled to 32 x 32
86
+ :param word_embeddings: Word embeddings [shape: D x L or D x T]
87
+ :param mask: Mask for padding tokens
88
+ :return: Hidden Image feature map Tensor of 64 x 64 size
89
+ """
90
+ noise_concat = torch.cat((noise, condition), 1)
91
+ inception_concat = torch.cat((noise_concat, global_inception), 1)
92
+ hidden = self.fully_connect(inception_concat)
93
+ hidden = hidden.view(-1, self.gf_dim, 4, 4) # convert to 4x4 image feature map
94
+ hidden = self.upsample1(hidden)
95
+ hidden = self.upsample2(hidden)
96
+ hidden_32 = self.upsample3(hidden) # shape: (batch_size, gf_dim // 8, 32, 32)
97
+ hidden_32_view = hidden_32.view(
98
+ hidden_32.shape[0], -1, hidden_32.shape[2] * hidden_32.shape[3]
99
+ ) # this reshaping is done as attention module expects this shape.
100
+
101
+ spatial_att_feat = self.spatial_att(
102
+ word_embeddings, hidden_32_view, mask
103
+ ) # spatial att shape: (batch, D^, 32 * 32)
104
+ channel_att_feat = self.channel_att(
105
+ spatial_att_feat, word_embeddings
106
+ ) # channel att shape: (batch, D^, 32 * 32), or (batch, C, Hk* Wk) from controlGAN paper
107
+ spatial_att_feat = spatial_att_feat.view(
108
+ word_embeddings.shape[0], -1, hidden_32.shape[2], hidden_32.shape[3]
109
+ ) # reshape to (batch, D^, 32, 32)
110
+ channel_att_feat = channel_att_feat.view(
111
+ word_embeddings.shape[0], -1, hidden_32.shape[2], hidden_32.shape[3]
112
+ ) # reshape to (batch, D^, 32, 32)
113
+
114
+ spatial_concat = torch.cat(
115
+ (hidden_32, spatial_att_feat), 1
116
+ ) # concat spatial attention feature with hidden_32
117
+ attn_concat = torch.cat(
118
+ (spatial_concat, channel_att_feat), 1
119
+ ) # concat channel and spatial attention feature
120
+
121
+ hidden_32 = self.acm_module(attn_concat, local_upsampled_inception)
122
+ hidden_32 = self.residual(hidden_32)
123
+ hidden_64 = self.upsample4(hidden_32)
124
+ return hidden_64
125
+
126
+
127
+ class NextStageG(nn.Module):
128
+ """Next Stage Generator Module"""
129
+
130
+ # pylint: disable=too-many-instance-attributes
131
+ # pylint: disable=too-many-arguments
132
+ # pylint: disable=invalid-name
133
+ # pylint: disable=too-many-locals
134
+
135
+ def __init__(self, Ng: int, Ng_init: int, D: int, image_size: int):
136
+ """
137
+ :param Ng: Number of channels.
138
+ :param Ng_init: Initial value of Ng.
139
+ :param D: Dimension of the text embedding space [D from AttnGAN paper]
140
+ :param image_size: Size of the output image from previous generator stage.
141
+ """
142
+ super().__init__()
143
+ self.gf_dim = Ng
144
+ self.gf_init = Ng_init
145
+ self.text_dim = D
146
+ self.img_size = image_size
147
+
148
+ self.define_module()
149
+
150
+ def define_module(self) -> None:
151
+ """Defines FC, Upsample, Residual, ACM, Attention modules"""
152
+ ng = self.gf_dim
153
+ self.spatial_att = SpatialAttention(self.text_dim, ng)
154
+ self.channel_att = ChannelWiseAttention(
155
+ self.img_size * self.img_size, self.text_dim
156
+ )
157
+
158
+ self.residual = self._make_layer(ResidualBlock, ng * 3)
159
+ self.upsample = up_sample(ng * 3, ng)
160
+ self.acm_module = ACM(self.gf_init, ng * 3)
161
+ self.upsample2 = up_sample(ng, ng)
162
+
163
+ def _make_layer(self, block: Any, channel_num: int) -> nn.Module:
164
+ layers = []
165
+ for _ in range(2): # no of residual layers hardcoded to 2
166
+ layers.append(block(channel_num))
167
+ return nn.Sequential(*layers)
168
+
169
+ def forward(
170
+ self,
171
+ hidden_feat: Any,
172
+ word_embeddings: torch.Tensor,
173
+ vgg64_feat: torch.Tensor,
174
+ mask: Optional[torch.Tensor] = None,
175
+ ) -> Any:
176
+ """
177
+ :param hidden_feat: Hidden feature from previous generator stage [i.e. hidden_64]
178
+ :param word_embeddings: Word embeddings
179
+ :param vgg64_feat: VGG feature map of size 64 x 64
180
+ :param mask: Mask for the padding tokens
181
+ :return: Image feature map of size 256 x 256
182
+ """
183
+ hidden_view = hidden_feat.view(
184
+ hidden_feat.shape[0], -1, hidden_feat.shape[2] * hidden_feat.shape[3]
185
+ ) # reshape to pass into attention modules.
186
+ spatial_att_feat = self.spatial_att(
187
+ word_embeddings, hidden_view, mask
188
+ ) # spatial att shape: (batch, D^, 64 * 64), or D^ x N
189
+ channel_att_feat = self.channel_att(
190
+ spatial_att_feat, word_embeddings
191
+ ) # channel att shape: (batch, D^, 64 * 64), or (batch, C, Hk* Wk) from controlGAN paper
192
+ spatial_att_feat = spatial_att_feat.view(
193
+ word_embeddings.shape[0], -1, hidden_feat.shape[2], hidden_feat.shape[3]
194
+ ) # reshape to (batch, D^, 64, 64)
195
+ channel_att_feat = channel_att_feat.view(
196
+ word_embeddings.shape[0], -1, hidden_feat.shape[2], hidden_feat.shape[3]
197
+ ) # reshape to (batch, D^, 64, 64)
198
+
199
+ spatial_concat = torch.cat(
200
+ (hidden_feat, spatial_att_feat), 1
201
+ ) # concat spatial attention feature with hidden_64
202
+ attn_concat = torch.cat(
203
+ (spatial_concat, channel_att_feat), 1
204
+ ) # concat channel and spatial attention feature
205
+
206
+ hidden_64 = self.acm_module(attn_concat, vgg64_feat)
207
+ hidden_64 = self.residual(hidden_64)
208
+ hidden_128 = self.upsample(hidden_64)
209
+ hidden_256 = self.upsample2(hidden_128)
210
+ return hidden_256
211
+
212
+
213
+ class GetImageG(nn.Module):
214
+ """Generates the Final Fake Image from the Image Feature Map"""
215
+
216
+ def __init__(self, Ng: int):
217
+ """
218
+ :param Ng: Number of channels.
219
+ """
220
+ super().__init__()
221
+ self.img = nn.Sequential(
222
+ nn.Conv2d(Ng, 3, kernel_size=3, stride=1, padding=1, bias=False), nn.Tanh()
223
+ )
224
+
225
+ def forward(self, hidden_feat: torch.Tensor) -> Any:
226
+ """
227
+ :param hidden_feat: Image feature map
228
+ :return: Final fake image
229
+ """
230
+ return self.img(hidden_feat)
231
+
232
+
233
+ class Generator(nn.Module):
234
+ """Generator Module"""
235
+
236
+ # pylint: disable=too-many-instance-attributes
237
+ # pylint: disable=too-many-arguments
238
+ # pylint: disable=invalid-name
239
+ # pylint: disable=too-many-locals
240
+
241
+ def __init__(self, Ng: int, D: int, conditioning_dim: int, noise_dim: int):
242
+ """
243
+ :param Ng: Number of channels. [Taken from StackGAN++ paper]
244
+ :param D: Dimension of the text embedding space
245
+ :param conditioning_dim: Dimension of the conditioning space
246
+ :param noise_dim: Dimension of the noise space
247
+ """
248
+ super().__init__()
249
+ self.cond_augment = CondAugmentation(D, conditioning_dim)
250
+ self.hidden_net1 = InitStageG(Ng * 16, Ng, conditioning_dim, D, noise_dim)
251
+ self.inception_img_upsample = img_up_block(
252
+ D, Ng
253
+ ) # as channel size returned by inception encoder is D (Default in paper: 256)
254
+ self.hidden_net2 = NextStageG(Ng, Ng, D, 64)
255
+ self.generate_img = GetImageG(Ng)
256
+
257
+ self.acm_module = ACM(Ng, Ng)
258
+
259
+ self.vgg_downsample = down_sample(D // 2, Ng)
260
+ self.upsample1 = up_sample(Ng, Ng)
261
+ self.upsample2 = up_sample(Ng, Ng)
262
+
263
+ def forward(
264
+ self,
265
+ noise: torch.Tensor,
266
+ sentence_embeddings: torch.Tensor,
267
+ word_embeddings: torch.Tensor,
268
+ global_inception_feat: torch.Tensor,
269
+ local_inception_feat: torch.Tensor,
270
+ vgg_feat: torch.Tensor,
271
+ mask: Optional[torch.Tensor] = None,
272
+ ) -> Any:
273
+ """
274
+ :param noise: Noise vector [shape: (batch, noise_dim)]
275
+ :param sentence_embeddings: Sentence embeddings [shape: (batch, D)]
276
+ :param word_embeddings: Word embeddings [shape: D x L, where L is length of sentence]
277
+ :param global_inception_feat: Global Inception feature map [shape: (batch, D)]
278
+ :param local_inception_feat: Local Inception feature map [shape: (batch, D, 17, 17)]
279
+ :param vgg_feat: VGG feature map [shape: (batch, D // 2 = 128, 128, 128)]
280
+ :param mask: Mask for the padding tokens
281
+ :return: Final fake image
282
+ """
283
+ c_hat, mu_tensor, logvar = self.cond_augment(sentence_embeddings)
284
+ hidden_32 = self.inception_img_upsample(local_inception_feat)
285
+
286
+ hidden_64 = self.hidden_net1(
287
+ noise, c_hat, global_inception_feat, hidden_32, word_embeddings, mask
288
+ )
289
+
290
+ vgg_64 = self.vgg_downsample(vgg_feat)
291
+
292
+ hidden_256 = self.hidden_net2(hidden_64, word_embeddings, vgg_64, mask)
293
+
294
+ vgg_128 = self.upsample1(vgg_64)
295
+ vgg_256 = self.upsample2(vgg_128)
296
+
297
+ hidden_256 = self.acm_module(hidden_256, vgg_256)
298
+ fake_img = self.generate_img(hidden_256)
299
+
300
+ return fake_img, mu_tensor, logvar
src/models/modules/image_encoder.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image Encoder Module"""
2
+ from typing import Any
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from src.models.modules.conv_utils import conv2d
8
+
9
+ # build inception v3 image encoder
10
+
11
+
12
+ class InceptionEncoder(nn.Module):
13
+ """Image Encoder Module adapted from AttnGAN"""
14
+
15
+ def __init__(self, D: int):
16
+ """
17
+ :param D: Dimension of the text embedding space [D from AttnGAN paper]
18
+ """
19
+ super().__init__()
20
+
21
+ self.text_emb_dim = D
22
+
23
+ model = torch.hub.load(
24
+ "pytorch/vision:v0.10.0", "inception_v3", pretrained=True
25
+ )
26
+ for param in model.parameters():
27
+ param.requires_grad = False
28
+
29
+ self.define_module(model)
30
+ self.init_trainable_weights()
31
+
32
+ def define_module(self, model: nn.Module) -> None:
33
+ """
34
+ This function defines the modules of the image encoder
35
+ :param model: Pretrained Inception V3 model
36
+ """
37
+ model.cust_upsample = nn.Upsample(size=(299, 299), mode="bilinear")
38
+ model.cust_maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
39
+ model.cust_maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
40
+ model.cust_avgpool = nn.AvgPool2d(kernel_size=8)
41
+
42
+ attribute_list = [
43
+ "cust_upsample",
44
+ "Conv2d_1a_3x3",
45
+ "Conv2d_2a_3x3",
46
+ "Conv2d_2b_3x3",
47
+ "cust_maxpool1",
48
+ "Conv2d_3b_1x1",
49
+ "Conv2d_4a_3x3",
50
+ "cust_maxpool2",
51
+ "Mixed_5b",
52
+ "Mixed_5c",
53
+ "Mixed_5d",
54
+ "Mixed_6a",
55
+ "Mixed_6b",
56
+ "Mixed_6c",
57
+ "Mixed_6d",
58
+ "Mixed_6e",
59
+ ]
60
+
61
+ self.feature_extractor = nn.Sequential(
62
+ *[getattr(model, name) for name in attribute_list]
63
+ )
64
+
65
+ attribute_list2 = ["Mixed_7a", "Mixed_7b", "Mixed_7c", "cust_avgpool"]
66
+
67
+ self.feature_extractor2 = nn.Sequential(
68
+ *[getattr(model, name) for name in attribute_list2]
69
+ )
70
+
71
+ self.emb_features = conv2d(
72
+ 768, self.text_emb_dim, kernel_size=1, stride=1, padding=0
73
+ )
74
+ self.emb_cnn_code = nn.Linear(2048, self.text_emb_dim)
75
+
76
+ def init_trainable_weights(self) -> None:
77
+ """
78
+ This function initializes the trainable weights of the image encoder
79
+ """
80
+ initrange = 0.1
81
+ self.emb_features.weight.data.uniform_(-initrange, initrange)
82
+ self.emb_cnn_code.weight.data.uniform_(-initrange, initrange)
83
+
84
+ def forward(self, image_tensor: torch.Tensor) -> Any:
85
+ """
86
+ :param image_tensor: Input image
87
+ :return: features: local feature matrix (v from attnGAN paper) [shape: (batch, D, 17, 17)]
88
+ :return: cnn_code: global image feature (v^ from attnGAN paper) [shape: (batch, D)]
89
+ """
90
+ # this is the image size
91
+ # x.shape: 10 3 256 256
92
+
93
+ features = self.feature_extractor(image_tensor)
94
+ # 17 x 17 x 768
95
+
96
+ image_tensor = self.feature_extractor2(features)
97
+
98
+ image_tensor = image_tensor.view(image_tensor.size(0), -1)
99
+ # 2048
100
+
101
+ # global image features
102
+ cnn_code = self.emb_cnn_code(image_tensor)
103
+
104
+ if features is not None:
105
+ features = self.emb_features(features)
106
+
107
+ # feature.shape: 10 256 17 17
108
+ # cnn_code.shape: 10 256
109
+ return features, cnn_code
110
+
111
+
112
+ class VGGEncoder(nn.Module):
113
+ """Pre Trained VGG Encoder Module"""
114
+
115
+ def __init__(self) -> None:
116
+ """
117
+ Initialize pre-trained VGG model with frozen parameters
118
+ """
119
+ super().__init__()
120
+ self.select = "8" ## We want to get the output of the 8th layer in VGG.
121
+
122
+ self.model = torch.hub.load("pytorch/vision:v0.10.0", "vgg16", pretrained=True)
123
+
124
+ for param in self.model.parameters():
125
+ param.resquires_grad = False
126
+
127
+ self.vgg_modules = self.model.features._modules
128
+
129
+ def forward(self, image_tensor: torch.Tensor) -> Any:
130
+ """
131
+ :param x: Input image tensor [shape: (batch, 3, 256, 256)]
132
+ :return: VGG features [shape: (batch, 128, 128, 128)]
133
+ """
134
+ for name, layer in self.vgg_modules.items():
135
+ image_tensor = layer(image_tensor)
136
+ if name == self.select:
137
+ return image_tensor
138
+ return None