sleepytaco's picture
initial commit
a4d851a
raw
history blame
No virus
6.16 kB
import torch
from torch import fft
from model.vgg19 import VGG19
from tqdm import tqdm
import model.utils as utils
import os
class TextureSynthesisCNN:
def __init__(self, tex_exemplar_image):
"""
tex_exemplar_path: ideal texture image w.r.t which we are synthesizing our textures
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# self.tex_exemplar_name = os.path.splitext(os.path.basename(tex_exemplar_path))[0]
# init VGGs
vgg_exemplar = VGG19(freeze_weights=True) # vgg to generate ideal feature maps
self.vgg_synthesis = VGG19(freeze_weights=False) # vgg whose weights will be trained
# calculate and save gram matrices for the texture exemplar once (as this does not change)
self.tex_exemplar_image = utils.load_image_tensor(tex_exemplar_image).to(self.device) # image path -> image Tensor
self.gram_matrices_ideal = vgg_exemplar(self.tex_exemplar_image).get_gram_matrices()
# set up the initial random noise image output which the network will optimize
self.output_image = torch.sigmoid(torch.randn_like(self.tex_exemplar_image)).to(self.device) # sigmoid to ensure values b/w 0 and 1
self.output_image.requires_grad = True # set to True so that the rand noise image can be optimized
self.LBFGS = torch.optim.LBFGS([self.output_image])
self.layer_weights = [10**9] * len(vgg_exemplar.output_layers) # output layer weights as per paper
self.beta = 10**5 # beta as per paper
self.losses = []
def synthesize_texture(self, num_epochs=250, display_when_done=False):
"""
- Idea: Each time the optimizer starts off from a random noise image, the network optimizes/synthesizes
the original tex exemplar in a slightly different way - i.e. introduce variation in the synthesis.
- Can be called multiple times to generate different texture variations of the tex exemplar this model holds
- IMPT: resets the output_image to random noise each time this is called
"""
self.losses = []
# reset output image to random noise
self.output_image = torch.sigmoid(torch.randn_like(self.tex_exemplar_image)).to(self.device)
self.output_image.requires_grad = True
self.LBFGS = torch.optim.LBFGS([self.output_image]) # update LBFGS to hold the new output image
synthesized_texture = self.optimize(num_epochs=num_epochs)
if display_when_done:
utils.display_image_tensor(synthesized_texture)
return synthesized_texture
def optimize(self, num_epochs=250):
"""
Perform num_epochs steps of L-BFGS algorithm
"""
progress_bar = tqdm(total=num_epochs, desc="Optimizing...")
epoch_offset = len(self.losses)
for epoch in range(num_epochs):
epoch_loss = self.get_loss().item()
progress_bar.update(1)
progress_bar.set_description(f"Loss @ Epoch {epoch_offset + epoch + 1} - {epoch_loss} ")
self.LBFGS.step(self.LBFGS_closure) # LBFGS optimizer expects loss in the form of closure function
self.losses.append(epoch_loss)
return self.output_image.detach().cpu()
def LBFGS_closure(self):
"""
Closure function for LBFGS which passes the curr output_image through vgg_synth, computes prediction gram_mats,
and uses that to compute loss for the network.
"""
self.LBFGS.zero_grad()
loss = self.get_loss()
loss.backward()
return loss
def get_loss(self):
"""
CNN loss: Generates the feature maps for the current output synth image, and uses the ideal feature maps to come
up with a loss E_l at one layer l. All the E_l's are added up to return the total cnn loss.
Spectrum loss: project tex synth to tex exemplar to come up with the spectrum constraint as per paper
Overall loss = loss_cnn + loss_spec
"""
# calculate spectrum constraint loss using current output_image and tex_exemplar_image
# - projects image I_hat (tex_synth) onto image I (tex_exemplar) and return I_proj (equation as per paper)
I_hat = utils.get_grayscale(self.output_image)
I_fourier = fft.fft2(utils.get_grayscale(self.tex_exemplar_image))
I_hat_fourier = fft.fft2(I_hat)
I_fourier_conj = torch.conj(I_fourier)
epsilon = 10e-12 # epsilon to avoid div by 0 and nan values
I_proj = fft.ifft2((I_hat_fourier * I_fourier_conj) / (torch.abs(I_hat_fourier * I_fourier_conj) + epsilon) * I_fourier)
loss_spec = (0.5 * (I_hat - I_proj) ** 2.).sum().real
# get the gram mats for the synth output_image by passing it to second vgg network
gram_matrices_pred = self.vgg_synthesis(self.output_image).get_gram_matrices()
# calculate cnn loss
loss_cnn = 0. # (w1*E1 + w2*E2 + ... + wl*El)
for i in range(len(self.layer_weights)):
# E_l = w_l * ||G_ideal_l - G_pred_l||^2
E = self.layer_weights[i] * ((self.gram_matrices_ideal[i] - gram_matrices_pred[i]) ** 2.).sum()
loss_cnn += E
return loss_cnn + (self.beta * loss_spec)
def save_textures(self, output_dir="./results/", display_when_done=False):
"""
Saves and displays the current tex_exemplar_image and the output_image tensors that this model holds
into the results directory (creates it if not yet created)
"""
tex_exemplar = utils.save_image_tensor(self.tex_exemplar_image.cpu(),
output_dir=output_dir,
image_name=f"exemplar_{self.tex_exemplar_name}.png")
tex_synth = utils.save_image_tensor(self.output_image.detach().cpu(),
output_dir=output_dir,
image_name=f"synth_{self.tex_exemplar_name}.png")
if display_when_done:
tex_exemplar.show()
print()
tex_synth.show()