demo-painttransformer / render_parallel.py
jaekookang
first upload
9ff1108
import render_utils
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
import math
def crop(img, h, w):
H, W = img.shape[-2:]
pad_h = (H - h) // 2
pad_w = (W - w) // 2
remainder_h = (H - h) % 2
remainder_w = (W - w) % 2
img = img[:, :, pad_h:H - pad_h - remainder_h, pad_w:W - pad_w - remainder_w]
return img
def stroke_net_predict(img_patch, result_patch, patch_size, net_g, stroke_num, patch_num):
"""
stroke_net_predict
"""
img_patch = img_patch.transpose([0, 2, 1]).reshape([-1, 3, patch_size, patch_size])
result_patch = result_patch.transpose([0, 2, 1]).reshape([-1, 3, patch_size, patch_size])
#*----- Stroke Predictor -----*#
shape_param, stroke_decision = net_g(img_patch, result_patch)
stroke_decision = (stroke_decision > 0).astype('float32')
#*----- sampling color -----*#
grid = shape_param[:, :, :2].reshape([img_patch.shape[0] * stroke_num, 1, 1, 2])
img_temp = img_patch.unsqueeze(1).tile([1, stroke_num, 1, 1, 1]).reshape([
img_patch.shape[0] * stroke_num, 3, patch_size, patch_size])
color = nn.functional.grid_sample(img_temp, 2 * grid - 1, align_corners=False).reshape([
img_patch.shape[0], stroke_num, 3])
param = paddle.concat([shape_param, color], axis=-1)
param = param.reshape([-1, 8])
param[:, :2] = param[:, :2] / 2 + 0.25
param[:, 2:4] = param[:, 2:4] / 2
param = param.reshape([1, patch_num, patch_num, stroke_num, 8])
decision = stroke_decision.reshape([1, patch_num, patch_num, stroke_num])#.astype('bool')
return param, decision
def param2img_parallel(param, decision, meta_brushes, cur_canvas, stroke_num=8):
"""
Input stroke parameters and decisions for each patch, meta brushes, current canvas, frame directory,
and whether there is a border (if intermediate painting results are required).
Output the painting results of adding the corresponding strokes on the current canvas.
Args:
param: a tensor with shape batch size x patch along height dimension x patch along width dimension
x n_stroke_per_patch x n_param_per_stroke
decision: a 01 tensor with shape batch size x patch along height dimension x patch along width dimension
x n_stroke_per_patch
meta_brushes: a tensor with shape 2 x 3 x meta_brush_height x meta_brush_width.
The first slice on the batch dimension denotes vertical brush and the second one denotes horizontal brush.
cur_canvas: a tensor with shape batch size x 3 x H x W,
where H and W denote height and width of padded results of original images.
Returns:
cur_canvas: a tensor with shape batch size x 3 x H x W, denoting painting results.
"""
# param: b, h, w, stroke_per_patch, param_per_stroke
# decision: b, h, w, stroke_per_patch
b, h, w, s, p = param.shape
h, w = int(h), int(w)
param = param.reshape([-1, 8])
decision = decision.reshape([-1, 8])
H, W = cur_canvas.shape[-2:]
is_odd_y = h % 2 == 1
is_odd_x = w % 2 == 1
render_size_y = 2 * H // h
render_size_x = 2 * W // w
even_idx_y = paddle.arange(0, h, 2)
even_idx_x = paddle.arange(0, w, 2)
if h > 1:
odd_idx_y = paddle.arange(1, h, 2)
if w > 1:
odd_idx_x = paddle.arange(1, w, 2)
cur_canvas = F.pad(cur_canvas, [render_size_x // 4, render_size_x // 4,
render_size_y // 4, render_size_y // 4])
valid_foregrounds = render_utils.param2stroke(param, render_size_y, render_size_x, meta_brushes)
#* ----- load dilation/erosion ---- *#
dilation = render_utils.Dilation2d(m=1)
erosion = render_utils.Erosion2d(m=1)
#* ----- generate alphas ----- *#
valid_alphas = (valid_foregrounds > 0).astype('float32')
valid_foregrounds = valid_foregrounds.reshape([-1, stroke_num, 1, render_size_y, render_size_x])
valid_alphas = valid_alphas.reshape([-1, stroke_num, 1, render_size_y, render_size_x])
temp = [dilation(valid_foregrounds[:, i, :, :, :]) for i in range(stroke_num)]
valid_foregrounds = paddle.stack(temp, axis=1)
valid_foregrounds = valid_foregrounds.reshape([-1, 1, render_size_y, render_size_x])
temp = [erosion(valid_alphas[:, i, :, :, :]) for i in range(stroke_num)]
valid_alphas = paddle.stack(temp, axis=1)
valid_alphas = valid_alphas.reshape([-1, 1, render_size_y, render_size_x])
foregrounds = valid_foregrounds.reshape([-1, h, w, stroke_num, 1, render_size_y, render_size_x])
alphas = valid_alphas.reshape([-1, h, w, stroke_num, 1, render_size_y, render_size_x])
decision = decision.reshape([-1, h, w, stroke_num, 1, 1, 1])
param = param.reshape([-1, h, w, stroke_num, 8])
def partial_render(this_canvas, patch_coord_y, patch_coord_x):
canvas_patch = F.unfold(this_canvas, [render_size_y, render_size_x], strides=[render_size_y // 2, render_size_x // 2])
# canvas_patch: b, 3 * py * px, h * w
canvas_patch = canvas_patch.reshape([b, 3, render_size_y, render_size_x, h, w])
canvas_patch = canvas_patch.transpose([0, 4, 5, 1, 2, 3])
selected_canvas_patch = paddle.gather(canvas_patch, patch_coord_y, 1)
selected_canvas_patch = paddle.gather(selected_canvas_patch, patch_coord_x, 2)
selected_canvas_patch = selected_canvas_patch.reshape([0, 0, 0, 1, 3, render_size_y, render_size_x])
selected_foregrounds = paddle.gather(foregrounds, patch_coord_y, 1)
selected_foregrounds = paddle.gather(selected_foregrounds, patch_coord_x, 2)
selected_alphas = paddle.gather(alphas, patch_coord_y, 1)
selected_alphas = paddle.gather(selected_alphas, patch_coord_x, 2)
selected_decisions = paddle.gather(decision, patch_coord_y, 1)
selected_decisions = paddle.gather(selected_decisions, patch_coord_x, 2)
selected_color = paddle.gather(param, patch_coord_y, 1)
selected_color = paddle.gather(selected_color, patch_coord_x, 2)
selected_color = paddle.gather(selected_color, paddle.to_tensor([5,6,7]), 4)
selected_color = selected_color.reshape([0, 0, 0, stroke_num, 3, 1, 1])
for i in range(stroke_num):
i = paddle.to_tensor(i)
cur_foreground = paddle.gather(selected_foregrounds, i, 3)
cur_alpha = paddle.gather(selected_alphas, i, 3)
cur_decision = paddle.gather(selected_decisions, i, 3)
cur_color = paddle.gather(selected_color, i, 3)
cur_foreground = cur_foreground * cur_color
selected_canvas_patch = cur_foreground * cur_alpha * cur_decision + selected_canvas_patch * (1 - cur_alpha * cur_decision)
selected_canvas_patch = selected_canvas_patch.reshape([0, 0, 0, 3, render_size_y, render_size_x])
this_canvas = selected_canvas_patch.transpose([0, 3, 1, 4, 2, 5])
# this_canvas: b, 3, h_half, py, w_half, px
h_half = this_canvas.shape[2]
w_half = this_canvas.shape[4]
this_canvas = this_canvas.reshape([b, 3, h_half * render_size_y, w_half * render_size_x])
# this_canvas: b, 3, h_half * py, w_half * px
return this_canvas
# even - even area
# 1 | 0
# 0 | 0
canvas = partial_render(cur_canvas, even_idx_y, even_idx_x)
if not is_odd_y:
canvas = paddle.concat([canvas, cur_canvas[:, :, -render_size_y // 2:, :canvas.shape[3]]], axis=2)
if not is_odd_x:
canvas = paddle.concat([canvas, cur_canvas[:, :, :canvas.shape[2], -render_size_x // 2:]], axis=3)
cur_canvas = canvas
# odd - odd area
# 0 | 0
# 0 | 1
if h > 1 and w > 1:
canvas = partial_render(cur_canvas, odd_idx_y, odd_idx_x)
canvas = paddle.concat([cur_canvas[:, :, :render_size_y // 2, -canvas.shape[3]:], canvas], axis=2)
canvas = paddle.concat([cur_canvas[:, :, -canvas.shape[2]:, :render_size_x // 2], canvas], axis=3)
if is_odd_y:
canvas = paddle.concat([canvas, cur_canvas[:, :, -render_size_y // 2:, :canvas.shape[3]]], axis=2)
if is_odd_x:
canvas = paddle.concat([canvas, cur_canvas[:, :, :canvas.shape[2], -render_size_x // 2:]], axis=3)
cur_canvas = canvas
# odd - even area
# 0 | 0
# 1 | 0
if h > 1:
canvas = partial_render(cur_canvas, odd_idx_y, even_idx_x)
canvas = paddle.concat([cur_canvas[:, :, :render_size_y // 2, :canvas.shape[3]], canvas], axis=2)
if is_odd_y:
canvas = paddle.concat([canvas, cur_canvas[:, :, -render_size_y // 2:, :canvas.shape[3]]], axis=2)
if not is_odd_x:
canvas = paddle.concat([canvas, cur_canvas[:, :, :canvas.shape[2], -render_size_x // 2:]], axis=3)
cur_canvas = canvas
# odd - even area
# 0 | 1
# 0 | 0
if w > 1:
canvas = partial_render(cur_canvas, even_idx_y, odd_idx_x)
canvas = paddle.concat([cur_canvas[:, :, :canvas.shape[2], :render_size_x // 2], canvas], axis=3)
if not is_odd_y:
canvas = paddle.concat([canvas, cur_canvas[:, :, -render_size_y // 2:, -canvas.shape[3]:]], axis=2)
if is_odd_x:
canvas = paddle.concat([canvas, cur_canvas[:, :, :canvas.shape[2], -render_size_x // 2:]], axis=3)
cur_canvas = canvas
cur_canvas = cur_canvas[:, :, render_size_y // 4:-render_size_y // 4, render_size_x // 4:-render_size_x // 4]
return cur_canvas
def render_parallel(original_img, net_g, meta_brushes):
patch_size = 32
stroke_num = 8
with paddle.no_grad():
original_h, original_w = original_img.shape[-2:]
K = max(math.ceil(math.log2(max(original_h, original_w) / patch_size)), 0)
original_img_pad_size = patch_size * (2 ** K)
original_img_pad = render_utils.pad(original_img, original_img_pad_size, original_img_pad_size)
final_result = paddle.zeros_like(original_img)
for layer in range(0, K + 1):
layer_size = patch_size * (2 ** layer)
img = F.interpolate(original_img_pad, (layer_size, layer_size))
result = F.interpolate(final_result, (layer_size, layer_size))
img_patch = F.unfold(img, [patch_size, patch_size], strides=[patch_size, patch_size])
result_patch = F.unfold(result, [patch_size, patch_size], strides=[patch_size, patch_size])
# There are patch_num * patch_num patches in total
patch_num = (layer_size - patch_size) // patch_size + 1
param, decision = stroke_net_predict(img_patch, result_patch, patch_size, net_g, stroke_num, patch_num)
#print(param.shape, decision.shape)
final_result = param2img_parallel(param, decision, meta_brushes, final_result)
# paint another time for last layer
border_size = original_img_pad_size // (2 * patch_num)
img = F.interpolate(original_img_pad, (layer_size, layer_size))
result = F.interpolate(final_result, (layer_size, layer_size))
img = F.pad(img, [patch_size // 2, patch_size // 2, patch_size // 2, patch_size // 2])
result = F.pad(result, [patch_size // 2, patch_size // 2, patch_size // 2, patch_size // 2])
img_patch = F.unfold(img, [patch_size, patch_size], strides=[patch_size, patch_size])
result_patch = F.unfold(result, [patch_size, patch_size], strides=[patch_size, patch_size])
final_result = F.pad(final_result, [border_size, border_size, border_size, border_size])
patch_num = (img.shape[2] - patch_size) // patch_size + 1
#w = (img.shape[3] - patch_size) // patch_size + 1
param, decision = stroke_net_predict(img_patch, result_patch, patch_size, net_g, stroke_num, patch_num)
final_result = param2img_parallel(param, decision, meta_brushes, final_result)
final_result = final_result[:, :, border_size:-border_size, border_size:-border_size]
final_result = (final_result.numpy().squeeze().transpose([1,2,0])[:,:,::-1] * 255).astype(np.uint8)
return final_result