Spaces:
Runtime error
Runtime error
File size: 3,113 Bytes
cbcb207 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from abc import ABC, abstractmethod
import numpy as np
import utils
import torch
class Tuning(ABC):
@abstractmethod
def Image(self, image):
pass
class TuningReconstruction(Tuning):
def __init__(self, model, optimizer, target_representation,
content_feature_maps_index, style_feature_maps_indices):
self.model = model
self.optimizer = optimizer
self.target_representation = target_representation
self.content_feature_maps_index = content_feature_maps_index
self.style_feature_maps_indices = style_feature_maps_indices
def Image(self, image):
# Finds the current representation
set_of_feature_maps = self.model(image)
if utils.yamlGet('reconstruct') == 'Content':
current_representation = set_of_feature_maps[
self.content_feature_maps_index].squeeze(axis=0)
elif utils.yamlGet('reconstruct') == 'Style':
current_representation = [
utils.gram_matrix(fmaps)
for i, fmaps in enumerate(set_of_feature_maps)
if i in self.style_feature_maps_indices
]
loss = 0.0
if utils.yamlGet('reconstruct') == 'Content':
loss = torch.nn.MSELoss(reduction='mean')(
self.target_representation, current_representation)
elif utils.yamlGet('reconstruct') == 'Style':
for gram_gt, gram_hat in zip(self.target_representation,
current_representation):
loss += (1 / len(self.target_representation)) * \
torch.nn.MSELoss(
reduction='sum')(gram_gt[0], gram_hat[0])
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
return loss.item(), current_representation
class Reconstruct(ABC):
@abstractmethod
def Visualize(self):
pass
class ContentReconstruct(Reconstruct):
"""
tcr -> target_content_representation
"""
def __init__(self, feature_maps):
self.fm = feature_maps
self.tcr = self.fm['set_of_feature_maps'][
self.fm['content_feature_maps_index_name'][0]].squeeze(axis=0)
self.nfm = self.tcr.size()[0]
def Visualize(self):
for i in range(self.nfm):
feature_map = self.tcr[i].to('cpu').numpy()
feature_map = np.uint8(utils.get_uint8_range(feature_map))
# plt.imshow(feature_map)
# plt.title(
# f'Feature map {i+1}/{num_of_feature_maps} from layer'
# f' {content_feature_maps_index_name[1]} '
# f'(model={config["model"]}) for'
# f' {config["content_img_name"]} image.'
# )
# plt.show()
filename = f'fm_{config["model"]}_{content_feature_maps_index_name[1]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}'
utils.save_image(feature_map, os.path.join(dump_path, filename))
class StyleReconstruct(Reconstruct):
pass
class Invoker:
pass
|