Spaces:
Runtime error
Runtime error
import torch | |
from torch import fft | |
from model.vgg19 import VGG19 | |
from tqdm import tqdm | |
import model.utils as utils | |
import os | |
class TextureSynthesisCNN: | |
def __init__(self, tex_exemplar_image): | |
""" | |
tex_exemplar_path: ideal texture image w.r.t which we are synthesizing our textures | |
""" | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# self.tex_exemplar_name = os.path.splitext(os.path.basename(tex_exemplar_path))[0] | |
# init VGGs | |
vgg_exemplar = VGG19(freeze_weights=True) # vgg to generate ideal feature maps | |
self.vgg_synthesis = VGG19(freeze_weights=False) # vgg whose weights will be trained | |
# calculate and save gram matrices for the texture exemplar once (as this does not change) | |
self.tex_exemplar_image = utils.load_image_tensor(tex_exemplar_image).to(self.device) # image path -> image Tensor | |
self.gram_matrices_ideal = vgg_exemplar(self.tex_exemplar_image).get_gram_matrices() | |
# set up the initial random noise image output which the network will optimize | |
self.output_image = torch.sigmoid(torch.randn_like(self.tex_exemplar_image)).to(self.device) # sigmoid to ensure values b/w 0 and 1 | |
self.output_image.requires_grad = True # set to True so that the rand noise image can be optimized | |
self.LBFGS = torch.optim.LBFGS([self.output_image]) | |
self.layer_weights = [10**9] * len(vgg_exemplar.output_layers) # output layer weights as per paper | |
self.beta = 10**5 # beta as per paper | |
self.losses = [] | |
def synthesize_texture(self, num_epochs=250, display_when_done=False): | |
""" | |
- Idea: Each time the optimizer starts off from a random noise image, the network optimizes/synthesizes | |
the original tex exemplar in a slightly different way - i.e. introduce variation in the synthesis. | |
- Can be called multiple times to generate different texture variations of the tex exemplar this model holds | |
- IMPT: resets the output_image to random noise each time this is called | |
""" | |
self.losses = [] | |
# reset output image to random noise | |
self.output_image = torch.sigmoid(torch.randn_like(self.tex_exemplar_image)).to(self.device) | |
self.output_image.requires_grad = True | |
self.LBFGS = torch.optim.LBFGS([self.output_image]) # update LBFGS to hold the new output image | |
synthesized_texture = self.optimize(num_epochs=num_epochs) | |
if display_when_done: | |
utils.display_image_tensor(synthesized_texture) | |
return synthesized_texture | |
def optimize(self, num_epochs=250): | |
""" | |
Perform num_epochs steps of L-BFGS algorithm | |
""" | |
progress_bar = tqdm(total=num_epochs, desc="Optimizing...") | |
epoch_offset = len(self.losses) | |
for epoch in range(num_epochs): | |
epoch_loss = self.get_loss().item() | |
progress_bar.update(1) | |
progress_bar.set_description(f"Loss @ Epoch {epoch_offset + epoch + 1} - {epoch_loss} ") | |
self.LBFGS.step(self.LBFGS_closure) # LBFGS optimizer expects loss in the form of closure function | |
self.losses.append(epoch_loss) | |
return self.output_image.detach().cpu() | |
def LBFGS_closure(self): | |
""" | |
Closure function for LBFGS which passes the curr output_image through vgg_synth, computes prediction gram_mats, | |
and uses that to compute loss for the network. | |
""" | |
self.LBFGS.zero_grad() | |
loss = self.get_loss() | |
loss.backward() | |
return loss | |
def get_loss(self): | |
""" | |
CNN loss: Generates the feature maps for the current output synth image, and uses the ideal feature maps to come | |
up with a loss E_l at one layer l. All the E_l's are added up to return the total cnn loss. | |
Spectrum loss: project tex synth to tex exemplar to come up with the spectrum constraint as per paper | |
Overall loss = loss_cnn + loss_spec | |
""" | |
# calculate spectrum constraint loss using current output_image and tex_exemplar_image | |
# - projects image I_hat (tex_synth) onto image I (tex_exemplar) and return I_proj (equation as per paper) | |
I_hat = utils.get_grayscale(self.output_image) | |
I_fourier = fft.fft2(utils.get_grayscale(self.tex_exemplar_image)) | |
I_hat_fourier = fft.fft2(I_hat) | |
I_fourier_conj = torch.conj(I_fourier) | |
epsilon = 10e-12 # epsilon to avoid div by 0 and nan values | |
I_proj = fft.ifft2((I_hat_fourier * I_fourier_conj) / (torch.abs(I_hat_fourier * I_fourier_conj) + epsilon) * I_fourier) | |
loss_spec = (0.5 * (I_hat - I_proj) ** 2.).sum().real | |
# get the gram mats for the synth output_image by passing it to second vgg network | |
gram_matrices_pred = self.vgg_synthesis(self.output_image).get_gram_matrices() | |
# calculate cnn loss | |
loss_cnn = 0. # (w1*E1 + w2*E2 + ... + wl*El) | |
for i in range(len(self.layer_weights)): | |
# E_l = w_l * ||G_ideal_l - G_pred_l||^2 | |
E = self.layer_weights[i] * ((self.gram_matrices_ideal[i] - gram_matrices_pred[i]) ** 2.).sum() | |
loss_cnn += E | |
return loss_cnn + (self.beta * loss_spec) | |
def save_textures(self, output_dir="./results/", display_when_done=False): | |
""" | |
Saves and displays the current tex_exemplar_image and the output_image tensors that this model holds | |
into the results directory (creates it if not yet created) | |
""" | |
tex_exemplar = utils.save_image_tensor(self.tex_exemplar_image.cpu(), | |
output_dir=output_dir, | |
image_name=f"exemplar_{self.tex_exemplar_name}.png") | |
tex_synth = utils.save_image_tensor(self.output_image.detach().cpu(), | |
output_dir=output_dir, | |
image_name=f"synth_{self.tex_exemplar_name}.png") | |
if display_when_done: | |
tex_exemplar.show() | |
print() | |
tex_synth.show() | |