Max Reimann commited on
Commit
85cce87
·
1 Parent(s): f0f40d2

Integrate servertask into demo app, add dockerfile

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ worker/img_received
2
+ worker/result
Dockerfile_worker ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:1.12.0-cuda11.3-cudnn8-runtime
2
+
3
+ WORKDIR /usr/app
4
+ ADD worker/requirements.txt .
5
+ RUN pip install -r requirements.txt
6
+
7
+ ADD wise .
8
+
9
+ WORKDIR /usr/app/worker
10
+ ADD worker/serve.py .
11
+
12
+ EXPOSE 8600
13
+
14
+ CMD ["python", "serve.py"]
Whitebox_style_transfer.py CHANGED
@@ -9,6 +9,7 @@ import requests
9
  import torch
10
  import torch.nn.functional as F
11
  from PIL import Image
 
12
 
13
  PACKAGE_PARENT = 'wise'
14
  SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
@@ -144,34 +145,15 @@ def optimize(effect, preset, result_image_placeholder):
144
  content = st.session_state["Content_im"]
145
  style = st.session_state["Style_im"]
146
  result_image_placeholder.text("<- Custom content/style needs to be style transferred")
147
- st.sidebar.info("Note: Optimizing takes up to 5 minutes.")
148
  optimize_button = st.sidebar.button("Optimize Style Transfer")
149
  if optimize_button:
150
- if HUGGING_FACE:
151
- result_image_placeholder.warning("NST optimization is currently disabled in this HuggingFace Space because it takes ~5min to optimize. To try it out, please clone the repo and change the huggingface variable in demo_config.py")
152
- st.stop()
153
-
154
- result_image_placeholder.text("Executing NST to create reference image..")
155
- base_dir = f"result/{datetime.datetime.now().strftime(r'%Y-%m-%d %H.%Mh %Ss')}"
156
- os.makedirs(base_dir)
157
- with st.spinner(text="Running NST"):
158
- reference = strotss(pil_resize_long_edge_to(content, 1024),
159
- pil_resize_long_edge_to(style, 1024), content_weight=16.0,
160
- device=torch.device("cuda"), space="uniform")
161
- progress_bar = result_image_placeholder.progress(0.0)
162
- ref_save_path = os.path.join(base_dir, "reference.jpg")
163
- content_save_path = os.path.join(base_dir, "content.jpg")
164
- resize_to = 720
165
- reference = pil_resize_long_edge_to(reference, resize_to)
166
- reference.save(ref_save_path)
167
- content.save(content_save_path)
168
- ST_CONFIG["n_iterations"] = 300
169
  with st.spinner(text="Optimizing parameters.."):
170
- vp, content_img_cuda = single_optimize(effect, preset, "l1", content_save_path, str(ref_save_path),
171
- write_video=False, base_dir=base_dir,
172
- iter_callback=lambda i: progress_bar.progress(
173
- float(i) / ST_CONFIG["n_iterations"]))
174
- return content_img_cuda.detach(), vp.cuda().detach()
175
  else:
176
  if not "result_vp" in st.session_state:
177
  st.stop()
@@ -224,6 +206,15 @@ coll2.header("Global Edits")
224
  result_image_placeholder = coll1.empty()
225
  result_image_placeholder.markdown("## loading..")
226
 
 
 
 
 
 
 
 
 
 
227
  img_choice_panel("Content", content_urls, "portrait", expanded=True)
228
  img_choice_panel("Style", style_urls, "starry_night", expanded=True)
229
 
 
9
  import torch
10
  import torch.nn.functional as F
11
  from PIL import Image
12
+ import time
13
 
14
  PACKAGE_PARENT = 'wise'
15
  SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
 
145
  content = st.session_state["Content_im"]
146
  style = st.session_state["Style_im"]
147
  result_image_placeholder.text("<- Custom content/style needs to be style transferred")
148
+ st.sidebar.warning("Note: Optimizing takes up to 5 minutes.")
149
  optimize_button = st.sidebar.button("Optimize Style Transfer")
150
  if optimize_button:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  with st.spinner(text="Optimizing parameters.."):
152
+ if HUGGING_FACE:
153
+ optimize_on_server(content, style, result_image_placeholder)
154
+ else:
155
+ optimize_params(effect, preset, content, style, result_image_placeholder)
156
+ return st.session_state["effect_input"], st.session_state["result_vp"]
157
  else:
158
  if not "result_vp" in st.session_state:
159
  st.stop()
 
206
  result_image_placeholder = coll1.empty()
207
  result_image_placeholder.markdown("## loading..")
208
 
