Manga / drawing.py
Keiser41's picture
Upload 47 files
62456b0
raw
history blame contribute delete
No virus
5.83 kB
import os
from datetime import datetime
import base64
import random
import string
import shutil
import torch
import matplotlib.pyplot as plt
import numpy as np
from flask import Flask, request, jsonify, abort, redirect, url_for, render_template, send_file, Response
from flask_wtf import FlaskForm
from wtforms import StringField, FileField, BooleanField, DecimalField
from wtforms.validators import DataRequired
from flask import after_this_request
from model.models import Colorizer, Generator
from model.extractor import get_seresnext_extractor
from utils.xdog import XDoGSketcher
from utils.utils import open_json
from denoising.denoiser import FFDNetDenoiser
from inference import process_image_with_hint
from utils.utils import resize_pad
from utils.dataset_utils import get_sketch
def generate_id(size=25, chars=string.ascii_letters + string.digits):
return ''.join(random.SystemRandom().choice(chars) for _ in range(size))
def generate_unique_id(current_ids = set()):
id_t = generate_id()
while id_t in current_ids:
id_t = generate_id()
current_ids.add(id_t)
return id_t
app = Flask(__name__)
app.config.update(dict(
SECRET_KEY="lol kek",
WTF_CSRF_SECRET_KEY="cheburek"
))
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
colorizer = torch.jit.load('./model/colorizer.zip', map_location=torch.device(device))
sketcher = XDoGSketcher()
xdog_config = open_json('configs/xdog_config.json')
for key in xdog_config.keys():
if key in sketcher.params:
sketcher.params[key] = xdog_config[key]
denoiser = FFDNetDenoiser(device)
color_args = {'colorizer':colorizer, 'sketcher':sketcher, 'device':device, 'dfm' : True, 'auto_hint' : False, 'ignore_gray' : False, 'denoiser' : denoiser, 'denoiser_sigma' : 25}
class SubmitForm(FlaskForm):
file = FileField(validators=[DataRequired(), ])
def preprocess_image(file_id, ext):
directory_path = os.path.join('static', 'temp_images', file_id)
original_path = os.path.join(directory_path, 'original') + ext
original_image = plt.imread(original_path)
resized_image, _ = resize_pad(original_image)
resized_image = denoiser.get_denoised_image(resized_image, 25)
bw, dfm = get_sketch(resized_image, sketcher, True)
resized_name = 'resized_' + str(resized_image.shape[0]) + '_' + str(resized_image.shape[1]) + '.png'
plt.imsave(os.path.join(directory_path, resized_name), resized_image)
plt.imsave(os.path.join(directory_path, 'bw.png'), bw, cmap = 'gray')
plt.imsave(os.path.join(directory_path, 'dfm.png'), dfm, cmap = 'gray')
os.remove(original_path)
empty_hint = np.zeros((resized_image.shape[0], resized_image.shape[1], 4), dtype = np.float32)
plt.imsave(os.path.join(directory_path, 'hint.png'), empty_hint)
@app.route('/', methods=['GET', 'POST'])
def upload():
form = SubmitForm()
if form.validate_on_submit():
input_data = form.file.data
_, ext = os.path.splitext(input_data.filename)
if ext not in ('.jpg', '.png', '.jpeg'):
return abort(400)
file_id = generate_unique_id()
directory = os.path.join('static', 'temp_images', file_id)
original_filename = os.path.join(directory, 'original') + ext
try :
os.mkdir(directory)
input_data.save(original_filename)
preprocess_image(file_id, ext)
return redirect(f'/draw/{file_id}')
except :
print('Failed to colorize')
if os.path.exists(directory):
shutil.rmtree(directory)
return abort(400)
return render_template("upload.html", form = form)
@app.route('/img/<file_id>')
def show_image(file_id):
if not os.path.exists(os.path.join('static', 'temp_images', str(file_id))):
abort(404)
return f'<img src="/static/temp_images/{file_id}/colorized.png?{random. randint(1,1000000)}">'
def colorize_image(file_id):
directory_path = os.path.join('static', 'temp_images', file_id)
bw = plt.imread(os.path.join(directory_path, 'bw.png'))[..., :1]
dfm = plt.imread(os.path.join(directory_path, 'dfm.png'))[..., :1]
hint = plt.imread(os.path.join(directory_path, 'hint.png'))
return process_image_with_hint(bw, dfm, hint, color_args)
@app.route('/colorize', methods=['POST'])
def colorize():
file_id = request.form['save_file_id']
file_id = file_id[file_id.rfind('/') + 1:]
img_data = request.form['save_image']
img_data = img_data[img_data.find(',') + 1:]
directory_path = os.path.join('static', 'temp_images', file_id)
with open(os.path.join(directory_path, 'hint.png'), "wb") as im:
im.write(base64.decodestring(str.encode(img_data)))
result = colorize_image(file_id)
plt.imsave(os.path.join(directory_path, 'colorized.png'), result)
src_path = f'../static/temp_images/{file_id}/colorized.png?{random. randint(1,1000000)}'
return src_path
@app.route('/draw/<file_id>', methods=['GET', 'POST'])
def paintapp(file_id):
if request.method == 'GET':
directory_path = os.path.join('static', 'temp_images', str(file_id))
if not os.path.exists(directory_path):
abort(404)
resized_name = [x for x in os.listdir(directory_path) if x.startswith('resized_')][0]
split = os.path.splitext(resized_name)[0].split('_')
width = int(split[2])
height = int(split[1])
return render_template("drawing.html", height = height, width = width, img_path = os.path.join('temp_images', str(file_id), resized_name))