Spaces:
Runtime error
Runtime error
import os | |
import src.utils.utils as utils | |
from src.utils.video_utils import create_video_from_intermediate_results | |
import torch | |
from torch.autograd import Variable | |
from torch.optim import Adam, LBFGS | |
import numpy as np | |
def make_tuning_step(optimizer, config): | |
def tuning_step(optimizing_img): | |
config.current_set_of_feature_maps = config.neural_net(optimizing_img) | |
loss, config.current_representation = utils.getCurrentData(config) | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
return loss.item(), config.current_representation | |
return tuning_step | |
def reconstruct_image_from_representation(): | |
dump_path = os.path.join(os.path.dirname(__file__), "data/reconstruct") | |
config = utils.Config() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
img, img_path = utils.getImageAndPath(device) | |
white_noise_img = np.random.uniform(-90., 90., | |
img.shape).astype(np.float32) | |
init_img = torch.from_numpy(white_noise_img).float().to(device) | |
optimizing_img = Variable(init_img, requires_grad=True) | |
# indices pick relevant feature maps (say conv4_1, relu1_1, etc.) | |
output = list(utils.prepare_model(device)) | |
config.neural_net = output[0] | |
content_feature_maps_index_name = output[1] | |
style_feature_maps_indices_names = output[2] | |
config.content_feature_maps_index = content_feature_maps_index_name[0] | |
config.style_feature_maps_indices = style_feature_maps_indices_names[0] | |
config.current_set_of_feature_maps = config.neural_net(img) | |
config.target_content_representation = config.current_set_of_feature_maps[ | |
config.content_feature_maps_index].squeeze(axis=0) | |
config.target_style_representation = [ | |
utils.gram_matrix(fmaps) | |
for i, fmaps in enumerate(config.current_set_of_feature_maps) | |
if i in config.style_feature_maps_indices | |
] | |
if utils.yamlGet('reconstruct') == "Content": | |
config.target_representation = config.target_content_representation | |
num_of_feature_maps = config.target_content_representation.size()[0] | |
for i in range(num_of_feature_maps): | |
feature_map = config.target_content_representation[i].to( | |
'cpu').numpy() | |
feature_map = np.uint8(utils.get_uint8_range(feature_map)) | |
# filename = f'fm_{config["model"]}_{content_feature_maps_index_name[1]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}' | |
# utils.save_image(feature_map, os.path.join(dump_path, filename)) | |
elif utils.yamlGet('reconstruct') == "Style": | |
config.target_representation = config.target_style_representation | |
num_of_gram_matrices = len(config.target_style_representation) | |
for i in range(num_of_gram_matrices): | |
Gram_matrix = config.target_style_representation[i].squeeze( | |
axis=0).to('cpu').numpy() | |
Gram_matrix = np.uint8(utils.get_uint8_range(Gram_matrix)) | |
# filename = f'gram_{config["model"]}_{style_feature_maps_indices_names[1][i]}_{str(i).zfill(config["img_format"][0])}{config["img_format"][1]}' | |
# utils.save_image(Gram_matrix, os.path.join(dump_path, filename)) | |
if utils.yamlGet('optimizer') == 'Adam': | |
optimizer = Adam((optimizing_img, ), lr=utils.yamlGet('learning_rate')) | |
tuning_step = make_tuning_step(optimizer, config) | |
for it in range(utils.yamlGet('optimizer')): | |
tuning_step(optimizing_img) | |
with torch.no_grad(): | |
utils.save_optimizing_image(optimizing_img, dump_path, it) | |
elif utils.yamlGet('optimizer') == 'LBFGS': | |
optimizer = LBFGS((optimizing_img, ), | |
max_iter=utils.yamlGet('optimizer'), | |
line_search_fn='strong_wolfe') | |
cnt = 0 | |
def closure(): | |
nonlocal cnt | |
loss = utils.getLBFGSReconstructLoss(config, optimizing_img) | |
loss.backward() | |
with torch.no_grad(): | |
utils.save_optimizing_image(optimizing_img, dump_path, cnt) | |
cnt += 1 | |
return loss | |
optimizer.step(closure) | |
return dump_path | |
if __name__ == "__main__": | |
# reconstruct style or content image purely from their representation | |
results_path = reconstruct_image_from_representation() | |
create_video_from_intermediate_results(results_path) | |