209
+ from tasks import optimize_on_server, optimize_params, monitor_task
210
+
211
+ if "current_server_task_id" not in st.session_state:
212
+ st.session_state['current_server_task_id'] = None
213
+
214
+ if HUGGING_FACE and st.session_state['current_server_task_id'] is not None:
215
+ with st.spinner(text="Optimizing parameters.."):
216
+ monitor_task(result_image_placeholder)
217
+
218
  img_choice_panel("Content", content_urls, "portrait", expanded=True)
219
  img_choice_panel("Style", style_urls, "starry_night", expanded=True)
220
 
demo_config.py CHANGED
@@ -1 +1,2 @@
1
- HUGGING_FACE=True # if run in hugging face. Disables some things like full NST optimization
 
 
1
+ HUGGING_FACE=True # if run in hugging face. Huggingface uses extra server task for optim
2
+ WORKER_URL="http://mr2632.byod.hpi.de:8600"
pages/test.py DELETED
@@ -1,119 +0,0 @@
1
- import base64
2
- import datetime
3
- import os
4
- import sys
5
- from io import BytesIO
6
- from pathlib import Path
7
- import numpy as np
8
- import requests
9
- import torch
10
- import torch.nn.functional as F
11
- from PIL import Image
12
- import json
13
- import time
14
-
15
- PACKAGE_PARENT = 'wise'
16
- SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
17
- sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
18
-
19
- import streamlit as st
20
- from streamlit.logger import get_logger
21
- from st_click_detector import click_detector
22
- import streamlit.components.v1 as components
23
- from streamlit.source_util import get_pages
24
- from streamlit_extras.switch_page_button import switch_page
25
-
26
- import helpers.session_state as session_state
27
- from parameter_optimization.strotss_org import strotss, pil_resize_long_edge_to
28
- from helpers import torch_to_np, np_to_torch
29
- from effects import get_default_settings, MinimalPipelineEffect
30
-
31
- st.set_page_config(layout="wide")
32
- BASE_URL = "https://ivpg.hpi3d.de/wise/wise-demo/images/"
33
- LOGGER = get_logger(__name__)
34
-
35
-
36
- def upload_form(imgtype):
37
- with st.form(imgtype + "-form", clear_on_submit=True):
38
- uploaded_im = st.file_uploader(f"Load {imgtype} image:", type=["png", "jpg"], )
39
- upload_pressed = st.form_submit_button("Upload")
40
-
41
- if upload_pressed and uploaded_im is not None:
42
- img = Image.open(uploaded_im)
43
- buffered = BytesIO()
44
- img.save(buffered, format="JPEG")
45
- encoded = base64.b64encode(buffered.getvalue()).decode()
46
- # session_state.get(uploaded_im=img, content_render_src=f"data:image/jpeg;base64,{encoded}")
47
- session_state.get(**{f"{imgtype}_im": img, f"{imgtype}_render_src": f"data:image/jpeg;base64,{encoded}",
48
- f"{imgtype}_id": "uploaded"})
49
- st.session_state["action"] = "uploaded"
50
- st.write("uploaded.")
51
-
52
- upload_form("Content")
53
- upload_form("Style")
54
- content = st.session_state["Content_im"]
55
- style = st.session_state["Style_im"]
56
- base_url = "http://mr2632.byod.hpi.de:5000"
57
-
58
- if content is not None and style is not None:
59
- optimize_button = st.sidebar.button("Optimize Style Transfer")
60
- if optimize_button:
61
- url = base_url + "/upload"
62
- content_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
63
- style_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
64
- content = pil_resize_long_edge_to(content, 1024)
65
- content.save(content_path)
66
- style = pil_resize_long_edge_to(style, 1024)
67
- style.save(style_path)
68
- files = {'style-image': open(style_path, "rb"), "content-image": open(content_path, "rb")}
69
- print("start-optimizing")
70
- task_id_res = requests.post(url, files=files)
71
- if task_id_res.status_code != 200:
72
- st.error(task_id_res.content)
73
- st.stop()
74
- else:
75
- task_id = task_id_res.json()['task_id']
76
-
77
- progress_bar = st.empty()
78
- with st.spinner(text="Optimizing parameters.."):
79
- started_time = time.time()
80
- while True:
81
- time.sleep(3)
82
- status = requests.get(base_url+"/get_status", params={"task_id": task_id})
83
- if status.status_code != 200:
84
- print("get_status got status_code", status.status_code)
85
- st.warning(status.content)
86
- continue
87
- status = status.json()
88
- print(status)
89
- if status["status"] != "running" and status["status"] != "queued" :
90
- if status["msg"] != "":
91
- st.error(status["msg"])
92
- break
93
- elif status["status"] == "queued":
94
- started_time = time.time()
95
- queue_length = requests.get(base_url+"/queue_length").json()
96
- progress_bar.write(f"There are {queue_length['length']} tasks in the queue")
97
- elif status["progress"] == 0.0:
98
- progressed = min(0.5 * (time.time() - started_time) / 80.0, 0.5) #estimate 80s for strotts
99
- progress_bar.progress(progressed)
100
- else:
101
- progress_bar.progress(min(0.5 + status["progress"] / 2.0, 1.0))
102
- vp_res = requests.get(base_url+"/get_vp", params={"task_id": task_id})
103
- if vp_res.status_code != 200:
104
- st.warning("got status" + str(vp_res.status_code))
105
- vp_res.raise_for_status()
106
- else:
107
- vp = np.load(BytesIO(vp_res.content))["vp"]
108
- print("received vp from server")
109
- print("got numpy array", vp.shape)
110
- vp = torch.from_numpy(vp).cuda()
111
-
112
- effect, preset, param_set = get_default_settings("minimal_pipeline")
113
- effect.enable_checkpoints()
114
- effect.cuda()
115
- content_cuda = np_to_torch(content).cuda()
116
- with torch.no_grad():
117
- result_cuda = effect(content_cuda, vp)
118
- img_res = Image.fromarray((torch_to_np(result_cuda) * 255.0).astype(np.uint8))
119
- st.image(img_res)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tasks.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import datetime
3
+ import os
4
+ import sys
5
+ from io import BytesIO
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import requests
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from PIL import Image
12
+ import time
13
+ import streamlit as st
14
+ from demo_config import HUGGING_FACE, WORKER_URL
15
+
16
+
17
+
18
+ PACKAGE_PARENT = 'wise'
19
+ SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
20
+ sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
21
+
22
+ from parameter_optimization.parametric_styletransfer import single_optimize
23
+ from parameter_optimization.parametric_styletransfer import CONFIG as ST_CONFIG
24
+ from parameter_optimization.strotss_org import strotss, pil_resize_long_edge_to
25
+ from helpers import torch_to_np, np_to_torch
26
+
27
+ def retrieve_for_results_from_server():
28
+ task_id = st.session_state['current_server_task_id']
29
+ vp_res = requests.get(WORKER_URL+"/get_vp", params={"task_id": task_id})
30
+ image_res = requests.get(WORKER_URL+"/get_image", params={"task_id": task_id})
31
+ if vp_res.status_code != 200 or image_res.status_code != 200:
32
+ st.warning("got status for " + WORKER_URL+"/get_vp" + str(vp_res.status_code))
33
+ st.warning("got status for " + WORKER_URL+"/image_res" + str(image_res.status_code))
34
+ st.session_state['current_server_task_id'] = None
35
+ vp_res.raise_for_status()
36
+ image_res.raise_for_status()
37
+ else:
38
+ st.session_state['current_server_task_id'] = None
39
+ vp = np.load(BytesIO(vp_res.content))["vp"]
40
+ print("received vp from server")
41
+ print("got numpy array", vp.shape)
42
+ vp = torch.from_numpy(vp).cuda()
43
+ image = Image.open(BytesIO(image_res.content))
44
+ print("received image from server")
45
+ image = np_to_torch(np.asarray(image)).cuda()
46
+
47
+ st.session_state["effect_input"] = image
48
+ st.session_state["result_vp"] = vp
49
+
50
+
51
+ def monitor_task(progress_placeholder):
52
+ task_id = st.session_state['current_server_task_id']
53
+
54
+ started_time = time.time()
55
+ retries = 3
56
+ while True:
57
+ status = requests.get(WORKER_URL+"/get_status", params={"task_id": task_id})
58
+ if status.status_code != 200:
59
+ print("get_status got status_code", status.status_code)
60
+ st.warning(status.content)
61
+ retries -= 1
62
+ if retries == 0:
63
+ return
64
+ else:
65
+ time.sleep(2)
66
+ continue
67
+ status = status.json()
68
+ print(status)
69
+ if status["status"] != "running" and status["status"] != "queued" :
70
+ if status["msg"] != "":
71
+ print("got error for task", task_id, ":", status["msg"])
72
+ progress_placeholder.error(status["msg"])
73
+ st.session_state['current_server_task_id'] = None
74
+ st.stop()
75
+ if status["status"] == "finished":
76
+ retrieve_for_results_from_server()
77
+ return
78
+ elif status["status"] == "queued":
79
+ started_time = time.time()
80
+ queue_length = requests.get(WORKER_URL+"/queue_length").json()
81
+ progress_placeholder.write(f"There are {queue_length['length']} tasks in the queue")
82
+ elif status["progress"] == 0.0:
83
+ progressed = min(0.5 * (time.time() - started_time) / 80.0, 0.5) #estimate 80s for strotts
84
+ progress_placeholder.progress(progressed)
85
+ else:
86
+ progress_placeholder.progress(min(0.5 + status["progress"] / 2.0, 1.0))
87
+
88
+ time.sleep(2)
89
+
90
+
91
+ def optimize_on_server(content, style, result_image_placeholder):
92
+ url = WORKER_URL + "/upload"
93
+ content_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
94
+ style_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
95
+ asp_c, asp_s = content.height / content.width, style.height / style.width
96
+ if any(a < 0.5 or a > 2.0 for a in (asp_c, asp_s)):
97
+ result_image_placeholder.error('aspect ratio must be <= 2')
98
+ st.stop()
99
+ content = pil_resize_long_edge_to(content, 1024)
100
+ content.save(content_path)
101
+ style = pil_resize_long_edge_to(style, 1024)
102
+ style.save(style_path)
103
+ files = {'style-image': open(style_path, "rb"), "content-image": open(content_path, "rb")}
104
+ print("start-optimizing")
105
+ task_id_res = requests.post(url, files=files)
106
+ if task_id_res.status_code != 200:
107
+ result_image_placeholder.error(task_id_res.content)
108
+ st.stop()
109
+ else:
110
+ task_id = task_id_res.json()['task_id']
111
+ st.session_state['current_server_task_id'] = task_id
112
+
113
+ monitor_task(result_image_placeholder)
114
+
115
+ def optimize_params(effect, preset, content, style, result_image_placeholder):
116
+ result_image_placeholder.text("Executing NST to create reference image..")
117
+ base_dir = f"result/{datetime.datetime.now().strftime(r'%Y-%m-%d %H.%Mh %Ss')}"
118
+ os.makedirs(base_dir)
119
+ reference = strotss(pil_resize_long_edge_to(content, 1024),
120
+ pil_resize_long_edge_to(style, 1024), content_weight=16.0,
121
+ device=torch.device("cuda"), space="uniform")
122
+ progress_bar = result_image_placeholder.progress(0.0)
123
+ ref_save_path = os.path.join(base_dir, "reference.jpg")
124
+ content_save_path = os.path.join(base_dir, "content.jpg")
125
+ resize_to = 720
126
+ reference = pil_resize_long_edge_to(reference, resize_to)
127
+ reference.save(ref_save_path)
128
+ content.save(content_save_path)
129
+ ST_CONFIG["n_iterations"] = 300
130
+
131
+ vp, content_img_cuda = single_optimize(effect, preset, "l1", content_save_path, str(ref_save_path),
132
+ write_video=False, base_dir=base_dir,
133
+ iter_callback=lambda i: progress_bar.progress(
134
+ float(i) / ST_CONFIG["n_iterations"]))
135
+ st.session_state["effect_input"], st.session_state["result_vp"] = content_img_cuda.detach(), vp.cuda().detach()
worker/requirements.txt CHANGED
@@ -3,6 +3,7 @@ imageio-ffmpeg
3
  scipy
