Neural_Style_Texture / src /commands.py
priyam314's picture
first commit
cbcb207
raw
history blame
3.11 kB
from abc import ABC, abstractmethod
import numpy as np
import utils
import torch
class Tuning(ABC):
@abstractmethod
def Image(self, image):
pass
class TuningReconstruction(Tuning):
def __init__(self, model, optimizer, target_representation,
content_feature_maps_index, style_feature_maps_indices):
self.model = model
self.optimizer = optimizer
self.target_representation = target_representation
self.content_feature_maps_index = content_feature_maps_index
self.style_feature_maps_indices = style_feature_maps_indices
def Image(self, image):
# Finds the current representation
set_of_feature_maps = self.model(image)
if utils.yamlGet('reconstruct') == 'Content':
current_representation = set_of_feature_maps[
self.content_feature_maps_index].squeeze(axis=0)
elif utils.yamlGet('reconstruct') == 'Style':
current_representation = [
utils.gram_matrix(fmaps)
for i, fmaps in enumerate(set_of_feature_maps)
if i in self.style_feature_maps_indices
]
loss = 0.0
if utils.yamlGet('reconstruct') == 'Content':
loss = torch.nn.MSELoss(reduction='mean')(
self.target_representation, current_representation)
elif utils.yamlGet('reconstruct') == 'Style':
for gram_gt, gram_hat in zip(self.target_representation,
current_representation):
loss += (1 / len(self.target_representation)) * \
torch.nn.MSELoss(
reduction='sum')(gram_gt[0], gram_hat[0])
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
return loss.item(), current_representation
class Reconstruct(ABC):
@abstractmethod
def Visualize(self):
pass
class ContentReconstruct(Reconstruct):
"""
tcr -> target_content_representation
"""
def __init__(self, feature_maps):
self.fm = feature_maps
self.tcr = self.fm['set_of_feature_maps'][
self.fm['content_feature_maps_index_name'][0]].squeeze(axis=0)
self.nfm = self.tcr.size()[0]
def Visualize(self):
for i in range(self.nfm):
feature_map = self.tcr[i].to('cpu').numpy()
feature_map = np.uint8(utils.get_uint8_range(feature_map))
# plt.imshow(feature_map)
# plt.title(
# f'Feature map {i+1}/{num_of_feature_maps} from layer'
# f' {content_feature_maps_index_name[1]} '
# f'(model={config["model"]}) for'
# f' {config["content_img_name"]} image.'
# )
# plt.show()
filename = f'fm_{config["model"]}_{content_feature_maps_index_name[1]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}'
utils.save_image(feature_map, os.path.join(dump_path, filename))
class StyleReconstruct(Reconstruct):
pass
class Invoker:
pass