import render_utils import paddle import paddle.nn as nn import paddle.nn.functional as F import numpy as np import math def crop(img, h, w): H, W = img.shape[-2:] pad_h = (H - h) // 2 pad_w = (W - w) // 2 remainder_h = (H - h) % 2 remainder_w = (W - w) % 2 img = img[:, :, pad_h:H - pad_h - remainder_h, pad_w:W - pad_w - remainder_w] return img def stroke_net_predict(img_patch, result_patch, patch_size, net_g, stroke_num, patch_num): """ stroke_net_predict """ img_patch = img_patch.transpose([0, 2, 1]).reshape([-1, 3, patch_size, patch_size]) result_patch = result_patch.transpose([0, 2, 1]).reshape([-1, 3, patch_size, patch_size]) #*----- Stroke Predictor -----*# shape_param, stroke_decision = net_g(img_patch, result_patch) stroke_decision = (stroke_decision > 0).astype('float32') #*----- sampling color -----*# grid = shape_param[:, :, :2].reshape([img_patch.shape[0] * stroke_num, 1, 1, 2]) img_temp = img_patch.unsqueeze(1).tile([1, stroke_num, 1, 1, 1]).reshape([ img_patch.shape[0] * stroke_num, 3, patch_size, patch_size]) color = nn.functional.grid_sample(img_temp, 2 * grid - 1, align_corners=False).reshape([ img_patch.shape[0], stroke_num, 3]) param = paddle.concat([shape_param, color], axis=-1) param = param.reshape([-1, 8]) param[:, :2] = param[:, :2] / 2 + 0.25 param[:, 2:4] = param[:, 2:4] / 2 param = param.reshape([1, patch_num, patch_num, stroke_num, 8]) decision = stroke_decision.reshape([1, patch_num, patch_num, stroke_num])#.astype('bool') return param, decision def param2img_parallel(param, decision, meta_brushes, cur_canvas, stroke_num=8): """ Input stroke parameters and decisions for each patch, meta brushes, current canvas, frame directory, and whether there is a border (if intermediate painting results are required). Output the painting results of adding the corresponding strokes on the current canvas. Args: param: a tensor with shape batch size x patch along height dimension x patch along width dimension x n_stroke_per_patch x n_param_per_stroke decision: a 01 tensor with shape batch size x patch along height dimension x patch along width dimension x n_stroke_per_patch meta_brushes: a tensor with shape 2 x 3 x meta_brush_height x meta_brush_width. The first slice on the batch dimension denotes vertical brush and the second one denotes horizontal brush. cur_canvas: a tensor with shape batch size x 3 x H x W, where H and W denote height and width of padded results of original images. Returns: cur_canvas: a tensor with shape batch size x 3 x H x W, denoting painting results. """ # param: b, h, w, stroke_per_patch, param_per_stroke # decision: b, h, w, stroke_per_patch b, h, w, s, p = param.shape h, w = int(h), int(w) param = param.reshape([-1, 8]) decision = decision.reshape([-1, 8]) H, W = cur_canvas.shape[-2:] is_odd_y = h % 2 == 1 is_odd_x = w % 2 == 1 render_size_y = 2 * H // h render_size_x = 2 * W // w even_idx_y = paddle.arange(0, h, 2) even_idx_x = paddle.arange(0, w, 2) if h > 1: odd_idx_y = paddle.arange(1, h, 2) if w > 1: odd_idx_x = paddle.arange(1, w, 2) cur_canvas = F.pad(cur_canvas, [render_size_x // 4, render_size_x // 4, render_size_y // 4, render_size_y // 4]) valid_foregrounds = render_utils.param2stroke(param, render_size_y, render_size_x, meta_brushes) #* ----- load dilation/erosion ---- *# dilation = render_utils.Dilation2d(m=1) erosion = render_utils.Erosion2d(m=1) #* ----- generate alphas ----- *# valid_alphas = (valid_foregrounds > 0).astype('float32') valid_foregrounds = valid_foregrounds.reshape([-1, stroke_num, 1, render_size_y, render_size_x]) valid_alphas = valid_alphas.reshape([-1, stroke_num, 1, render_size_y, render_size_x]) temp = [dilation(valid_foregrounds[:, i, :, :, :]) for i in range(stroke_num)] valid_foregrounds = paddle.stack(temp, axis=1) valid_foregrounds = valid_foregrounds.reshape([-1, 1, render_size_y, render_size_x]) temp = [erosion(valid_alphas[:, i, :, :, :]) for i in range(stroke_num)] valid_alphas = paddle.stack(temp, axis=1) valid_alphas = valid_alphas.reshape([-1, 1, render_size_y, render_size_x]) foregrounds = valid_foregrounds.reshape([-1, h, w, stroke_num, 1, render_size_y, render_size_x]) alphas = valid_alphas.reshape([-1, h, w, stroke_num, 1, render_size_y, render_size_x]) decision = decision.reshape([-1, h, w, stroke_num, 1, 1, 1]) param = param.reshape([-1, h, w, stroke_num, 8]) def partial_render(this_canvas, patch_coord_y, patch_coord_x): canvas_patch = F.unfold(this_canvas, [render_size_y, render_size_x], strides=[render_size_y // 2, render_size_x // 2]) # canvas_patch: b, 3 * py * px, h * w canvas_patch = canvas_patch.reshape([b, 3, render_size_y, render_size_x, h, w]) canvas_patch = canvas_patch.transpose([0, 4, 5, 1, 2, 3]) selected_canvas_patch = paddle.gather(canvas_patch, patch_coord_y, 1) selected_canvas_patch = paddle.gather(selected_canvas_patch, patch_coord_x, 2) selected_canvas_patch = selected_canvas_patch.reshape([0, 0, 0, 1, 3, render_size_y, render_size_x]) selected_foregrounds = paddle.gather(foregrounds, patch_coord_y, 1) selected_foregrounds = paddle.gather(selected_foregrounds, patch_coord_x, 2) selected_alphas = paddle.gather(alphas, patch_coord_y, 1) selected_alphas = paddle.gather(selected_alphas, patch_coord_x, 2) selected_decisions = paddle.gather(decision, patch_coord_y, 1) selected_decisions = paddle.gather(selected_decisions, patch_coord_x, 2) selected_color = paddle.gather(param, patch_coord_y, 1) selected_color = paddle.gather(selected_color, patch_coord_x, 2) selected_color = paddle.gather(selected_color, paddle.to_tensor([5,6,7]), 4) selected_color = selected_color.reshape([0, 0, 0, stroke_num, 3, 1, 1]) for i in range(stroke_num): i = paddle.to_tensor(i) cur_foreground = paddle.gather(selected_foregrounds, i, 3) cur_alpha = paddle.gather(selected_alphas, i, 3) cur_decision = paddle.gather(selected_decisions, i, 3) cur_color = paddle.gather(selected_color, i, 3) cur_foreground = cur_foreground * cur_color selected_canvas_patch = cur_foreground * cur_alpha * cur_decision + selected_canvas_patch * (1 - cur_alpha * cur_decision) selected_canvas_patch = selected_canvas_patch.reshape([0, 0, 0, 3, render_size_y, render_size_x]) this_canvas = selected_canvas_patch.transpose([0, 3, 1, 4, 2, 5]) # this_canvas: b, 3, h_half, py, w_half, px h_half = this_canvas.shape[2] w_half = this_canvas.shape[4] this_canvas = this_canvas.reshape([b, 3, h_half * render_size_y, w_half * render_size_x]) # this_canvas: b, 3, h_half * py, w_half * px return this_canvas # even - even area # 1 | 0 # 0 | 0 canvas = partial_render(cur_canvas, even_idx_y, even_idx_x) if not is_odd_y: canvas = paddle.concat([canvas, cur_canvas[:, :, -render_size_y // 2:, :canvas.shape[3]]], axis=2) if not is_odd_x: canvas = paddle.concat([canvas, cur_canvas[:, :, :canvas.shape[2], -render_size_x // 2:]], axis=3) cur_canvas = canvas # odd - odd area # 0 | 0 # 0 | 1 if h > 1 and w > 1: canvas = partial_render(cur_canvas, odd_idx_y, odd_idx_x) canvas = paddle.concat([cur_canvas[:, :, :render_size_y // 2, -canvas.shape[3]:], canvas], axis=2) canvas = paddle.concat([cur_canvas[:, :, -canvas.shape[2]:, :render_size_x // 2], canvas], axis=3) if is_odd_y: canvas = paddle.concat([canvas, cur_canvas[:, :, -render_size_y // 2:, :canvas.shape[3]]], axis=2) if is_odd_x: canvas = paddle.concat([canvas, cur_canvas[:, :, :canvas.shape[2], -render_size_x // 2:]], axis=3) cur_canvas = canvas # odd - even area # 0 | 0 # 1 | 0 if h > 1: canvas = partial_render(cur_canvas, odd_idx_y, even_idx_x) canvas = paddle.concat([cur_canvas[:, :, :render_size_y // 2, :canvas.shape[3]], canvas], axis=2) if is_odd_y: canvas = paddle.concat([canvas, cur_canvas[:, :, -render_size_y // 2:, :canvas.shape[3]]], axis=2) if not is_odd_x: canvas = paddle.concat([canvas, cur_canvas[:, :, :canvas.shape[2], -render_size_x // 2:]], axis=3) cur_canvas = canvas # odd - even area # 0 | 1 # 0 | 0 if w > 1: canvas = partial_render(cur_canvas, even_idx_y, odd_idx_x) canvas = paddle.concat([cur_canvas[:, :, :canvas.shape[2], :render_size_x // 2], canvas], axis=3) if not is_odd_y: canvas = paddle.concat([canvas, cur_canvas[:, :, -render_size_y // 2:, -canvas.shape[3]:]], axis=2) if is_odd_x: canvas = paddle.concat([canvas, cur_canvas[:, :, :canvas.shape[2], -render_size_x // 2:]], axis=3) cur_canvas = canvas cur_canvas = cur_canvas[:, :, render_size_y // 4:-render_size_y // 4, render_size_x // 4:-render_size_x // 4] return cur_canvas def render_parallel(original_img, net_g, meta_brushes): patch_size = 32 stroke_num = 8 with paddle.no_grad(): original_h, original_w = original_img.shape[-2:] K = max(math.ceil(math.log2(max(original_h, original_w) / patch_size)), 0) original_img_pad_size = patch_size * (2 ** K) original_img_pad = render_utils.pad(original_img, original_img_pad_size, original_img_pad_size) final_result = paddle.zeros_like(original_img) for layer in range(0, K + 1): layer_size = patch_size * (2 ** layer) img = F.interpolate(original_img_pad, (layer_size, layer_size)) result = F.interpolate(final_result, (layer_size, layer_size)) img_patch = F.unfold(img, [patch_size, patch_size], strides=[patch_size, patch_size]) result_patch = F.unfold(result, [patch_size, patch_size], strides=[patch_size, patch_size]) # There are patch_num * patch_num patches in total patch_num = (layer_size - patch_size) // patch_size + 1 param, decision = stroke_net_predict(img_patch, result_patch, patch_size, net_g, stroke_num, patch_num) #print(param.shape, decision.shape) final_result = param2img_parallel(param, decision, meta_brushes, final_result) # paint another time for last layer border_size = original_img_pad_size // (2 * patch_num) img = F.interpolate(original_img_pad, (layer_size, layer_size)) result = F.interpolate(final_result, (layer_size, layer_size)) img = F.pad(img, [patch_size // 2, patch_size // 2, patch_size // 2, patch_size // 2]) result = F.pad(result, [patch_size // 2, patch_size // 2, patch_size // 2, patch_size // 2]) img_patch = F.unfold(img, [patch_size, patch_size], strides=[patch_size, patch_size]) result_patch = F.unfold(result, [patch_size, patch_size], strides=[patch_size, patch_size]) final_result = F.pad(final_result, [border_size, border_size, border_size, border_size]) patch_num = (img.shape[2] - patch_size) // patch_size + 1 #w = (img.shape[3] - patch_size) // patch_size + 1 param, decision = stroke_net_predict(img_patch, result_patch, patch_size, net_g, stroke_num, patch_num) final_result = param2img_parallel(param, decision, meta_brushes, final_result) final_result = final_result[:, :, border_size:-border_size, border_size:-border_size] final_result = (final_result.numpy().squeeze().transpose([1,2,0])[:,:,::-1] * 255).astype(np.uint8) return final_result