4
  Pillow
5
  numpy
 
6
  --extra-index-url https://download.pytorch.org/whl/cu113
7
  torch
8
  torchvision
 
3
  scipy
4
  Pillow
5
  numpy
6
+ matplotlib
7
  --extra-index-url https://download.pytorch.org/whl/cu113
8
  torch
9
  torchvision
worker/serve.py CHANGED
@@ -175,10 +175,12 @@ class StylerQueue:
175
  def queue_task(self, *args):
176
  global total_task_count
177
  total_task_count += 1
178
- task = StyleTask(total_task_count, *args)
 
 
179
  self.queued_tasks.append(task)
180
 
181
- return total_task_count
182
 
183
  def get_task(self, task_id):
184
  if self.running_task is not None and self.running_task.task_id == task_id:
@@ -281,4 +283,4 @@ def get_vp():
281
 
282
 
283
  if __name__ == '__main__':
284
- app.run(debug=False, host="0.0.0.0",port=5000)
 
175
  def queue_task(self, *args):
176
  global total_task_count
177
  total_task_count += 1
178
+ task_id = abs(hash(str(time.time())))
179
+ print("queued task num. ", total_task_count, "with ID", task_id)
180
+ task = StyleTask(task_id, *args)
181
  self.queued_tasks.append(task)
182
 
183
+ return task_id
184
 
185
  def get_task(self, task_id):
186
  if self.running_task is not None and self.running_task.task_id == task_id:
 
283
 
284
 
285
  if __name__ == '__main__':
286
+ app.run(debug=False, host="0.0.0.0",port=8600)