Spaces:
Runtime error
Runtime error
'''PaintTransformer Demo | |
- 2021-12-21 first created | |
- See: https://github.com/wzmsltw/PaintTransformer | |
''' | |
import os | |
os.system('apt-get update') | |
os.system('apt-get -y install libgl1-mesa-glx') | |
import cv2 | |
import network | |
from time import time | |
from glob import glob | |
from loguru import logger | |
import gradio as gr | |
import paddle | |
import render_utils | |
import render_parallel | |
import render_serial | |
# ---------- Settings ---------- | |
GPU_ID = '-1' | |
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID | |
DEVICE = 'cpu' if GPU_ID == '-1' else f'cuda:{GPU_ID}' | |
examples = sorted(glob(os.path.join('input', '*.jpg'))) | |
WIDTH = 512 | |
HEIGHT = 512 | |
STROKE_NUM = 8 | |
FPS = 10 | |
# ---------- Logger ---------- | |
logger.add('app.log', mode='a') | |
logger.info('===== APP RESTARTED =====') | |
# ---------- Model ---------- | |
MODEL_FILE = 'paint_best.pdparams' | |
if not os.path.exists(MODEL_FILE): | |
os.system('gdown --id 1G0O81qSvGp0kFCgyaQHmPygbVHFi1--q') | |
logger.info('model downloaded') | |
else: | |
logger.info('model already exists') | |
paddle.set_device(DEVICE) | |
net_g = network.Painter(5, STROKE_NUM, 256, 8, 3, 3) | |
net_g.set_state_dict(paddle.load(MODEL_FILE)) | |
net_g.eval() | |
for param in net_g.parameters(): | |
param.stop_gradient = True | |
brush_large_vertical = render_utils.read_img('brush/brush_large_vertical.png', 'L') | |
brush_large_horizontal = render_utils.read_img('brush/brush_large_horizontal.png', 'L') | |
meta_brushes = paddle.concat([brush_large_vertical, brush_large_horizontal], axis=0) | |
def predict(image_file): | |
original_img = render_utils.read_img(image_file, 'RGB', WIDTH, HEIGHT) | |
logger.info(f'--- image loaded & resized {WIDTH}x{HEIGHT}') | |
logger.info('--- doing inference...') | |
t0 = time() | |
final_result_list = render_serial.render_serial(original_img, net_g, meta_brushes) | |
logger.info(f'--- inference took {time() - t0:.4f} sec') | |
out = cv2.VideoWriter('output.mp4', cv2.VideoWriter_fourcc(*'mp4v'), FPS, | |
(WIDTH, HEIGHT)) | |
for idx, frame in enumerate(final_result_list): | |
out.write(frame) | |
out.release() | |
logger.info('--- animation generated') | |
return 'output.mp4' | |
iface = gr.Interface( | |
predict, | |
title='π¨ Paint Transformer', | |
description='This demo converts an image into a sequence of painted images (animation)', | |
inputs=[ | |
gr.inputs.Image(label='Input image', type='filepath') | |
], | |
outputs=[ | |
gr.outputs.Video(label='Output animation', type='mp4') | |
], | |
examples=examples, | |
article='<p style="text-align:center">Original work: <a href="https://github.com/wzmsltw/PaintTransformer">PaintTransformer</a></p>' | |
) | |
iface.launch(debug=True) | |