File size: 3,761 Bytes
d7dbcdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from flask import Flask, render_template, request, redirect, url_for, abort
import json

app = Flask(__name__)

import sys
sys.path.append(".")
sys.path.append("..")

import argparse
from PIL import Image, ImageOps
import numpy as np
import base64
import cv2
from inference import demo

def Base64ToNdarry(img_base64):
    img_data = base64.b64decode(img_base64)
    img_np = np.fromstring(img_data, np.uint8)
    src = cv2.imdecode(img_np, cv2.IMREAD_ANYCOLOR)

    return src

def NdarrayToBase64(dst):
    result, dst_data = cv2.imencode('.png', dst)
    dst_base64 = base64.b64encode(dst_data)

    return dst_base64

parser = argparse.ArgumentParser(description='User controllable latent transformer')
parser.add_argument('--checkpoint_path', default='pretrained_models/latent_transformer/cat.pt')
args = parser.parse_args()

demo = demo(args.checkpoint_path)

@app.route("/", methods=["GET", "POST"])
#@auth.login_required
def init():
    if request.method == "GET":
        input_img = demo.run()
        input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
        return render_template("index.html", filepath1=input_base64, canvas_img=input_base64, result=True)
    if request.method == "POST":
        if 'zi' in request.form.keys():
            input_img = demo.move(z=-0.05)
        elif 'zo' in request.form.keys():
            input_img = demo.move(z=0.05)
        elif 'u' in request.form.keys():
            input_img = demo.move(y=-0.5, z=-0.0)
        elif 'd' in request.form.keys():
            input_img = demo.move(y=0.5, z=-0.0)
        elif 'l' in request.form.keys():
            input_img = demo.move(x=-0.5, z=-0.0)
        elif 'r' in request.form.keys():
            input_img = demo.move(x=0.5, z=-0.0)
        else:
            input_img = demo.run()
        
        input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
        return render_template("index.html", filepath1=input_base64, canvas_img=input_base64, result=True)

@app.route('/zoom', methods=["POST"])
def zoom_func():
    
    dz = json.loads(request.form['dz'])
    sx = json.loads(request.form['sx'])
    sy = json.loads(request.form['sy'])
    stop_points = json.loads(request.form['stop_points'])
    
    input_img = demo.zoom(dz,sxsy=[sx,sy],stop_points=stop_points)
    input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
    res = {'img':input_base64}
    return json.dumps(res)

@app.route('/translate', methods=["POST"])
def translate_func():
    
    dx = json.loads(request.form['dx'])
    dy = json.loads(request.form['dy'])
    dz = json.loads(request.form['dz'])
    sx = json.loads(request.form['sx'])
    sy = json.loads(request.form['sy'])
    stop_points = json.loads(request.form['stop_points'])
    zi = json.loads(request.form['zi'])
    zo = json.loads(request.form['zo'])

    input_img = demo.translate([dx,dy],sxsy=[sx,sy],stop_points=stop_points,zoom_in=zi,zoom_out=zo)
    input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
    res = {'img':input_base64}
    return json.dumps(res)

@app.route('/changestyle', methods=["POST"])
def changestyle_func():
    input_img = demo.change_style()
    input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
    res = {'img':input_base64}
    return json.dumps(res)

@app.route('/reset', methods=["POST"])
def reset_func():
    input_img = demo.reset()
    input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
    res = {'img':input_base64}
    return json.dumps(res)
    
if __name__ == "__main__":
    app.run(debug=False, host='0.0.0.0', port=8000)