Nikhil Mudhalwadkar commited on
Commit
c6d5483
1 Parent(s): 77b7934

added other files

Browse files
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import matplotlib
4
+ matplotlib.use('Agg')
5
+ import numpy as np
6
+ from PIL import Image
7
+ import albumentations as A
8
+ import albumentations.pytorch as al_pytorch
9
+ import matplotlib.pyplot as plt
10
+ import torchvision
11
+
12
+ from app.model.lit_model import Pix2PixLitModule
13
+
14
+ """ Load the model """
15
+ model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=9-step=17780.ckpt"
16
+ model = Pix2PixLitModule.load_from_checkpoint(
17
+ model_checkpoint_path
18
+ )
19
+ model.eval()
20
+
21
+
22
+ def greet(name):
23
+ return "Hello " + name + "!!"
24
+
25
+
26
+ def predict(image: Image):
27
+ # use on inference
28
+ inference_transform = A.Compose([
29
+ A.Resize(width=256, height=256),
30
+ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
31
+ al_pytorch.ToTensorV2(),
32
+ ])
33
+ inference_img = inference_transform(
34
+ image=np.asarray(image)
35
+ )['image'].unsqueeze(0)
36
+ result = model(inference_img)
37
+ result_grid = torchvision.utils.make_grid(
38
+ [result[0].permute(1, 2, 0).detach()],
39
+ normalize=True
40
+ )
41
+ plt.imsave("coloured_grid.png", result_grid.numpy())
42
+ torchvision.utils.save_image(result, "coloured_image.png", normalize=True)
43
+ return 'coloured_image.png', 'coloured_grid.png'
44
+
45
+
46
+ if __name__ == '__main__':
47
+ #
48
+ iface = gr.Interface(
49
+ fn=predict,
50
+ inputs=gr.inputs.Image(type="pil"),
51
+ examples=["examples/thesis_test.png", "examples/thesis_test2.png"],
52
+ outputs=["image","image"],
53
+ title="Colour your sketches!",
54
+ description=" Upload a sketch and the conditional gan will colour it for you!",
55
+ article="WIP repo lives here - https://github.com/nmud19/thesisGAN "
56
+ )
57
+ iface.launch()
58
+ #
app/__init__.py ADDED
File without changes
app/config.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ num_workers = 4
2
+ train_batch_size = 32
3
+ val_batch_size = 1
app/consume_data/__init__.py ADDED
File without changes
app/consume_data/consume_data.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from typing import List, Optional
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ from torchvision import transforms
7
+ import albumentations as A
8
+ import numpy as np
9
+ import albumentations.pytorch as al_pytorch
10
+ from typing import Dict, Tuple
11
+ from app import config
12
+ import pytorch_lightning as pl
13
+
14
+ torch.__version__
15
+
16
+
17
+ class AnimeDataset(torch.utils.data.Dataset):
18
+ """ Sketchs and Colored Image dataset """
19
+
20
+ def __init__(self, imgs_path: List[str], transforms: transforms.Compose) -> None:
21
+ """ Set the transforms and file path """
22
+ self.list_files = imgs_path
23
+ self.transform = transforms
24
+
25
+ def __len__(self) -> int:
26
+ """ Should return number of files """
27
+ return len(self.list_files)
28
+
29
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
30
+ """ Get image and mask by index """
31
+ # read image file
32
+ img_file = self.list_files[index]
33
+ # img_path = os.path.join(self.root_dir, img_file)
34
+ image = np.array(Image.open(img_file))
35
+
36
+ # divide image into sketchs and colored_imgs, right is sketch and left is colored images
37
+ sketchs = image[:, image.shape[1] // 2:, :]
38
+ colored_imgs = image[:, :image.shape[1] // 2, :]
39
+
40
+ # data augmentation on both sketchs and colored_imgs
41
+ augmentations = self.transform.both_transform(image=sketchs, image0=colored_imgs)
42
+ sketchs, colored_imgs = augmentations['image'], augmentations['image0']
43
+
44
+ # conduct data augmentation respectively
45
+ sketchs = self.transform.transform_only_input(image=sketchs)['image']
46
+ colored_imgs = self.transform.transform_only_mask(image=colored_imgs)['image']
47
+ return sketchs, colored_imgs
48
+
49
+
50
+ # Data Augmentation
51
+ class Transforms:
52
+ def __init__(self):
53
+ # use on both sketchs and colored images
54
+ self.both_transform = A.Compose([
55
+ A.Resize(width=256, height=256),
56
+ A.HorizontalFlip(p=.5)
57
+ ], additional_targets={'image0': 'image'})
58
+
59
+ # use on sketchs only
60
+ self.transform_only_input = A.Compose([
61
+ A.ColorJitter(p=.1),
62
+ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
63
+ al_pytorch.ToTensorV2(),
64
+ ])
65
+
66
+ # use on colored images
67
+ self.transform_only_mask = A.Compose([
68
+ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
69
+ al_pytorch.ToTensorV2(),
70
+ ])
71
+
72
+
73
+ class Transforms_v1:
74
+ """ Class to hold transforms """
75
+
76
+ def __init__(self):
77
+ # use on both sketchs and colored images
78
+ self.resize_572 = A.Compose([
79
+ A.Resize(width=572, height=572)
80
+ ])
81
+
82
+ self.resize_388 = A.Compose([
83
+ A.Resize(width=388, height=388)
84
+ ])
85
+
86
+ self.resize_256 = A.Compose([
87
+ A.Resize(width=256, height=256)
88
+ ])
89
+
90
+ # use on sketchs only
91
+ self.transform_only_input = A.Compose([
92
+ # A.ColorJitter(p=.1),
93
+ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
94
+ al_pytorch.ToTensorV2(),
95
+ ])
96
+
97
+ # use on colored images
98
+ self.transform_only_mask = A.Compose([
99
+ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
100
+ al_pytorch.ToTensorV2(),
101
+ ])
102
+
103
+
104
+ class AnimeSketchDataModule(pl.LightningDataModule):
105
+ """ Class to hold the Anime sketch Data"""
106
+
107
+ def __init__(
108
+ self,
109
+ data_dir: str,
110
+ train_folder_name: str = "train/",
111
+ val_folder_name: str = "val/",
112
+ train_batch_size: int = config.train_batch_size,
113
+ val_batch_size: int = config.val_batch_size,
114
+ train_num_images: int = 0,
115
+ val_num_images: int = 0,
116
+ ):
117
+ super().__init__()
118
+ self.val_dataset = None
119
+ self.train_dataset = None
120
+ self.data_dir: str = data_dir
121
+ # Set train and val images folder
122
+ train_path: str = f"{self.data_dir}{train_folder_name}/"
123
+ train_images: List[str] = [f"{train_path}{x}" for x in os.listdir(train_path)]
124
+ val_path: str = f"{self.data_dir}{val_folder_name}"
125
+ val_images: List[str] = [f"{val_path}{x}" for x in os.listdir(val_path)]
126
+ #
127
+ self.train_images = train_images[:train_num_images] if train_num_images else train_images
128
+ self.val_images = val_images[:val_num_images] if val_num_images else val_images
129
+ #
130
+ self.train_batch_size = train_batch_size
131
+ self.val_batch_size = val_batch_size
132
+
133
+ def set_datasets(self) -> None:
134
+ """ Get the train and test datasets """
135
+ self.train_dataset = AnimeDataset(
136
+ imgs_path=self.train_images,
137
+ transforms=Transforms()
138
+ )
139
+ self.val_dataset = AnimeDataset(
140
+ imgs_path=self.val_images,
141
+ transforms=Transforms()
142
+ )
143
+ print("The train test dataset lengths are : ", len(self.train_dataset), len(self.val_dataset))
144
+ return None
145
+
146
+ def setup(self, stage: Optional[str] = None) -> None:
147
+ self.set_datasets()
148
+
149
+ def train_dataloader(self):
150
+ return torch.utils.data.DataLoader(
151
+ self.train_dataset,
152
+ batch_size=self.train_batch_size,
153
+ shuffle=False,
154
+ num_workers=2,
155
+ pin_memory=True
156
+ )
157
+
158
+ def val_dataloader(self):
159
+ return torch.utils.data.DataLoader(
160
+ self.val_dataset,
161
+ batch_size=self.val_batch_size,
162
+ shuffle=False,
163
+ num_workers=2,
164
+ pin_memory=True
165
+ )
app/data.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from typing import List
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ from torchvision import transforms
7
+ import albumentations as A
8
+ import numpy as np
9
+ import albumentations.pytorch as al_pytorch
10
+ from typing import Dict, Tuple
11
+
12
+
13
+ class AnimeDataset(torch.utils.data.Dataset):
14
+ """ Sketchs and Colored Image dataset """
15
+
16
+ def __init__(self, imgs_path: List[str], transforms: transforms.Compose) -> None:
17
+ """ Set the transforms and file path """
18
+ self.list_files = imgs_path
19
+ self.transform = transforms
20
+
21
+ def __len__(self) -> int:
22
+ """ Should return number of files """
23
+ return len(self.list_files)
24
+
25
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
26
+ """ Get image and mask by index """
27
+ # read image file
28
+ img_path = img_file = self.list_files[index]
29
+ image = np.array(Image.open(img_path))
30
+
31
+ # divide image into sketchs and colored_imgs, right is sketch and left is colored images
32
+ # as according to the dataset
33
+ sketchs = image[:, image.shape[1] // 2:, :]
34
+ colored_imgs = image[:, :image.shape[1] // 2, :]
35
+
36
+ # data augmentation on both sketchs and colored_imgs
37
+ augmentations = self.transform.both_transform(image=sketchs, image0=colored_imgs)
38
+ sketchs, colored_imgs = augmentations['image'], augmentations['image0']
39
+
40
+ # conduct data augmentation respectively
41
+ sketchs = self.transform.transform_only_input(image=sketchs)['image']
42
+ colored_imgs = self.transform.transform_only_mask(image=colored_imgs)['image']
43
+ return sketchs, colored_imgs
44
+
45
+
46
+ class Transforms:
47
+ """ Class to hold transforms """
48
+
49
+ def __init__(self):
50
+ # use on both sketchs and colored images
51
+ self.both_transform = A.Compose([
52
+ A.Resize(width=1024, height=1024),
53
+ A.HorizontalFlip(p=.5)
54
+ ],
55
+ additional_targets={'image0': 'image'}
56
+ )
57
+
58
+ # use on sketchs only
59
+ self.transform_only_input = A.Compose([
60
+ # A.ColorJitter(p=.1),
61
+ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
62
+ al_pytorch.ToTensorV2(),
63
+ ])
64
+
65
+ # use on colored images
66
+ self.transform_only_mask = A.Compose([
67
+ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
68
+ al_pytorch.ToTensorV2(),
69
+ ])
app/discriminator/__init__.py ADDED
File without changes
app/discriminator/patch_gan.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import albumentations as A
4
+
5
+
6
+ # CNN block will be used repeatly later
7
+ class CNNBlock(nn.Module):
8
+ def __init__(self, in_channels, out_channels, stride=2):
9
+ super().__init__()
10
+ self.conv = nn.Sequential(
11
+ nn.Conv2d(in_channels, out_channels, 4, stride, bias=False, padding_mode='reflect'),
12
+ nn.BatchNorm2d(out_channels),
13
+ nn.LeakyReLU(0.2)
14
+ )
15
+
16
+ def forward(self, x):
17
+ return self.conv(x)
18
+
19
+
20
+ class PatchGan(torch.nn.Module):
21
+ """ Patch GAN Architecture """
22
+
23
+ @staticmethod
24
+ def create_contracting_block(in_channels: int, out_channels: int):
25
+ """
26
+ Create encoding layer
27
+ :param in_channels:
28
+ :param out_channels:
29
+ :return:
30
+ """
31
+ conv_layer = torch.nn.Sequential(
32
+ torch.nn.Conv2d(
33
+ in_channels=in_channels,
34
+ out_channels=out_channels,
35
+ kernel_size=3,
36
+ padding=1,
37
+ ),
38
+ torch.nn.ReLU(),
39
+ torch.nn.Conv2d(
40
+ in_channels=out_channels,
41
+ out_channels=out_channels,
42
+ kernel_size=3,
43
+ padding=1,
44
+ ),
45
+ torch.nn.ReLU(),
46
+ )
47
+ max_pool = torch.nn.Sequential(
48
+ torch.nn.MaxPool2d(
49
+ stride=2,
50
+ kernel_size=2,
51
+ ),
52
+ )
53
+ layer = torch.nn.Sequential(
54
+ conv_layer,
55
+ max_pool,
56
+ )
57
+ return layer
58
+
59
+ def __init__(self, input_channels: int, hidden_channels: int) -> None:
60
+ super().__init__()
61
+ self.resize_channels = torch.nn.Conv2d(
62
+ in_channels=input_channels,
63
+ out_channels=hidden_channels,
64
+ kernel_size=1,
65
+ )
66
+
67
+ self.enc1 = self.create_contracting_block(
68
+ in_channels=hidden_channels,
69
+ out_channels=hidden_channels * 2
70
+ )
71
+
72
+ self.enc2 = self.create_contracting_block(
73
+ in_channels=hidden_channels * 2,
74
+ out_channels=hidden_channels * 4
75
+ )
76
+
77
+ self.enc3 = self.create_contracting_block(
78
+ in_channels=hidden_channels * 4,
79
+ out_channels=hidden_channels * 8
80
+ )
81
+ self.enc4 = self.create_contracting_block(
82
+ in_channels=hidden_channels * 8,
83
+ out_channels=hidden_channels * 16
84
+ )
85
+
86
+ self.final_layer = torch.nn.Conv2d(
87
+ in_channels=hidden_channels * 16,
88
+ out_channels=1,
89
+ kernel_size=1,
90
+ )
91
+
92
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
93
+ """ Forward patch gan layer """
94
+ inpt = torch.cat([x, y], axis=1)
95
+ resize_img = self.resize_channels(inpt)
96
+ enc1 = self.enc1(resize_img)
97
+ enc2 = self.enc2(enc1)
98
+ enc3 = self.enc3(enc2)
99
+ enc4 = self.enc4(enc3)
100
+ final_layer = self.final_layer(enc4)
101
+ return final_layer
102
+
103
+
104
+ # x, y <- concatenate the gen image and the input image to determin the gen image is real or not
105
+ class Discriminator(nn.Module):
106
+ def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
107
+ super().__init__()
108
+ self.initial = nn.Sequential(
109
+ nn.Conv2d(in_channels * 2, features[0], kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
110
+ nn.LeakyReLU(.2)
111
+ )
112
+
113
+ # save layers into a list
114
+ layers = []
115
+ in_channels = features[0]
116
+ for feature in features[1:]:
117
+ layers.append(
118
+ CNNBlock(
119
+ in_channels,
120
+ feature,
121
+ stride=1 if feature == features[-1] else 2
122
+ ),
123
+ )
124
+ in_channels = feature
125
+
126
+ # append last conv layer
127
+ layers.append(
128
+ nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode='reflect')
129
+ )
130
+
131
+ # create a model using the list of layers
132
+ self.model = nn.Sequential(*layers)
133
+
134
+ def forward(self, x, y):
135
+ x = torch.cat([x, y], dim=1)
136
+ x = self.initial(x)
137
+ return self.model(x)
app/generator/__init__.py ADDED
File without changes
app/generator/unetGen.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from app.generator import unetParts
4
+
5
+
6
+ class UNET(torch.nn.Module):
7
+ """ Implementation of unet """
8
+
9
+ def __init__(
10
+ self,
11
+ ) -> None:
12
+ """
13
+ Create the UNET here
14
+ """
15
+ super().__init__()
16
+ self.enc_layer1: unetParts.EncoderLayer = unetParts.EncoderLayer(
17
+ in_channels=3,
18
+ out_channels=64
19
+ )
20
+ self.enc_layer2: unetParts.EncoderLayer = unetParts.EncoderLayer(
21
+ in_channels=64,
22
+ out_channels=128
23
+ )
24
+ self.enc_layer3: unetParts.EncoderLayer = unetParts.EncoderLayer(
25
+ in_channels=128,
26
+ out_channels=256
27
+ )
28
+ self.enc_layer4: unetParts.EncoderLayer = unetParts.EncoderLayer(
29
+ in_channels=256,
30
+ out_channels=512
31
+ )
32
+ # Middle layer
33
+ self.middle_layer: unetParts.MiddleLayer = unetParts.MiddleLayer(
34
+ in_channels=512,
35
+ out_channels=1024,
36
+ )
37
+ # Decoding layer
38
+ self.dec_layer1: unetParts.DecoderLayer = unetParts.DecoderLayer(
39
+ in_channels=1024,
40
+ out_channels=512,
41
+ )
42
+ self.dec_layer2: unetParts.DecoderLayer = unetParts.DecoderLayer(
43
+ in_channels=512,
44
+ out_channels=256,
45
+ )
46
+
47
+ self.dec_layer3: unetParts.DecoderLayer = unetParts.DecoderLayer(
48
+ in_channels=256,
49
+ out_channels=128,
50
+ )
51
+ self.dec_layer4: unetParts.DecoderLayer = unetParts.DecoderLayer(
52
+ in_channels=128,
53
+ out_channels=64,
54
+ )
55
+ self.final_layer: torch.nn.Conv2d = torch.nn.Conv2d(
56
+ in_channels=64,
57
+ out_channels=3,
58
+ kernel_size=1
59
+ )
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ """
63
+ Forward function
64
+ :param x:
65
+ :return:
66
+ """
67
+ # enc layers
68
+ enc1, conv1 = self.enc_layer1(x=x) # 64
69
+ enc2, conv2 = self.enc_layer2(x=enc1) # 128
70
+ enc3, conv3 = self.enc_layer3(x=enc2) # 256
71
+ enc4, conv4 = self.enc_layer4(x=enc3) # 512
72
+ # middle layers
73
+ mid = self.middle_layer(x=enc4) # 1024
74
+ # expanding layers
75
+ # 512
76
+ dec1 = self.dec_layer1(
77
+ input_layer=mid,
78
+ cropping_layer=conv4,
79
+ )
80
+ # 256
81
+ dec2 = self.dec_layer2(
82
+ input_layer=dec1,
83
+ cropping_layer=conv3,
84
+ )
85
+ # 128
86
+ dec3 = self.dec_layer3(
87
+ input_layer=dec2,
88
+ cropping_layer=conv2,
89
+ )
90
+ # 64
91
+ dec4 = self.dec_layer4(
92
+ input_layer=dec3,
93
+ cropping_layer=conv1,
94
+ )
95
+ # 3
96
+ fin_layer = self.final_layer(
97
+ dec4,
98
+ )
99
+ # Interpolate to retain size
100
+ fin_layer_resized = torch.nn.functional.interpolate(fin_layer, 572)
101
+ return fin_layer_resized
102
+
103
+
104
+ class Generator(nn.Module):
105
+ def __init__(self, in_channels=3, features=64):
106
+ super().__init__()
107
+ # Encoder
108
+ self.initial_down = nn.Sequential(
109
+ nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode='reflect'),
110
+ nn.LeakyReLU(.2),
111
+ )
112
+ self.down1 = Block(features, features * 2, down=True, act='leaky', use_dropout=False) # 64
113
+ self.down2 = Block(features * 2, features * 4, down=True, act='leaky', use_dropout=False) # 32
114
+ self.down3 = Block(features * 4, features * 8, down=True, act='leaky', use_dropout=False) # 16
115
+ self.down4 = Block(features * 8, features * 8, down=True, act='leaky', use_dropout=False) # 8
116
+ self.down5 = Block(features * 8, features * 8, down=True, act='leaky', use_dropout=False) # 4
117
+ self.down6 = Block(features * 8, features * 8, down=True, act='leaky', use_dropout=False) # 2
118
+ self.bottleneck = nn.Sequential(
119
+ nn.Conv2d(features * 8, features * 8, 4, 2, 1, padding_mode='reflect'),
120
+ nn.ReLU(), # 1x1
121
+ )
122
+ # Decoder
123
+ self.up1 = Block(features * 8, features * 8, down=False, act='relu', use_dropout=True)
124
+ self.up2 = Block(features * 8 * 2, features * 8, down=False, act='relu', use_dropout=True)
125
+ self.up3 = Block(features * 8 * 2, features * 8, down=False, act='relu', use_dropout=True)
126
+ self.up4 = Block(features * 8 * 2, features * 8, down=False, act='relu', use_dropout=False)
127
+ self.up5 = Block(features * 8 * 2, features * 4, down=False, act='relu', use_dropout=False)
128
+ self.up6 = Block(features * 4 * 2, features * 2, down=False, act='relu', use_dropout=False)
129
+ self.up7 = Block(features * 2 * 2, features, down=False, act='relu', use_dropout=False)
130
+ self.final_up = nn.Sequential(
131
+ nn.ConvTranspose2d(features * 2, in_channels, kernel_size=4, stride=2, padding=1),
132
+ nn.Tanh()
133
+ )
134
+
135
+ def forward(self, x):
136
+ # Encoder
137
+ d1 = self.initial_down(x)
138
+ d2 = self.down1(d1)
139
+ d3 = self.down2(d2)
140
+ d4 = self.down3(d3)
141
+ d5 = self.down4(d4)
142
+ d6 = self.down5(d5)
143
+ d7 = self.down6(d6)
144
+ bottleneck = self.bottleneck(d7)
145
+
146
+ # Decoder
147
+ u1 = self.up1(bottleneck)
148
+ u2 = self.up2(torch.cat([u1, d7], 1))
149
+ u3 = self.up3(torch.cat([u2, d6], 1))
150
+ u4 = self.up4(torch.cat([u3, d5], 1))
151
+ u5 = self.up5(torch.cat([u4, d4], 1))
152
+ u6 = self.up6(torch.cat([u5, d3], 1))
153
+ u7 = self.up7(torch.cat([u6, d2], 1))
154
+ return self.final_up(torch.cat([u7, d1], 1))
155
+
156
+
157
+ # block will be use repeatly later
158
+ class Block(nn.Module):
159
+ def __init__(self, in_channels, out_channels, down=True, act='relu', use_dropout=False):
160
+ super().__init__()
161
+ self.conv = nn.Sequential(
162
+ # the block will be use on both encoder (down=True) and decoder (down=False)
163
+ nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode='reflect')
164
+ if down
165
+ else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
166
+ nn.BatchNorm2d(out_channels),
167
+ nn.ReLU() if act == 'relu' else nn.LeakyReLU(.2)
168
+ )
169
+ self.use_dropout = use_dropout
170
+ self.dropout = nn.Dropout(.5)
171
+
172
+ def forward(self, x):
173
+ x = self.conv(x)
174
+ return self.dropout(x) if self.use_dropout else x
app/generator/unetParts.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple
3
+
4
+
5
+ class DecoderLayer(torch.nn.Module):
6
+ """Decoder model"""
7
+
8
+ def __init__(self, in_channels: int, out_channels: int):
9
+ super().__init__()
10
+ self.up_sample_layer = torch.nn.Sequential(
11
+ torch.nn.ConvTranspose2d(
12
+ in_channels=in_channels,
13
+ out_channels=out_channels,
14
+ kernel_size=2,
15
+ stride=2,
16
+ bias=False,
17
+ )
18
+ )
19
+ self.conv_layer = EncoderLayer(
20
+ in_channels=in_channels,
21
+ out_channels=out_channels,
22
+ ).conv_layer
23
+
24
+ @staticmethod
25
+ def _get_cropping_shape(previous_layer_shape: torch.Size, current_layer_shape: torch.Size) -> int:
26
+ """ Get the shape to crop """
27
+ return (previous_layer_shape[2] - current_layer_shape[2]) // 2 * -1
28
+
29
+ def forward(
30
+ self,
31
+ input_layer: torch.Tensor,
32
+ cropping_layer: torch.Tensor
33
+ ) -> torch.Tensor:
34
+ """
35
+ Forward function to concatenate and conv the figure
36
+ :param cropping_layer:
37
+ :param input_layer:
38
+ :return:
39
+ """
40
+ input_layer = self.up_sample_layer(input_layer)
41
+
42
+ cropping_shape = self._get_cropping_shape(
43
+ current_layer_shape=input_layer.shape,
44
+ previous_layer_shape=cropping_layer.shape,
45
+ )
46
+
47
+ cropping_layer = torch.nn.functional.pad(
48
+ input=cropping_layer,
49
+ pad=[cropping_shape for _ in range(4)]
50
+ )
51
+ combined_layer = torch.cat(
52
+ tensors=[input_layer, cropping_layer],
53
+ dim=1
54
+ )
55
+ result = self.conv_layer(combined_layer)
56
+ return result
57
+
58
+
59
+ class EncoderLayer(torch.nn.Module):
60
+ """Encoder Layer"""
61
+
62
+ def __init__(self, in_channels: int, out_channels: int) -> None:
63
+ super().__init__()
64
+ self.conv_layer = torch.nn.Sequential(
65
+ torch.nn.Conv2d(
66
+ in_channels=in_channels,
67
+ out_channels=out_channels,
68
+ kernel_size=3,
69
+ stride=2,
70
+ padding=1,
71
+ ),
72
+ torch.nn.LeakyReLU(),
73
+ torch.nn.Conv2d(
74
+ in_channels=out_channels,
75
+ out_channels=out_channels,
76
+ kernel_size=3,
77
+ stride=2,
78
+ padding=1,
79
+ ),
80
+ torch.nn.LeakyReLU(),
81
+ )
82
+ self.max_pool = torch.nn.Sequential(
83
+ torch.nn.MaxPool2d(2),
84
+ )
85
+ self.layer = torch.nn.Sequential(
86
+ self.conv_layer,
87
+ self.max_pool,
88
+ )
89
+
90
+ def get_conv_layers(self, x: torch.Tensor) -> torch.Tensor:
91
+ """Need to concatenate the layer"""
92
+ return self.conv_layer(x)
93
+
94
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
95
+ """Forward pass to return conv layer and the max pool layer"""
96
+ conv_output: torch.tensor = self.conv_layer(x)
97
+ fin_out: torch.Tensor = self.max_pool(conv_output)
98
+ return fin_out, conv_output
99
+
100
+
101
+ class MiddleLayer(EncoderLayer):
102
+ """Middle layer only"""
103
+
104
+ def forward(self, x: torch.tensor) -> torch.tensor:
105
+ """Forward pass"""
106
+ return self.conv_layer(x)
app/model/__init__.py ADDED
File without changes
app/model/lit_model.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import pytorch_lightning as pl
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision
6
+
7
+
8
+ class Pix2PixLitModule(pl.LightningModule):
9
+ """ Lightning Module for pix2pix """
10
+
11
+ @staticmethod
12
+ def _weights_init(m):
13
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
14
+ torch.nn.init.normal_(m.weight, 0.0, 0.02)
15
+ if isinstance(m, nn.BatchNorm2d):
16
+ torch.nn.init.normal_(m.weight, 0.0, 0.02)
17
+ torch.nn.init.constant_(m.bias, 0)
18
+
19
+ def __init__(
20
+ self,
21
+ generator,
22
+ discriminator,
23
+ use_gpu: bool,
24
+ lambda_recon=100
25
+ ):
26
+ super().__init__()
27
+ self.save_hyperparameters()
28
+
29
+ self.gen = generator
30
+ self.disc = discriminator
31
+
32
+ # intializing weights
33
+ self.gen = self.gen.apply(self._weights_init)
34
+ self.disc = self.disc.apply(self._weights_init)
35
+ #
36
+ self.adversarial_criterion = nn.BCEWithLogitsLoss()
37
+ self.recon_criterion = nn.L1Loss()
38
+ self.lambda_l1 = lambda_recon
39
+
40
+ def _gen_step(self, sketch, coloured_sketches):
41
+ # Pix2Pix has adversarial and a reconstruction loss
42
+ # First calculate the adversarial loss
43
+ gen_coloured_sketches = self.gen(sketch)
44
+ # disc_logits = self.disc(gen_coloured_sketches, coloured_sketches)
45
+ disc_logits = self.disc(sketch, gen_coloured_sketches)
46
+ adversarial_loss = self.adversarial_criterion(disc_logits, torch.ones_like(disc_logits))
47
+ # calculate reconstruction loss
48
+ recon_loss = self.recon_criterion(gen_coloured_sketches, coloured_sketches) * self.lambda_l1
49
+ #
50
+ self.log("Gen recon_loss", recon_loss)
51
+ self.log("Gen adversarial_loss", adversarial_loss)
52
+ #
53
+ return adversarial_loss + recon_loss
54
+
55
+ def _disc_step(self, sketch, coloured_sketches):
56
+ gen_coloured_sketches = self.gen(sketch).detach()
57
+ #
58
+ # fake_logits = self.disc(gen_coloured_sketches, coloured_sketches)
59
+ fake_logits = self.disc(sketch, gen_coloured_sketches)
60
+ real_logits = self.disc(sketch, coloured_sketches)
61
+ #
62
+ fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits))
63
+ real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits))
64
+ #
65
+ self.log("PatchGAN fake_loss", fake_loss)
66
+ self.log("PatchGAN real_loss", real_loss)
67
+ return (real_loss + fake_loss) / 2
68
+
69
+ def forward(self, x):
70
+ return self.gen(x)
71
+
72
+ def training_step(self, batch, batch_idx, optimizer_idx):
73
+ real, condition = batch
74
+ loss = None
75
+ if optimizer_idx == 0:
76
+ loss = self._disc_step(real, condition)
77
+ self.log("TRAIN_PatchGAN Loss", loss)
78
+ elif optimizer_idx == 1:
79
+ loss = self._gen_step(real, condition)
80
+ self.log("TRAIN_Generator Loss", loss)
81
+ return loss
82
+
83
+ def validation_epoch_end(self, outputs) -> None:
84
+ """ Log the images"""
85
+ sketch = outputs[0]['sketch']
86
+ colour = outputs[0]['colour']
87
+ gen_coloured = self.gen(sketch)
88
+ grid_image = torchvision.utils.make_grid(
89
+ [sketch[0], colour[0], gen_coloured[0]],
90
+ normalize=True
91
+ )
92
+ self.logger.experiment.add_image(f'Image Grid {str(self.current_epoch)}', grid_image, self.current_epoch)
93
+ #plt.imshow(grid_image.permute(1, 2, 0))
94
+
95
+ def validation_step(self, batch, batch_idx):
96
+ """ Validation step """
97
+ real, condition = batch
98
+ return {
99
+ 'sketch': real,
100
+ 'colour': condition
101
+ }
102
+
103
+ def configure_optimizers(self, lr=2e-4):
104
+ gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999))
105
+ disc_opt = torch.optim.Adam(self.disc.parameters(), lr=lr, betas=(0.5, 0.999))
106
+ return disc_opt, gen_opt
107
+
108
+ # class EpochInference(pl.Callback):
109
+ # """
110
+ # Callback on each end of training epoch
111
+ # The callback will do inference on test dataloader based on corresponding checkpoints
112
+ # The results will be saved as an image with 4-rows:
113
+ # 1 - Input image e.g. grayscale edged input
114
+ # 2 - Ground-truth
115
+ # 3 - Single inference
116
+ # 4 - Mean of hundred accumulated inference
117
+ # Note that the inference have a noise factor that will generate different output on each execution
118
+ # """
119
+ #
120
+ # def __init__(self, dataloader, use_gpu: bool, *args, **kwargs):
121
+ # super().__init__(*args, **kwargs)
122
+ # self.dataloader = dataloader
123
+ # self.use_gpu = use_gpu
124
+ #
125
+ # def on_train_epoch_end(self, trainer, pl_module):
126
+ # super().on_train_epoch_end(trainer, pl_module)
127
+ # data = next(iter(self.dataloader))
128
+ # image, target = data
129
+ # if self.use_gpu:
130
+ # image = image.cuda()
131
+ # target = target.cuda()
132
+ # with torch.no_grad():
133
+ # # Take average of multiple inference as there is a random noise
134
+ # # Single
135
+ # reconstruction_init = pl_module(image)
136
+ # reconstruction_init = torch.clip(reconstruction_init, 0, 1)
137
+ # # # Mean
138
+ # # reconstruction_mean = torch.stack([pl_module(image) for _ in range(10)])
139
+ # # reconstruction_mean = torch.clip(reconstruction_mean, 0, 1)
140
+ # # reconstruction_mean = torch.mean(reconstruction_mean, dim=0)
141
+ # # Grayscale 1-D to 3-D
142
+ # # image = torch.stack([image for _ in range(3)], dim=1)
143
+ # # image = torch.squeeze(image)
144
+ # grid_image = torchvision.utils.make_grid([image[0], target[0], reconstruction_init[0]])
145
+ # torchvision.utils.save_image(grid_image, fp=f'{trainer.default_root_dir}/epoch-{trainer.current_epoch:04}.png')
app/scratch.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class GANInference:
3
+ def __init__(
4
+ self,
5
+ model: Pix2PixLitModule,
6
+ img_file: str = "/Users/nimud/Downloads/thesis_test2.png",
7
+ ) -> None:
8
+ self.img_file = img_file
9
+ self.model = model
10
+
11
+ def _get_image_from_path(self) -> torch.Tensor:
12
+ """ gets the tensor from filepath """
13
+ image = np.array(Image.open(self.img_file))
14
+ # use on inference
15
+ inference_transform = A.Compose([
16
+ A.Resize(width=256, height=256),
17
+ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
18
+ al_pytorch.ToTensorV2(),
19
+ ])
20
+ inference_img = inference_transform(image=image)['image'].unsqueeze(0)
21
+ return inference_img
22
+
23
+ def _create_grid(self, result: torch.Tensor) -> np.array:
24
+ return torchvision.utils.make_grid(
25
+ [result[0].permute(1, 2, 0).detach()],
26
+ normalize=True
27
+ )
28
+
29
+ def run(self) -> np.array:
30
+ """ Returns a plottable image """
31
+ inference_img = self._get_image_from_path()
32
+ result = self.model(inference_img)
33
+ adjusted_result = self._create_grid(result=result)
34
+ return adjusted_result
examples/__init__.py ADDED
File without changes
examples/thesis_test.png ADDED
examples/thesis_test2.png ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ pytorch_lightning
5
+ matplotlib
6
+ albumentations
7
+ pillow