import torch import torch.nn.functional as F import numpy as np from PIL import Image import network import os import math import render_utils import paddle import paddle.nn as nn import paddle.nn.functional as F import cv2 import render_parallel import render_serial def main(input_path, model_path, output_dir, need_animation=False, resize_h=None, resize_w=None, serial=False): if not os.path.exists(output_dir): os.mkdir(output_dir) input_name = os.path.basename(input_path) output_path = os.path.join(output_dir, input_name) frame_dir = None if need_animation: if not serial: print('It must be under serial mode if animation results are required, so serial flag is set to True!') serial = True frame_dir = os.path.join(output_dir, input_name[:input_name.find('.')]) if not os.path.exists(frame_dir): os.mkdir(frame_dir) stroke_num = 8 #* ----- load model ----- *# # paddle.set_device('gpu') paddle.set_device('cpu') # 2021-12-21 jkang edited to "cpu" net_g = network.Painter(5, stroke_num, 256, 8, 3, 3) net_g.set_state_dict(paddle.load(model_path)) net_g.eval() for param in net_g.parameters(): param.stop_gradient = True #* ----- load brush ----- *# brush_large_vertical = render_utils.read_img('brush/brush_large_vertical.png', 'L') brush_large_horizontal = render_utils.read_img('brush/brush_large_horizontal.png', 'L') meta_brushes = paddle.concat([brush_large_vertical, brush_large_horizontal], axis=0) import time t0 = time.time() original_img = render_utils.read_img(input_path, 'RGB', resize_h, resize_w) if serial: final_result_list = render_serial.render_serial(original_img, net_g, meta_brushes) if need_animation: print("total frame:", len(final_result_list)) for idx, frame in enumerate(final_result_list): cv2.imwrite(os.path.join(frame_dir, '%03d.png' %idx), frame) else: cv2.imwrite(output_path, final_result_list[-1]) else: final_result = render_parallel.render_parallel(original_img, net_g, meta_brushes) cv2.imwrite(output_path, final_result) print("total infer time:", time.time() - t0) if __name__ == '__main__': main(input_path='input/chicago.jpg', model_path='paint_best.pdparams', output_dir='output/', need_animation=True, # whether need intermediate results for animation. resize_h=512, # resize original input to this size. None means do not resize. resize_w=512, # resize original input to this size. None means do not resize. serial=True) # if need animation, serial must be True.