sleepytaco commited on
Commit
a4d851a
1 Parent(s): cc7aecd

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .DS_Store
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from model.model import TextureSynthesisCNN
4
+ from model.utils import convert_tensor_to_PIL_image
5
+
6
+
7
+ def image_mod(image):
8
+ return image.rotate(45)
9
+
10
+ def synth_image(image):
11
+ synthesizer = TextureSynthesisCNN(tex_exemplar_image=image)
12
+ output_tensor = synthesizer.synthesize_texture(num_epochs=10)
13
+ output_image = convert_tensor_to_PIL_image(output_tensor)
14
+ return output_image
15
+
16
+
17
+ demo = gr.Interface(
18
+ fn=synth_image,
19
+ inputs=[gr.Image(type="numpy")],
20
+ outputs=[gr.Image(type="pil")],
21
+ flagging_options=["blurry", "incorrect"],
22
+ examples=[
23
+ os.path.join(os.path.dirname(__file__), "images/blotchy_0025.png"),
24
+ os.path.join(os.path.dirname(__file__), "images/blotchy_0027.png"),
25
+ os.path.join(os.path.dirname(__file__), "images/cracked_0080.png"),
26
+ os.path.join(os.path.dirname(__file__), "images/scenery.png"),
27
+ ],
28
+ )
29
+
30
+ if __name__ == "__main__":
31
+ demo.launch()
images/blotchy_0025.png ADDED
images/blotchy_0027.png ADDED
images/cracked_0080.png ADDED
images/scenery.png ADDED
model/__init__.py ADDED
File without changes
model/main.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import TextureSynthesisCNN
2
+
3
+
4
+ def main():
5
+ synthesizer = TextureSynthesisCNN(tex_exemplar_path="data/cracked_0063.png")
6
+ synthesizer.synthesize_texture(num_epochs=10)
7
+ # synthesizer.optimize(num_epochs=500) # can call this on an existing model object to continue optimization
8
+ synthesizer.save_textures(output_dir="./results/", # directory automatically is created if not found
9
+ display_when_done=True) # saves exemplar and synth into the output_dir folder
10
+
11
+
12
+ if __name__ == '__main__':
13
+ main()
model/model.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import fft
3
+ from model.vgg19 import VGG19
4
+ from tqdm import tqdm
5
+ import model.utils as utils
6
+ import os
7
+
8
+
9
+ class TextureSynthesisCNN:
10
+ def __init__(self, tex_exemplar_image):
11
+ """
12
+ tex_exemplar_path: ideal texture image w.r.t which we are synthesizing our textures
13
+ """
14
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ # self.tex_exemplar_name = os.path.splitext(os.path.basename(tex_exemplar_path))[0]
16
+
17
+ # init VGGs
18
+ vgg_exemplar = VGG19(freeze_weights=True) # vgg to generate ideal feature maps
19
+ self.vgg_synthesis = VGG19(freeze_weights=False) # vgg whose weights will be trained
20
+
21
+ # calculate and save gram matrices for the texture exemplar once (as this does not change)
22
+ self.tex_exemplar_image = utils.load_image_tensor(tex_exemplar_image).to(self.device) # image path -> image Tensor
23
+ self.gram_matrices_ideal = vgg_exemplar(self.tex_exemplar_image).get_gram_matrices()
24
+
25
+ # set up the initial random noise image output which the network will optimize
26
+ self.output_image = torch.sigmoid(torch.randn_like(self.tex_exemplar_image)).to(self.device) # sigmoid to ensure values b/w 0 and 1
27
+ self.output_image.requires_grad = True # set to True so that the rand noise image can be optimized
28
+
29
+ self.LBFGS = torch.optim.LBFGS([self.output_image])
30
+ self.layer_weights = [10**9] * len(vgg_exemplar.output_layers) # output layer weights as per paper
31
+ self.beta = 10**5 # beta as per paper
32
+ self.losses = []
33
+
34
+ def synthesize_texture(self, num_epochs=250, display_when_done=False):
35
+ """
36
+ - Idea: Each time the optimizer starts off from a random noise image, the network optimizes/synthesizes
37
+ the original tex exemplar in a slightly different way - i.e. introduce variation in the synthesis.
38
+ - Can be called multiple times to generate different texture variations of the tex exemplar this model holds
39
+ - IMPT: resets the output_image to random noise each time this is called
40
+ """
41
+ self.losses = []
42
+
43
+ # reset output image to random noise
44
+ self.output_image = torch.sigmoid(torch.randn_like(self.tex_exemplar_image)).to(self.device)
45
+ self.output_image.requires_grad = True
46
+ self.LBFGS = torch.optim.LBFGS([self.output_image]) # update LBFGS to hold the new output image
47
+
48
+ synthesized_texture = self.optimize(num_epochs=num_epochs)
49
+ if display_when_done:
50
+ utils.display_image_tensor(synthesized_texture)
51
+
52
+ return synthesized_texture
53
+
54
+ def optimize(self, num_epochs=250):
55
+ """
56
+ Perform num_epochs steps of L-BFGS algorithm
57
+ """
58
+ progress_bar = tqdm(total=num_epochs, desc="Optimizing...")
59
+ epoch_offset = len(self.losses)
60
+
61
+ for epoch in range(num_epochs):
62
+ epoch_loss = self.get_loss().item()
63
+ progress_bar.update(1)
64
+ progress_bar.set_description(f"Loss @ Epoch {epoch_offset + epoch + 1} - {epoch_loss} ")
65
+ self.LBFGS.step(self.LBFGS_closure) # LBFGS optimizer expects loss in the form of closure function
66
+ self.losses.append(epoch_loss)
67
+
68
+ return self.output_image.detach().cpu()
69
+
70
+ def LBFGS_closure(self):
71
+ """
72
+ Closure function for LBFGS which passes the curr output_image through vgg_synth, computes prediction gram_mats,
73
+ and uses that to compute loss for the network.
74
+ """
75
+ self.LBFGS.zero_grad()
76
+ loss = self.get_loss()
77
+ loss.backward()
78
+ return loss
79
+
80
+ def get_loss(self):
81
+ """
82
+ CNN loss: Generates the feature maps for the current output synth image, and uses the ideal feature maps to come
83
+ up with a loss E_l at one layer l. All the E_l's are added up to return the total cnn loss.
84
+ Spectrum loss: project tex synth to tex exemplar to come up with the spectrum constraint as per paper
85
+ Overall loss = loss_cnn + loss_spec
86
+ """
87
+ # calculate spectrum constraint loss using current output_image and tex_exemplar_image
88
+ # - projects image I_hat (tex_synth) onto image I (tex_exemplar) and return I_proj (equation as per paper)
89
+ I_hat = utils.get_grayscale(self.output_image)
90
+ I_fourier = fft.fft2(utils.get_grayscale(self.tex_exemplar_image))
91
+ I_hat_fourier = fft.fft2(I_hat)
92
+ I_fourier_conj = torch.conj(I_fourier)
93
+ epsilon = 10e-12 # epsilon to avoid div by 0 and nan values
94
+ I_proj = fft.ifft2((I_hat_fourier * I_fourier_conj) / (torch.abs(I_hat_fourier * I_fourier_conj) + epsilon) * I_fourier)
95
+ loss_spec = (0.5 * (I_hat - I_proj) ** 2.).sum().real
96
+
97
+ # get the gram mats for the synth output_image by passing it to second vgg network
98
+ gram_matrices_pred = self.vgg_synthesis(self.output_image).get_gram_matrices()
99
+
100
+ # calculate cnn loss
101
+ loss_cnn = 0. # (w1*E1 + w2*E2 + ... + wl*El)
102
+ for i in range(len(self.layer_weights)):
103
+ # E_l = w_l * ||G_ideal_l - G_pred_l||^2
104
+ E = self.layer_weights[i] * ((self.gram_matrices_ideal[i] - gram_matrices_pred[i]) ** 2.).sum()
105
+ loss_cnn += E
106
+
107
+ return loss_cnn + (self.beta * loss_spec)
108
+
109
+ def save_textures(self, output_dir="./results/", display_when_done=False):
110
+ """
111
+ Saves and displays the current tex_exemplar_image and the output_image tensors that this model holds
112
+ into the results directory (creates it if not yet created)
113
+ """
114
+ tex_exemplar = utils.save_image_tensor(self.tex_exemplar_image.cpu(),
115
+ output_dir=output_dir,
116
+ image_name=f"exemplar_{self.tex_exemplar_name}.png")
117
+ tex_synth = utils.save_image_tensor(self.output_image.detach().cpu(),
118
+ output_dir=output_dir,
119
+ image_name=f"synth_{self.tex_exemplar_name}.png")
120
+ if display_when_done:
121
+ tex_exemplar.show()
122
+ print()
123
+ tex_synth.show()
model/utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from torchvision import transforms
3
+ from skimage import io, transform, util
4
+ import numpy as np
5
+ import os
6
+
7
+ """
8
+ Contains utility functions to work with images in tensor and jpg/png forms
9
+ """
10
+
11
+
12
+ def load_image_tensor(image, path=""):
13
+ """
14
+ Returns Image as a Pytorch Tensor of shape ((img_size),3).
15
+ Values between 0 and 1.
16
+ """
17
+ img_size = (256, 256)
18
+ # image = io.imread(path)
19
+ cropped_image = util.crop(image, ((0, 0), (0, image.shape[1] - image.shape[0]), (0, 0)))
20
+ resized_image = (transform.resize(image=cropped_image, output_shape=img_size, anti_aliasing=True))
21
+ to_tensor = transforms.Compose([transforms.ToTensor()])
22
+ tensor = to_tensor(resized_image)
23
+ # tensor = tensor.permute(1,2,0) # the model expects w, h, 3!
24
+ return tensor.float()
25
+
26
+
27
+ def convert_tensor_to_PIL_image(image_tensor):
28
+ output_image = image_tensor.numpy().transpose(1, 2, 0)
29
+ output_image = np.clip(output_image, 0, 1) * 255
30
+ output_image = output_image.astype(np.uint8)
31
+ output_image = Image.fromarray(output_image)
32
+ return output_image
33
+
34
+ def save_image_tensor(tensor, output_dir="./", image_name="output.png"):
35
+ """
36
+ Saves a 3D tensor as an image.
37
+ """
38
+ output_image = tensor.numpy().transpose(1, 2, 0)
39
+ output_image = np.clip(output_image, 0, 1) * 255
40
+ output_image = output_image.astype(np.uint8)
41
+ output_image = Image.fromarray(output_image)
42
+
43
+ if not os.path.exists(output_dir):
44
+ os.mkdir(output_dir)
45
+ output_image.save(output_dir + image_name)
46
+
47
+ return output_image
48
+
49
+
50
+ def display_image_tensor(tensor):
51
+ """
52
+ Displays the passed in 3D image tensor
53
+ """
54
+ output_image = tensor.numpy().transpose(1, 2, 0)
55
+ output_image = np.clip(output_image, 0, 1) * 255
56
+ output_image = output_image.astype(np.uint8)
57
+ output_image = Image.fromarray(output_image)
58
+ output_image.show()
59
+
60
+
61
+ def get_grayscale(tensor):
62
+ """
63
+ Converts a 3D image tensor to greyscale
64
+ """
65
+ greyscale_transform = transforms.Grayscale()
66
+ return greyscale_transform(tensor)
model/vgg19.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torchvision.models import VGG19_Weights, vgg19
4
+
5
+
6
+ class VGG19:
7
+ """
8
+ Custom version of VGG19 with the maxpool layers replaced with avgpool as per the paper
9
+ """
10
+ def __init__(self, freeze_weights):
11
+ """
12
+ If True, the gradients for the VGG params are turned off
13
+ """
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ self.model = vgg19(weights=VGG19_Weights(VGG19_Weights.DEFAULT)).to(device)
16
+
17
+ # note: added one extra maxpool (layer 36) from the vgg... worked well so kept it in
18
+ self.output_layers = [0, 4, 9, 18, 27, 36] # vgg19 layers [convlayer1, maxpool, ..., maxpool]
19
+ for layer in self.output_layers[1:]: # convert the maxpool layers to an avgpool
20
+ self.model.features[layer] = nn.AvgPool2d(kernel_size=2, stride=2)
21
+
22
+ self.feature_maps = []
23
+ for param in self.model.parameters():
24
+ if freeze_weights:
25
+ param.requires_grad = False
26
+ else:
27
+ param.requires_grad = True
28
+
29
+ def __call__(self, x):
30
+ """
31
+ Take in image, pass it through the VGG, capture feature maps at each of the output layers of VGG
32
+ """
33
+ self.feature_maps = []
34
+ for index, layer in enumerate(self.model.features):
35
+ # print(layer)
36
+ x = layer(x) # pass the img through the layer to get feature maps of the img
37
+ if index in self.output_layers:
38
+ self.feature_maps.append(x)
39
+ if index == self.output_layers[-1]:
40
+ # stop VGG execution as we've captured the feature maps from all the important layers
41
+ break
42
+
43
+ return self
44
+
45
+ def get_gram_matrices(self):
46
+ """
47
+ Convert the featuremaps captured by the call method into gram matrices
48
+ """
49
+ gram_matrices = []
50
+ for fm in self.feature_maps:
51
+ n, x, y = fm.size() # num filters n and (filter dims x and y)
52
+ F = fm.reshape(n, x * y) # reshape filterbank into a 2D mat before doing auto correlation
53
+ gram_mat = (F @ F.t()) / (4. * n * x * y) # auto corr + normalize by layer output dims
54
+ gram_matrices.append(gram_mat)
55
+
56
+ return gram_matrices
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch~=2.0.0
3
+ torchvision~=0.15.1
4
+ scikit-image~=0.20.0
5
+ tqdm~=4.64.1
6
+ numpy~=1.24.1
7
+ pillow~=9.4.0