Spaces:
Runtime error
Runtime error
| import warnings | |
| warnings.filterwarnings('ignore') | |
| warnings.simplefilter('ignore') | |
| import argparse | |
| import math | |
| import os | |
| import sys | |
| import time | |
| import traceback | |
| import numpy as np | |
| import PIL | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import wandb | |
| from PIL import Image | |
| from torchvision import models, transforms | |
| from tqdm.auto import tqdm, trange | |
| import config | |
| import sketch_utils as utils | |
| from models.loss import Loss | |
| from models.painter_params import Painter, PainterOptimizer | |
| from IPython.display import display, SVG | |
| def load_renderer(args, target_im=None, mask=None): | |
| renderer = Painter(num_strokes=args.num_paths, args=args, | |
| num_segments=args.num_segments, | |
| imsize=args.image_scale, | |
| device=args.device, | |
| target_im=target_im, | |
| mask=mask) | |
| renderer = renderer.to(args.device) | |
| return renderer | |
| def get_target(args): | |
| target = Image.open(args.target) | |
| if target.mode == "RGBA": | |
| # Create a white rgba background | |
| new_image = Image.new("RGBA", target.size, "WHITE") | |
| # Paste the image on the background. | |
| new_image.paste(target, (0, 0), target) | |
| target = new_image | |
| target = target.convert("RGB") | |
| masked_im, mask = utils.get_mask_u2net(args, target) | |
| if args.mask_object: | |
| target = masked_im | |
| if args.fix_scale: | |
| target = utils.fix_image_scale(target) | |
| transforms_ = [] | |
| if target.size[0] != target.size[1]: | |
| transforms_.append(transforms.Resize( | |
| (args.image_scale, args.image_scale), interpolation=PIL.Image.BICUBIC)) | |
| else: | |
| transforms_.append(transforms.Resize( | |
| args.image_scale, interpolation=PIL.Image.BICUBIC)) | |
| transforms_.append(transforms.CenterCrop(args.image_scale)) | |
| transforms_.append(transforms.ToTensor()) | |
| data_transforms = transforms.Compose(transforms_) | |
| target_ = data_transforms(target).unsqueeze(0).to(args.device) | |
| return target_, mask | |
| def main(args): | |
| loss_func = Loss(args) | |
| inputs, mask = get_target(args) | |
| utils.log_input(args.use_wandb, 0, inputs, args.output_dir) | |
| renderer = load_renderer(args, inputs, mask) | |
| optimizer = PainterOptimizer(args, renderer) | |
| counter = 0 | |
| configs_to_save = {"loss_eval": []} | |
| best_loss, best_fc_loss = 100, 100 | |
| best_iter, best_iter_fc = 0, 0 | |
| min_delta = 1e-5 | |
| terminate = False | |
| renderer.set_random_noise(0) | |
| img = renderer.init_image(stage=0) | |
| optimizer.init_optimizers() | |
| # not using tdqm for jupyter demo | |
| if args.display: | |
| epoch_range = range(args.num_iter) | |
| else: | |
| epoch_range = tqdm(range(args.num_iter)) | |
| for epoch in epoch_range: | |
| if not args.display: | |
| epoch_range.refresh() | |
| renderer.set_random_noise(epoch) | |
| if args.lr_scheduler: | |
| optimizer.update_lr(counter) | |
| start = time.time() | |
| optimizer.zero_grad_() | |
| sketches = renderer.get_image().to(args.device) | |
| losses_dict = loss_func(sketches, inputs.detach( | |
| ), renderer.get_color_parameters(), renderer, counter, optimizer) | |
| loss = sum(list(losses_dict.values())) | |
| loss.backward() | |
| optimizer.step_() | |
| if epoch % args.save_interval == 0: | |
| utils.plot_batch(inputs, sketches, f"{args.output_dir}/jpg_logs", counter, | |
| use_wandb=args.use_wandb, title=f"iter{epoch}.jpg") | |
| renderer.save_svg( | |
| f"{args.output_dir}/svg_logs", f"svg_iter{epoch}") | |
| if epoch % args.eval_interval == 0: | |
| with torch.no_grad(): | |
| losses_dict_eval = loss_func(sketches, inputs, renderer.get_color_parameters( | |
| ), renderer.get_points_parans(), counter, optimizer, mode="eval") | |
| loss_eval = sum(list(losses_dict_eval.values())) | |
| configs_to_save["loss_eval"].append(loss_eval.item()) | |
| for k in losses_dict_eval.keys(): | |
| if k not in configs_to_save.keys(): | |
| configs_to_save[k] = [] | |
| configs_to_save[k].append(losses_dict_eval[k].item()) | |
| if args.clip_fc_loss_weight: | |
| if losses_dict_eval["fc"].item() < best_fc_loss: | |
| best_fc_loss = losses_dict_eval["fc"].item( | |
| ) / args.clip_fc_loss_weight | |
| best_iter_fc = epoch | |
| # print( | |
| # f"eval iter[{epoch}/{args.num_iter}] loss[{loss.item()}] time[{time.time() - start}]") | |
| cur_delta = loss_eval.item() - best_loss | |
| if abs(cur_delta) > min_delta: | |
| if cur_delta < 0: | |
| best_loss = loss_eval.item() | |
| best_iter = epoch | |
| terminate = False | |
| utils.plot_batch( | |
| inputs, sketches, args.output_dir, counter, use_wandb=args.use_wandb, title="best_iter.jpg") | |
| renderer.save_svg(args.output_dir, "best_iter") | |
| if args.use_wandb: | |
| wandb.run.summary["best_loss"] = best_loss | |
| wandb.run.summary["best_loss_fc"] = best_fc_loss | |
| wandb_dict = {"delta": cur_delta, | |
| "loss_eval": loss_eval.item()} | |
| for k in losses_dict_eval.keys(): | |
| wandb_dict[k + "_eval"] = losses_dict_eval[k].item() | |
| wandb.log(wandb_dict, step=counter) | |
| if abs(cur_delta) <= min_delta: | |
| if terminate: | |
| break | |
| terminate = True | |
| if counter == 0 and args.attention_init: | |
| utils.plot_atten(renderer.get_attn(), renderer.get_thresh(), inputs, renderer.get_inds(), | |
| args.use_wandb, "{}/{}.jpg".format( | |
| args.output_dir, "attention_map"), | |
| args.saliency_model, args.display_logs) | |
| if args.use_wandb: | |
| wandb_dict = {"loss": loss.item(), "lr": optimizer.get_lr()} | |
| for k in losses_dict.keys(): | |
| wandb_dict[k] = losses_dict[k].item() | |
| wandb.log(wandb_dict, step=counter) | |
| counter += 1 | |
| renderer.save_svg(args.output_dir, "final_svg") | |
| path_svg = os.path.join(args.output_dir, "best_iter.svg") | |
| utils.log_sketch_summary_final( | |
| path_svg, args.use_wandb, args.device, best_iter, best_loss, "best total") | |
| return configs_to_save | |
| if __name__ == "__main__": | |
| args = config.parse_arguments() | |
| final_config = vars(args) | |
| try: | |
| configs_to_save = main(args) | |
| except BaseException as err: | |
| print(f"Unexpected error occurred:\n {err}") | |
| print(traceback.format_exc()) | |
| sys.exit(1) | |
| for k in configs_to_save.keys(): | |
| final_config[k] = configs_to_save[k] | |
| np.save(f"{args.output_dir}/config.npy", final_config) | |
| if args.use_wandb: | |
| wandb.finish() | |