Max Reimann commited on
Commit
f0f40d2
·
1 Parent(s): 4b98912

add optimization server test

Browse files
Whitebox_style_transfer.py CHANGED
@@ -144,6 +144,7 @@ def optimize(effect, preset, result_image_placeholder):
144
  content = st.session_state["Content_im"]
145
  style = st.session_state["Style_im"]
146
  result_image_placeholder.text("<- Custom content/style needs to be style transferred")
 
147
  optimize_button = st.sidebar.button("Optimize Style Transfer")
148
  if optimize_button:
149
  if HUGGING_FACE:
 
144
  content = st.session_state["Content_im"]
145
  style = st.session_state["Style_im"]
146
  result_image_placeholder.text("<- Custom content/style needs to be style transferred")
147
+ st.sidebar.info("Note: Optimizing takes up to 5 minutes.")
148
  optimize_button = st.sidebar.button("Optimize Style Transfer")
149
  if optimize_button:
150
  if HUGGING_FACE:
pages/test.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import datetime
3
+ import os
4
+ import sys
5
+ from io import BytesIO
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import requests
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from PIL import Image
12
+ import json
13
+ import time
14
+
15
+ PACKAGE_PARENT = 'wise'
16
+ SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
17
+ sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
18
+
19
+ import streamlit as st
20
+ from streamlit.logger import get_logger
21
+ from st_click_detector import click_detector
22
+ import streamlit.components.v1 as components
23
+ from streamlit.source_util import get_pages
24
+ from streamlit_extras.switch_page_button import switch_page
25
+
26
+ import helpers.session_state as session_state
27
+ from parameter_optimization.strotss_org import strotss, pil_resize_long_edge_to
28
+ from helpers import torch_to_np, np_to_torch
29
+ from effects import get_default_settings, MinimalPipelineEffect
30
+
31
+ st.set_page_config(layout="wide")
32
+ BASE_URL = "https://ivpg.hpi3d.de/wise/wise-demo/images/"
33
+ LOGGER = get_logger(__name__)
34
+
35
+
36
+ def upload_form(imgtype):
37
+ with st.form(imgtype + "-form", clear_on_submit=True):
38
+ uploaded_im = st.file_uploader(f"Load {imgtype} image:", type=["png", "jpg"], )
39
+ upload_pressed = st.form_submit_button("Upload")
40
+
41
+ if upload_pressed and uploaded_im is not None:
42
+ img = Image.open(uploaded_im)
43
+ buffered = BytesIO()
44
+ img.save(buffered, format="JPEG")
45
+ encoded = base64.b64encode(buffered.getvalue()).decode()
46
+ # session_state.get(uploaded_im=img, content_render_src=f"data:image/jpeg;base64,{encoded}")
47
+ session_state.get(**{f"{imgtype}_im": img, f"{imgtype}_render_src": f"data:image/jpeg;base64,{encoded}",
48
+ f"{imgtype}_id": "uploaded"})
49
+ st.session_state["action"] = "uploaded"
50
+ st.write("uploaded.")
51
+
52
+ upload_form("Content")
53
+ upload_form("Style")
54
+ content = st.session_state["Content_im"]
55
+ style = st.session_state["Style_im"]
56
+ base_url = "http://mr2632.byod.hpi.de:5000"
57
+
58
+ if content is not None and style is not None:
59
+ optimize_button = st.sidebar.button("Optimize Style Transfer")
60
+ if optimize_button:
61
+ url = base_url + "/upload"
62
+ content_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
63
+ style_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
64
+ content = pil_resize_long_edge_to(content, 1024)
65
+ content.save(content_path)
66
+ style = pil_resize_long_edge_to(style, 1024)
67
+ style.save(style_path)
68
+ files = {'style-image': open(style_path, "rb"), "content-image": open(content_path, "rb")}
69
+ print("start-optimizing")
70
+ task_id_res = requests.post(url, files=files)
71
+ if task_id_res.status_code != 200:
72
+ st.error(task_id_res.content)
73
+ st.stop()
74
+ else:
75
+ task_id = task_id_res.json()['task_id']
76
+
77
+ progress_bar = st.empty()
78
+ with st.spinner(text="Optimizing parameters.."):
79
+ started_time = time.time()
80
+ while True:
81
+ time.sleep(3)
82
+ status = requests.get(base_url+"/get_status", params={"task_id": task_id})
83
+ if status.status_code != 200:
84
+ print("get_status got status_code", status.status_code)
85
+ st.warning(status.content)
86
+ continue
87
+ status = status.json()
88
+ print(status)
89
+ if status["status"] != "running" and status["status"] != "queued" :
90
+ if status["msg"] != "":
91
+ st.error(status["msg"])
92
+ break
93
+ elif status["status"] == "queued":
94
+ started_time = time.time()
95
+ queue_length = requests.get(base_url+"/queue_length").json()
96
+ progress_bar.write(f"There are {queue_length['length']} tasks in the queue")
97
+ elif status["progress"] == 0.0:
98
+ progressed = min(0.5 * (time.time() - started_time) / 80.0, 0.5) #estimate 80s for strotts
99
+ progress_bar.progress(progressed)
100
+ else:
101
+ progress_bar.progress(min(0.5 + status["progress"] / 2.0, 1.0))
102
+ vp_res = requests.get(base_url+"/get_vp", params={"task_id": task_id})
103
+ if vp_res.status_code != 200:
104
+ st.warning("got status" + str(vp_res.status_code))
105
+ vp_res.raise_for_status()
106
+ else:
107
+ vp = np.load(BytesIO(vp_res.content))["vp"]
108
+ print("received vp from server")
109
+ print("got numpy array", vp.shape)
110
+ vp = torch.from_numpy(vp).cuda()
111
+
112
+ effect, preset, param_set = get_default_settings("minimal_pipeline")
113
+ effect.enable_checkpoints()
114
+ effect.cuda()
115
+ content_cuda = np_to_torch(content).cuda()
116
+ with torch.no_grad():
117
+ result_cuda = effect(content_cuda, vp)
118
+ img_res = Image.fromarray((torch_to_np(result_cuda) * 255.0).astype(np.uint8))
119
+ st.image(img_res)
worker/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ imageio
2
+ imageio-ffmpeg
3
+ scipy
4
+ Pillow
5
+ numpy
6
+ --extra-index-url https://download.pytorch.org/whl/cu113
7
+ torch
8
+ torchvision
9
+ Flask
10
+ Flask-Reuploaded
worker/serve.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ from pathlib import Path
4
+ import sys
5
+ from flask import Flask, jsonify, request, send_file, abort
6
+ from flask_uploads import UploadSet, configure_uploads, IMAGES
7
+ from werkzeug.exceptions import default_exceptions
8
+ from werkzeug.exceptions import HTTPException, NotFound
9
+ import json
10
+ import torch
11
+ import time
12
+ import threading
13
+ import traceback
14
+ from PIL import Image
15
+ import numpy as np
16
+
17
+ PACKAGE_PARENT = '..'
18
+ WISE_DIR = '../wise/'
19
+ SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
20
+ sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
21
+ sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, WISE_DIR)))
22
+
23
+
24
+
25
+ from parameter_optimization.parametric_styletransfer import single_optimize
26
+ from parameter_optimization.parametric_styletransfer import CONFIG as ST_CONFIG
27
+ from parameter_optimization.strotss_org import strotss, pil_resize_long_edge_to
28
+ from helpers import torch_to_np, np_to_torch
29
+ from effects import get_default_settings, MinimalPipelineEffect
30
+
31
+ class JSONExceptionHandler(object):
32
+
33
+ def __init__(self, app=None):
34
+ if app:
35
+ self.init_app(app)
36
+
37
+ def std_handler(self, error):
38
+ response = jsonify(message=error.message)
39
+ response.status_code = error.code if isinstance(error, HTTPException) else 500
40
+ return response
41
+
42
+
43
+ def init_app(self, app):
44
+ self.app = app
45
+ self.register(HTTPException)
46
+ for code, v in default_exceptions.items():
47
+ self.register(code)
48
+
49
+ def register(self, exception_or_code, handler=None):
50
+ self.app.errorhandler(exception_or_code)(handler or self.std_handler)
51
+
52
+
53
+
54
+ app = Flask(__name__)
55
+ handler = JSONExceptionHandler(app)
56
+
57
+ image_folder = 'img_received'
58
+ photos = UploadSet('photos', IMAGES)
59
+ app.config['UPLOADED_PHOTOS_DEST'] = image_folder
60
+ configure_uploads(app, photos)
61
+
62
+ class Args(object):
63
+ def __init__(self, initial_data):
64
+ for key in initial_data:
65
+ setattr(self, key, initial_data[key])
66
+ def set_attributes(self, val_dict):
67
+ for key in val_dict:
68
+ setattr(self, key, val_dict[key])
69
+
70
+ default_args = {
71
+ "output_image" : "output.jpg",
72
+ ## values always set by request ##
73
+ "content_image": "",
74
+ "style_image": "",
75
+ "output_vp": "",
76
+ "iters": 500
77
+ }
78
+
79
+
80
+ total_task_count = 0
81
+
82
+ class NeuralOptimizer():
83
+ def __init__(self, args) -> None:
84
+ self.cur_iteration = 0
85
+ self.args = args
86
+
87
+ def optimize(self):
88
+ base_dir = f"result/{datetime.datetime.now().strftime(r'%Y-%m-%d %H.%Mh %Ss')}"
89
+ os.makedirs(base_dir)
90
+
91
+ content = Image.open(self.args.content_image)
92
+ style = Image.open(self.args.style_image)
93
+
94
+ def set_iter(iter):
95
+ self.cur_iteration=iter
96
+
97
+ effect, preset, _ = get_default_settings("minimal_pipeline")
98
+ effect.enable_checkpoints()
99
+
100
+ reference = strotss(pil_resize_long_edge_to(content, 1024),
101
+ pil_resize_long_edge_to(style, 1024), content_weight=16.0,
102
+ device=torch.device("cuda"), space="uniform")
103
+
104
+ ref_save_path = os.path.join(base_dir, "reference.jpg")
105
+ resize_to = 720
106
+ reference = pil_resize_long_edge_to(reference, resize_to)
107
+ reference.save(ref_save_path)
108
+ ST_CONFIG["n_iterations"] = self.args.iters
109
+ vp, content_img_cuda = single_optimize(effect, preset, "l1", self.args.content_image, str(ref_save_path),
110
+ write_video=False, base_dir=base_dir,
111
+ iter_callback=set_iter)
112
+
113
+ output = Image.fromarray(torch_to_np(content_img_cuda.detach().cpu() * 255.0).astype(np.uint8))
114
+ output.save(self.args.output_image)
115
+ # torch.save (vp.detach().clone(), self.args.output_vp)
116
+ # preset_tensor = effect.vpd.preset_tensor(preset, np_to_torch(np.array(content)).cuda(), add_local_dims=True)
117
+ np.savez_compressed(self.args.output_vp, vp=vp.detach().cpu().numpy())
118
+
119
+
120
+
121
+ class StyleTask:
122
+ def __init__(self, task_id, style_filename, content_filename):
123
+ self.content_filename=content_filename
124
+ self.style_filename=style_filename
125
+
126
+ self.status = "queued"
127
+ self.task_id = task_id
128
+ self.error_msg = ""
129
+ self.output_filename = content_filename.split(".")[0] + "_output.jpg"
130
+ self.vp_output_filename = content_filename.split(".")[0] + "_output.npz"
131
+
132
+ # global neural_optimizer
133
+ # if neural_optimizer is None:
134
+ # neural_optimizer = NeuralOptimizer(Args(default_args))
135
+
136
+ self.neural_optimizer = NeuralOptimizer(Args(default_args))
137
+
138
+ def start(self):
139
+ self.neural_optimizer.args.set_attributes(default_args)
140
+
141
+ self.neural_optimizer.args.style_image = os.path.join(image_folder, self.style_filename)
142
+ self.neural_optimizer.args.content_image = os.path.join(image_folder, self.content_filename)
143
+ self.neural_optimizer.args.output_image = os.path.join(image_folder, self.output_filename)
144
+ self.neural_optimizer.args.output_vp = os.path.join(image_folder, self.vp_output_filename)
145
+
146
+ thread = threading.Thread(target=self.run, args=())
147
+ thread.daemon = True # Daemonize thread
148
+ thread.start() # Start the execution
149
+
150
+ def run(self):
151
+ self.status = "running"
152
+ try:
153
+ self.neural_optimizer.optimize()
154
+ except Exception as e:
155
+ print("Error in task %d :"%(self.task_id), str(e))
156
+ traceback.print_exc()
157
+
158
+ self.status = "error"
159
+ self.error_msg = str(e)
160
+ return
161
+
162
+ self.status = "finished"
163
+ print("finished styling task: " + str(self.task_id))
164
+
165
+ class StylerQueue:
166
+ queued_tasks = []
167
+ finished_tasks = []
168
+ running_task = None
169
+
170
+ def __init__(self):
171
+ thread = threading.Thread(target=self.status_checker, args=())
172
+ thread.daemon = True # Daemonize thread
173
+ thread.start() # Start the execution
174
+
175
+ def queue_task(self, *args):
176
+ global total_task_count
177
+ total_task_count += 1
178
+ task = StyleTask(total_task_count, *args)
179
+ self.queued_tasks.append(task)
180
+
181
+ return total_task_count
182
+
183
+ def get_task(self, task_id):
184
+ if self.running_task is not None and self.running_task.task_id == task_id:
185
+ return self.running_task
186
+ task = next((task for task in self.queued_tasks + self.finished_tasks if task.task_id == task_id), None)
187
+ return task
188
+
189
+ def status_checker(self):
190
+ while True:
191
+ time.sleep(0.3)
192
+
193
+ if self.running_task is None:
194
+ if len(self.queued_tasks) > 0:
195
+ self.running_task = self.queued_tasks[0]
196
+ self.running_task.start()
197
+ self.queued_tasks = self.queued_tasks[1:]
198
+ elif self.running_task.status == "finished" or self.running_task.status == "error":
199
+ self.finished_tasks.append(self.running_task)
200
+ if len(self.queued_tasks) > 0:
201
+ self.running_task = self.queued_tasks[0]
202
+ self.running_task.start()
203
+ self.queued_tasks = self.queued_tasks[1:]
204
+ else:
205
+ self.running_task = None
206
+
207
+ styler_queue = StylerQueue()
208
+
209
+
210
+ @app.route('/upload', methods=['POST'])
211
+ def upload():
212
+ if 'style-image' in request.files and \
213
+ 'content-image' in request.files:
214
+
215
+ style_filename = photos.save(request.files['style-image'])
216
+ content_filename = photos.save(request.files['content-image'])
217
+
218
+ job_id = styler_queue.queue_task(style_filename, content_filename)
219
+ print('added new stylization task', style_filename, content_filename)
220
+
221
+ return jsonify({"task_id": job_id})
222
+ abort(jsonify(message="request needs style, content image"), 400)
223
+
224
+ @app.route('/get_status')
225
+ def get_status():
226
+ task_id = int(request.args.get("task_id"))
227
+ task = styler_queue.get_task(task_id)
228
+
229
+ if task is None:
230
+ abort(jsonify(message="task with id %d not found"%task_id), 400)
231
+
232
+ status = {
233
+ "status": task.status,
234
+ "msg": task.error_msg
235
+ }
236
+
237
+ if task.status == "running":
238
+ if isinstance(task, StyleTask):
239
+ status["progress"] = float(task.neural_optimizer.cur_iteration) / float(default_args["iters"])
240
+
241
+ return jsonify(status)
242
+
243
+ @app.route('/queue_length')
244
+ def get_queue_length():
245
+ tasks = len(styler_queue.queued_tasks)
246
+ if styler_queue.running_task is not None:
247
+ tasks += 1
248
+
249
+ status = {
250
+ "length": tasks
251
+ }
252
+
253
+ return jsonify(status)
254
+
255
+
256
+ @app.route('/get_image')
257
+ def get_image():
258
+ task_id = int(request.args.get("task_id"))
259
+ task = styler_queue.get_task(task_id)
260
+
261
+ if task is None:
262
+ abort(jsonify(message="task with id %d not found"%task_id), 400)
263
+
264
+ if task.status != "finished":
265
+ abort(jsonify(message="task with id %d not in finished state"%task_id), 400)
266
+
267
+ return send_file(os.path.join(image_folder, task.output_filename), mimetype='image/jpg')
268
+
269
+ @app.route('/get_vp')
270
+ def get_vp():
271
+ task_id = int(request.args.get("task_id"))
272
+ task = styler_queue.get_task(task_id)
273
+
274
+ if task is None:
275
+ abort(jsonify(message="task with id %d not found"%task_id), 400)
276
+
277
+ if task.status != "finished":
278
+ abort(jsonify(message="task with id %d not in finished state"%task_id), 400)
279
+
280
+ return send_file(os.path.join(image_folder, task.vp_output_filename), mimetype='application/zip')
281
+
282
+
283
+ if __name__ == '__main__':
284
+ app.run(debug=False, host="0.0.0.0",port=5000)