Max Reimann commited on
Commit
6124669
0 Parent(s):

Initial commit of app

Browse files
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.png filter=lfs diff=lfs merge=lfs -text
25
+ *.jpg filter=lfs diff=lfs merge=lfs -text
26
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
27
+ *.rar filter=lfs diff=lfs merge=lfs -text
28
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
29
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
30
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
31
+ *.tflite filter=lfs diff=lfs merge=lfs -text
32
+ *.tgz filter=lfs diff=lfs merge=lfs -text
33
+ *.wasm filter=lfs diff=lfs merge=lfs -text
34
+ *.xz filter=lfs diff=lfs merge=lfs -text
35
+ *.zip filter=lfs diff=lfs merge=lfs -text
36
+ *.zst filter=lfs diff=lfs merge=lfs -text
37
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "wise"]
2
+ path = wise
3
+ url = https://github.com/winfried-ripken/wise
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: White-box Style Transfer Editing (WISE)
3
+ emoji: 🎨
4
+ colorFrom: pink
5
+ colorTo: red
6
+ sdk: streamlit
7
+ sdk_version: 1.10.0
8
+ app_file: Whitebox_style_transfer.py
9
+ tags: [Style Transfer,Image Synthesis,Editing,Painting]
10
+ pinned: false
11
+ license: mit
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
Whitebox_style_transfer.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import datetime
3
+ import os
4
+ import sys
5
+ from io import BytesIO
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import requests
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from PIL import Image
12
+
13
+ PACKAGE_PARENT = 'wise'
14
+ SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
15
+ sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
16
+
17
+ import streamlit as st
18
+ from streamlit.logger import get_logger
19
+ from st_click_detector import click_detector
20
+ import streamlit.components.v1 as components
21
+ from streamlit_extras.switch_page_button import switch_page
22
+
23
+ from demo_config import HUGGING_FACE
24
+ from parameter_optimization.parametric_styletransfer import single_optimize
25
+ from parameter_optimization.parametric_styletransfer import CONFIG as ST_CONFIG
26
+ from parameter_optimization.strotss_org import strotss, pil_resize_long_edge_to
27
+ import helpers.session_state as session_state
28
+ from helpers import torch_to_np, np_to_torch
29
+ from effects import get_default_settings, MinimalPipelineEffect
30
+
31
+ st.set_page_config(layout="wide")
32
+ BASE_URL = "https://ivpg.hpi3d.de/wise/wise-demo/images/"
33
+ LOGGER = get_logger(__name__)
34
+
35
+ effect_type = "minimal_pipeline"
36
+
37
+ if "click_counter" not in st.session_state:
38
+ st.session_state.click_counter = 1
39
+
40
+ if "action" not in st.session_state:
41
+ st.session_state["action"] = ""
42
+
43
+ content_urls = [
44
+ {
45
+ "name": "Portrait", "id": "portrait",
46
+ "src": BASE_URL + "/content/portrait.jpeg"
47
+ },
48
+ {
49
+ "name": "Tuebingen", "id": "tubingen",
50
+ "src": BASE_URL + "/content/tubingen.jpeg"
51
+ },
52
+ {
53
+ "name": "Colibri", "id": "colibri",
54
+ "src": BASE_URL + "/content/colibri.jpeg"
55
+ }
56
+ ]
57
+
58
+ style_urls = [
59
+ {
60
+ "name": "Starry Night, Van Gogh", "id": "starry_night",
61
+ "src": BASE_URL + "/style/starry_night.jpg"
62
+ },
63
+ {
64
+ "name": "The Scream, Edward Munch", "id": "the_scream",
65
+ "src": BASE_URL + "/style/the_scream.jpg"
66
+ },
67
+ {
68
+ "name": "The Great Wave, Ukiyo-e", "id": "wave",
69
+ "src": BASE_URL + "/style/wave.jpg"
70
+ },
71
+ {
72
+ "name": "Woman with Hat, Henry Matisse", "id": "woman_with_hat",
73
+ "src": BASE_URL + "/style/woman_with_hat.jpg"
74
+ }
75
+ ]
76
+
77
+
78
+ def last_image_clicked(type="content", action=None, ):
79
+ kw = "last_image_clicked" + "_" + type
80
+ if action:
81
+ session_state.get(**{kw: action})
82
+ elif kw not in session_state.get():
83
+ return None
84
+ else:
85
+ return session_state.get()[kw]
86
+
87
+
88
+ @st.cache
89
+ def _retrieve_from_id(clicked, urls):
90
+ src = [x["src"] for x in urls if x["id"] == clicked][0]
91
+ img = Image.open(requests.get(src, stream=True).raw)
92
+ return img, src
93
+
94
+
95
+ def store_img_from_id(clicked, urls, imgtype):
96
+ img, src = _retrieve_from_id(clicked, urls)
97
+ session_state.get(**{f"{imgtype}_im": img, f"{imgtype}_render_src": src, f"{imgtype}_id": clicked})
98
+
99
+
100
+ def img_choice_panel(imgtype, urls, default_choice, expanded):
101
+ with st.expander(f"Select {imgtype} image:", expanded=expanded):
102
+ html_code = '<div class="column" style="display: flex; flex-wrap: wrap; padding: 0 4px;">'
103
+ for url in urls:
104
+ html_code += f"<a href='#' id='{url['id']}' style='padding: 0px 5px'><img height='160px' style='margin-top: 8px;' src='{url['src']}'></a>"
105
+ html_code += "</div>"
106
+ clicked = click_detector(html_code)
107
+
108
+ if not clicked and st.session_state["action"] not in ("uploaded", "switch_page_from_local_edits", "switch_page_from_presets", "slider_change", "reset"): # default val
109
+ store_img_from_id(default_choice, urls, imgtype)
110
+
111
+ st.write("OR: ")
112
+
113
+ with st.form(imgtype + "-form", clear_on_submit=True):
114
+ uploaded_im = st.file_uploader(f"Load {imgtype} image:", type=["png", "jpg"], )
115
+ upload_pressed = st.form_submit_button("Upload")
116
+
117
+ if upload_pressed and uploaded_im is not None:
118
+ img = Image.open(uploaded_im)
119
+ buffered = BytesIO()
120
+ img.save(buffered, format="JPEG")
121
+ encoded = base64.b64encode(buffered.getvalue()).decode()
122
+ # session_state.get(uploaded_im=img, content_render_src=f"data:image/jpeg;base64,{encoded}")
123
+ session_state.get(**{f"{imgtype}_im": img, f"{imgtype}_render_src": f"data:image/jpeg;base64,{encoded}",
124
+ f"{imgtype}_id": "uploaded"})
125
+ st.session_state["action"] = "uploaded"
126
+ st.write("uploaded.")
127
+
128
+ last_clicked = last_image_clicked(type=imgtype)
129
+ print("last_clicked", last_clicked, "clicked", clicked, "action", st.session_state["action"] )
130
+ if not upload_pressed and clicked != "": # trigger when no file uploaded
131
+ if last_clicked != clicked: # only activate when content was actually clicked
132
+ store_img_from_id(clicked, urls, imgtype)
133
+ last_image_clicked(type=imgtype, action=clicked)
134
+ st.session_state["action"] = "clicked"
135
+ st.session_state.click_counter += 1 # hack to get page to reload at top
136
+
137
+ state = session_state.get()
138
+ st.sidebar.write(f'Selected {imgtype} image:')
139
+ st.sidebar.markdown(f'<img src="{state[f"{imgtype}_render_src"]}" width=240px></img>', unsafe_allow_html=True)
140
+
141
+
142
+ def optimize(effect, preset, result_image_placeholder):
143
+ content = st.session_state["Content_im"]
144
+ style = st.session_state["Style_im"]
145
+ result_image_placeholder.text("<- Custom content/style needs to be style transferred")
146
+ optimize_button = st.sidebar.button("Optimize Style Transfer")
147
+ if optimize_button:
148
+ if HUGGING_FACE:
149
+ result_image_placeholder.warning("NST optimization is currently disabled in this HuggingFace Space because it takes ~5min to optimize. To try it out, please clone the repo and change the huggingface variable in demo_config.py")
150
+ st.stop()
151
+
152
+ result_image_placeholder.text("Executing NST to create reference image..")
153
+ base_dir = f"result/{datetime.datetime.now().strftime(r'%Y-%m-%d %H.%Mh %Ss')}"
154
+ os.makedirs(base_dir)
155
+ with st.spinner(text="Running NST"):
156
+ reference = strotss(pil_resize_long_edge_to(content, 1024),
157
+ pil_resize_long_edge_to(style, 1024), content_weight=16.0,
158
+ device=torch.device("cuda"), space="uniform")
159
+ progress_bar = result_image_placeholder.progress(0.0)
160
+ ref_save_path = os.path.join(base_dir, "reference.jpg")
161
+ content_save_path = os.path.join(base_dir, "content.jpg")
162
+ resize_to = 720
163
+ reference = pil_resize_long_edge_to(reference, resize_to)
164
+ reference.save(ref_save_path)
165
+ content.save(content_save_path)
166
+ ST_CONFIG["n_iterations"] = 300
167
+ with st.spinner(text="Optimizing parameters.."):
168
+ vp, content_img_cuda = single_optimize(effect, preset, "l1", content_save_path, str(ref_save_path),
169
+ write_video=False, base_dir=base_dir,
170
+ iter_callback=lambda i: progress_bar.progress(
171
+ float(i) / ST_CONFIG["n_iterations"]))
172
+ return content_img_cuda.detach(), vp.cuda().detach()
173
+ else:
174
+ if not "result_vp" in st.session_state:
175
+ st.stop()
176
+ else:
177
+ return st.session_state["effect_input"], st.session_state["result_vp"]
178
+
179
+
180
+ @st.cache(hash_funcs={MinimalPipelineEffect: id})
181
+ def create_effect():
182
+ effect, preset, param_set = get_default_settings(effect_type)
183
+ effect.enable_checkpoints()
184
+ effect.cuda()
185
+ return effect, preset
186
+
187
+
188
+ def load_visual_params(vp_path: str, img_org: Image, org_cuda: torch.Tensor, effect) -> torch.Tensor:
189
+ if Path(vp_path).exists():
190
+ vp = torch.load(vp_path).detach().clone()
191
+ vp = F.interpolate(vp, (img_org.height, img_org.width))
192
+ if len(effect.vpd.vp_ranges) == vp.shape[1]:
193
+ return vp
194
+ # use preset and save it
195
+ vp = effect.vpd.preset_tensor(preset, org_cuda, add_local_dims=True)
196
+ torch.save(vp, vp_path)
197
+ return vp
198
+
199
+
200
+ # @st.cache(hash_funcs={torch.Tensor: id})
201
+ @st.experimental_memo
202
+ def load_params(content_id, style_id):#, effect):
203
+ preoptim_param_path = os.path.join("precomputed", effect_type, content_id, style_id)
204
+ img_org = Image.open(os.path.join(preoptim_param_path, "input.png"))
205
+ content_cuda = np_to_torch(img_org).cuda()
206
+ vp_path = os.path.join(preoptim_param_path, "vp.pt")
207
+ vp = load_visual_params(vp_path, img_org, content_cuda, effect)
208
+ return content_cuda, vp
209
+
210
+
211
+ def render_effect(effect, content_cuda, vp):
212
+ with torch.no_grad():
213
+ result_cuda = effect(content_cuda, vp)
214
+ img_res = Image.fromarray((torch_to_np(result_cuda) * 255.0).astype(np.uint8))
215
+ return img_res
216
+
217
+
218
+ result_container = st.container()
219
+ coll1, coll2 = result_container.columns([3,2])
220
+ coll1.header("Result")
221
+ coll2.header("Global Edits")
222
+ result_image_placeholder = coll1.empty()
223
+ result_image_placeholder.markdown("## loading..")
224
+
225
+ img_choice_panel("Content", content_urls, "portrait", expanded=True)
226
+ img_choice_panel("Style", style_urls, "starry_night", expanded=True)
227
+
228
+ state = session_state.get()
229
+ content_id = state["Content_id"]
230
+ style_id = state["Style_id"]
231
+
232
+ effect, preset = create_effect()
233
+
234
+ print("content id, style id", content_id, style_id )
235
+ if st.session_state["action"] == "uploaded":
236
+ content_img, _vp = optimize(effect, preset, result_image_placeholder)
237
+ elif st.session_state["action"] in ("switch_page_from_local_edits", "switch_page_from_presets", "slider_change") or \
238
+ content_id == "uploaded" or style_id == "uploaded":
239
+ print("restore param")
240
+ _vp = st.session_state["result_vp"]
241
+ content_img = st.session_state["effect_input"]
242
+ else:
243
+ print("load_params")
244
+ content_img, _vp = load_params(content_id, style_id)#, effect)
245
+
246
+ vp = torch.clone(_vp)
247
+
248
+
249
+ def reset_params(means, names):
250
+ for i, name in enumerate(names):
251
+ st.session_state["slider_" + name] = means[i]
252
+
253
+ def on_slider():
254
+ st.session_state["action"] = "slider_change"
255
+
256
+
257
+ with coll2:
258
+ show_params_names = [ 'bumpScale', "bumpOpacity", "contourOpacity"]
259
+ display_means = []
260
+ def create_slider(name):
261
+ mean = torch.mean(vp[:, effect.vpd.name2idx[name]]).item()
262
+ display_mean = mean + 0.5
263
+ display_means.append(display_mean)
264
+ if "slider_" + name not in st.session_state or st.session_state["action"] != "slider_change":
265
+ st.session_state["slider_" + name] = display_mean
266
+ slider = st.slider(f"Mean {name}: ", 0.0, 1.0, step=0.05, key="slider_" + name, on_change=on_slider)
267
+ vp[:, effect.vpd.name2idx[name]] += slider - display_mean
268
+ vp.clamp_(-0.5, 0.5)
269
+
270
+ for name in show_params_names:
271
+ create_slider(name)
272
+
273
+ others_idx = set(range(len(effect.vpd.vp_ranges))) - set([effect.vpd.name2idx[name] for name in show_params_names])
274
+ others_names = [effect.vpd.vp_ranges[i][0] for i in sorted(list(others_idx))]
275
+ other_param = st.selectbox("Other parameters: ", others_names)
276
+ create_slider(other_param)
277
+
278
+
279
+ reset_button = st.button("Reset Parameters", on_click=reset_params, args=(display_means, show_params_names))
280
+ if reset_button:
281
+ st.session_state["action"] = "reset"
282
+ st.experimental_rerun()
283
+
284
+ edit_locally_btn = st.button("Edit Local Parameter Maps")
285
+ if edit_locally_btn:
286
+ switch_page("Local_edits")
287
+
288
+ img_res = render_effect(effect, content_img, vp)
289
+
290
+ st.session_state["result_vp"] = vp
291
+ st.session_state["effect_input"] = content_img
292
+ st.session_state["last_result"] = img_res
293
+
294
+ with coll1:
295
+ # width = int(img_res.width * 500 / img_res.height)
296
+ result_image_placeholder.image(img_res)#, width=width)
297
+
298
+ # a bit hacky way to return focus to top of page after clicking on images
299
+ components.html(
300
+ f"""
301
+ <p>{st.session_state.click_counter}</p>
302
+ <script>
303
+ window.parent.document.querySelector('section.main').scrollTo(0, 0);
304
+ </script>
305
+ """,
306
+ height=0
307
+ )
demo_config.py ADDED
@@ -0,0 +1 @@
 
 
1
+ HUGGING_FACE=True # if run in hugging face. Disables some things like full NST optimization
pages/Apply_preset.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch.nn.functional as F
4
+ import torch
5
+
6
+ PACKAGE_PARENT = '../wise/'
7
+ SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
8
+ sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
9
+
10
+
11
+ import numpy as np
12
+ from PIL import Image
13
+ import streamlit as st
14
+ from streamlit_drawable_canvas import st_canvas
15
+
16
+ from effects.minimal_pipeline import MinimalPipelineEffect
17
+ from helpers.visual_parameter_def import minimal_pipeline_presets, minimal_pipeline_bump_mapping_preset, minimal_pipeline_xdog_preset
18
+ from helpers import torch_to_np, np_to_torch
19
+ from effects import get_default_settings
20
+
21
+ st.set_page_config(page_title="Preset Edit Demo", layout="wide")
22
+
23
+
24
+ # @st.cache(hash_funcs={OilPaintEffect: id})
25
+ @st.cache(hash_funcs={MinimalPipelineEffect: id})
26
+ def local_edits_create_effect():
27
+ effect, preset, param_set = get_default_settings("minimal_pipeline")
28
+ effect.enable_checkpoints()
29
+ effect.cuda()
30
+ return effect, param_set
31
+
32
+
33
+ effect, param_set = local_edits_create_effect()
34
+ presets = {
35
+ "original": minimal_pipeline_presets,
36
+ "bump mapped": minimal_pipeline_bump_mapping_preset,
37
+ "contoured": minimal_pipeline_xdog_preset
38
+ }
39
+
40
+ st.session_state["action"] = "switch_page_from_presets" # on switchback, remember effect input
41
+
42
+ active_preset = st.sidebar.selectbox("apply preset: ", ["original", "bump mapped", "contoured"])
43
+ blend_strength = st.sidebar.slider("Parameter blending strength (non-hue) : ", 0.0, 1.0, 1.0, 0.05)
44
+ hue_blend_strength = st.sidebar.slider("Hue-shift blending strength : ", 0.0, 1.0, 1.0, 0.05)
45
+
46
+ st.sidebar.text("Drawing options:")
47
+ stroke_width = st.sidebar.slider("Stroke width: ", 1, 80, 40)
48
+ drawing_mode = st.sidebar.selectbox(
49
+ "Drawing tool:", ("freedraw", "line", "rect", "circle", "transform")
50
+ )
51
+
52
+ st.session_state["preset_canvas_key"] ="preset_canvas"
53
+
54
+ vp = torch.clone(st.session_state["result_vp"])
55
+ org_cuda = st.session_state["effect_input"]
56
+
57
+ @st.experimental_memo
58
+ def greyscale_original(_org_cuda, content_id): #content_id is used for hashing
59
+ if HUGGING_FACE:
60
+ wsize = 450
61
+ img_org_height, img_org_width = _org_cuda.shape[-2:]
62
+ wpercent = (wsize / float(img_org_width))
63
+ hsize = int((float(img_org_height) * float(wpercent)))
64
+ else:
65
+ longest_edge = 670
66
+ img_org_height, img_org_width = _org_cuda.shape[-2:]
67
+ max_width_height = max(img_org_width, img_org_height)
68
+ hsize = int((float(longest_edge) * float(float(img_org_height) / max_width_height)))
69
+ wsize = int((float(longest_edge) * float(float(img_org_width) / max_width_height)))
70
+
71
+ org_img = F.interpolate(_org_cuda, (hsize, wsize), mode="bilinear")
72
+ org_img = torch.mean(org_img, dim=1, keepdim=True) / 2.0
73
+ org_img = torch_to_np(org_img, multiply_by_255=True)[..., np.newaxis].repeat(3, axis=2)
74
+ org_img = Image.fromarray(org_img.astype(np.uint8))
75
+ return org_img, hsize, wsize
76
+
77
+ greyscale_img, hsize, wsize = greyscale_original(org_cuda, st.session_state["Content_id"])
78
+
79
+ coll1, coll2 = st.columns(2)
80
+ coll1.header("Draw Mask")
81
+ coll2.header("Live Result")
82
+
83
+ with coll1:
84
+ # Create a canvas component
85
+ canvas_result = st_canvas(
86
+ fill_color="rgba(0, 0, 0, 1)", # Fixed fill color with some opacity
87
+ stroke_width=stroke_width,
88
+ background_image=greyscale_img,
89
+ width=greyscale_img.width,
90
+ height=greyscale_img.height,
91
+ drawing_mode=drawing_mode,
92
+ key=st.session_state["preset_canvas_key"]
93
+ )
94
+
95
+
96
+ res_data = None
97
+ if canvas_result.image_data is not None:
98
+ abc = np_to_torch(canvas_result.image_data.astype(np.float)).sum(dim=1, keepdim=True).cuda()
99
+
100
+ img_org_width = org_cuda.shape[-1]
101
+ img_org_height = org_cuda.shape[-2]
102
+ res_data = F.interpolate(abc, (img_org_height, img_org_width)).squeeze(1)
103
+
104
+ preset_tensor = effect.vpd.preset_tensor(presets[active_preset], org_cuda, add_local_dims=True)
105
+ hue = torch.clone(vp[:,effect.vpd.name2idx["hueShift"]])
106
+ vp[:] = preset_tensor * res_data * blend_strength + vp[:] * (1 - res_data * blend_strength)
107
+ vp[:, effect.vpd.name2idx["hueShift"]] = \
108
+ preset_tensor[:,effect.vpd.name2idx["hueShift"]] * res_data * hue_blend_strength + hue * (1 - res_data * hue_blend_strength)
109
+
110
+ with torch.no_grad():
111
+ result_cuda = effect(org_cuda, vp)
112
+
113
+ img_res = Image.fromarray((torch_to_np(result_cuda) * 255.0).astype(np.uint8))
114
+ coll2.image(img_res)
115
+
116
+ apply_btn = st.sidebar.button("Apply")
117
+ if apply_btn:
118
+ st.session_state["result_vp"] = vp
pages/Local_edits.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch.nn.functional as F
5
+ import torch
6
+ import numpy as np
7
+ import matplotlib
8
+ from matplotlib import pyplot as plt
9
+ import matplotlib.cm
10
+ from PIL import Image
11
+
12
+ import streamlit as st
13
+ from streamlit_drawable_canvas import st_canvas
14
+
15
+ from .. import demo_config
16
+ from demo_config import HUGGING_FACE
17
+
18
+ PACKAGE_PARENT = '../wise/'
19
+ SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
20
+ sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
21
+
22
+
23
+
24
+ from effects.gauss2d_xy_separated import Gauss2DEffect
25
+ from effects.minimal_pipeline import MinimalPipelineEffect
26
+ from helpers import torch_to_np, np_to_torch
27
+ from effects import get_default_settings
28
+
29
+ st.set_page_config(page_title="Editing Demo", layout="wide")
30
+
31
+ # @st.cache(hash_funcs={OilPaintEffect: id})
32
+ @st.cache(hash_funcs={MinimalPipelineEffect: id})
33
+ def local_edits_create_effect():
34
+ effect, preset, param_set = get_default_settings("minimal_pipeline")
35
+ effect.enable_checkpoints()
36
+ effect.cuda()
37
+ return effect, param_set
38
+
39
+
40
+ effect, param_set = local_edits_create_effect()
41
+
42
+ @st.experimental_memo
43
+ def gen_param_strength_fig():
44
+ cmap = matplotlib.cm.get_cmap('plasma')
45
+ # cmap show
46
+ gradient = np.linspace(0, 1, 256)
47
+ gradient = np.vstack((gradient, gradient))
48
+ fig, ax = plt.subplots(figsize=(3, 0.1))
49
+ fig.patch.set_alpha(0.0)
50
+ ax.set_title("parameter strength", fontsize=6.5, loc="left")
51
+ ax.imshow(gradient, aspect='auto', cmap=cmap)
52
+ ax.set_axis_off()
53
+ return fig, cmap
54
+
55
+ cmap_fig, cmap = gen_param_strength_fig()
56
+
57
+ st.session_state["canvas_key"] = "canvas"
58
+ try:
59
+ vp = st.session_state["result_vp"]
60
+ org_cuda = st.session_state["effect_input"]
61
+ except KeyError as e:
62
+ print("init run, certain keys not found. If this happens once its ok.")
63
+
64
+ if st.session_state["action"] != "switch_page_from_local_edits":
65
+ st.session_state.local_edit_action = "init"
66
+
67
+ st.session_state["action"] = "switch_page_from_local_edits" # on switchback, remember effect input
68
+
69
+ if "mask_edit_counter" not in st.session_state:
70
+ st.session_state["mask_edit_counter"] = 1
71
+ if "initial_drawing" not in st.session_state:
72
+ st.session_state["initial_drawing"] = {"random": st.session_state["mask_edit_counter"], "background": "#eee"}
73
+
74
+ def on_slider_change():
75
+ if st.session_state.local_edit_action == "init":
76
+ st.stop()
77
+ st.session_state.local_edit_action = "slider"
78
+
79
+ def on_param_change():
80
+ st.session_state.local_edit_action = "param_change"
81
+
82
+ active_param = st.sidebar.selectbox("active parameter: ", param_set + ["smooth"], index=2, on_change=on_param_change)
83
+
84
+ st.sidebar.text("Drawing options")
85
+ if active_param != "smooth":
86
+ plus_or_minus = st.sidebar.slider("Increase or decrease param map: ", -1.0, 1.0, 0.8, 0.05,
87
+ on_change=on_slider_change)
88
+ else:
89
+ sigma = st.sidebar.slider("Sigma: ", 0.1, 10.0, 0.5, 0.1, on_change=on_slider_change)
90
+
91
+ stroke_width = st.sidebar.slider("Stroke width: ", 1, 50, 20, on_change=on_slider_change)
92
+ drawing_mode = st.sidebar.selectbox(
93
+ "Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"), on_change=on_slider_change,
94
+ )
95
+
96
+ st.sidebar.text("Viewing options")
97
+ if active_param != "smooth":
98
+ overlay = st.sidebar.slider("show parameter overlay: ", 0.0, 1.0, 0.8, 0.02, on_change=on_slider_change)
99
+ st.sidebar.pyplot(cmap_fig, bbox_inches='tight', pad_inches=0)
100
+
101
+ st.sidebar.text("Update:")
102
+ realtime_update = st.sidebar.checkbox("Update in realtime", True)
103
+ clear_after_draw = st.sidebar.checkbox("Clear Canvas after each Stroke", False)
104
+ invert_selection = st.sidebar.checkbox("Invert Selection", False)
105
+
106
+
107
+ @st.experimental_memo
108
+ def greyscale_org(_org_cuda, content_id): #content_id is used for hashing
109
+ if HUGGING_FACE:
110
+ wsize = 450
111
+ img_org_height, img_org_width = _org_cuda.shape[-2:]
112
+ wpercent = (wsize / float(img_org_width))
113
+ hsize = int((float(img_org_height) * float(wpercent)))
114
+ else:
115
+ longest_edge = 670
116
+ img_org_height, img_org_width = _org_cuda.shape[-2:]
117
+ max_width_height = max(img_org_width, img_org_height)
118
+ hsize = int((float(longest_edge) * float(float(img_org_height) / max_width_height)))
119
+ wsize = int((float(longest_edge) * float(float(img_org_width) / max_width_height)))
120
+
121
+ org_img = F.interpolate(_org_cuda, (hsize, wsize), mode="bilinear")
122
+ org_img = torch.mean(org_img, dim=1, keepdim=True) / 2.0
123
+ org_img = torch_to_np(org_img)[..., np.newaxis].repeat(3, axis=2)
124
+ return org_img, hsize, wsize
125
+
126
+ def generate_param_mask(vp):
127
+ greyscale_img, hsize, wsize = greyscale_org(org_cuda, st.session_state["Content_id"])
128
+ if active_param != "smooth":
129
+ scaled_vp = F.interpolate(vp, (hsize, wsize))[:, effect.vpd.name2idx[active_param]]
130
+ param_cmapped = cmap((scaled_vp + 0.5).cpu().numpy())[...,:3][0]
131
+ greyscale_img = greyscale_img * (1 - overlay) + param_cmapped * overlay
132
+ return Image.fromarray((greyscale_img * 255).astype(np.uint8))
133
+
134
+ def compute_results(_vp):
135
+ if "cached_canvas" in st.session_state and st.session_state["cached_canvas"].image_data is not None:
136
+ canvas_result = st.session_state["cached_canvas"]
137
+ abc = np_to_torch(canvas_result.image_data.astype(np.float32)).sum(dim=1, keepdim=True).cuda()
138
+
139
+ if invert_selection:
140
+ abc = abc * (- 1.0) + 1.0
141
+
142
+ img_org_width = org_cuda.shape[-1]
143
+ img_org_height = org_cuda.shape[-2]
144
+ res_data = F.interpolate(abc, (img_org_height, img_org_width)).squeeze(1)
145
+
146
+ if active_param != "smooth":
147
+ _vp[:, effect.vpd.name2idx[active_param]] += plus_or_minus * res_data
148
+ _vp.clamp_(-0.5, 0.5)
149
+ else:
150
+ gauss2dx = Gauss2DEffect(dxdy=[1.0, 0.0], dim_kernsize=5)
151
+ gauss2dy = Gauss2DEffect(dxdy=[0.0, 1.0], dim_kernsize=5)
152
+
153
+ vp_smoothed = gauss2dx(_vp, torch.tensor(sigma).cuda())
154
+ vp_smoothed = gauss2dy(vp_smoothed, torch.tensor(sigma).cuda())
155
+
156
+ print(res_data.shape)
157
+ print(_vp.shape)
158
+ print(vp_smoothed.shape)
159
+ _vp = torch.lerp(_vp, vp_smoothed, res_data.unsqueeze(1))
160
+
161
+ with torch.no_grad():
162
+ result_cuda = effect(org_cuda, _vp)
163
+
164
+ _, hsize, wsize = greyscale_org(org_cuda, st.session_state["Content_id"])
165
+ result_cuda = F.interpolate(result_cuda, (hsize, wsize), mode="bilinear")
166
+
167
+ return Image.fromarray((torch_to_np(result_cuda) * 255.0).astype(np.uint8)), _vp
168
+
169
+ coll1, coll2 = st.columns(2)
170
+ coll1.header("Draw Mask:")
171
+ coll2.header("Live Result")
172
+
173
+ # there is no way of removing the canvas history/state without rerunning the whole program.
174
+ # therefore, giving the canvas a initial_drawing that differs from the canvas state will clear the background
175
+ def mark_canvas_for_redraw():
176
+ print("mark for redraw")
177
+ st.session_state["mask_edit_counter"] += 1 # change state of initial drawing
178
+ initial_drawing = {"random": st.session_state["mask_edit_counter"], "background": "#eee"}
179
+ st.session_state["initial_drawing"] = initial_drawing
180
+
181
+
182
+ with coll1:
183
+ print("edit action", st.session_state.local_edit_action)
184
+ if clear_after_draw and st.session_state.local_edit_action not in ("slider", "param_change", "init"):
185
+ if st.session_state.local_edit_action == "redraw":
186
+ st.session_state.local_edit_action = "draw"
187
+ mark_canvas_for_redraw()
188
+ else:
189
+ st.session_state.local_edit_action = "redraw"
190
+
191
+ mask = generate_param_mask(st.session_state["result_vp"])
192
+ st.session_state["last_mask"] = mask
193
+
194
+ # Create a canvas component
195
+ canvas_result = st_canvas(
196
+ fill_color="rgba(0, 0, 0, 1)",
197
+ stroke_width=stroke_width,
198
+ background_image=mask,
199
+ update_streamlit=realtime_update,
200
+ width=mask.width,
201
+ height=mask.height,
202
+ initial_drawing=st.session_state["initial_drawing"],
203
+ drawing_mode=drawing_mode,
204
+ key=st.session_state.canvas_key,
205
+ )
206
+
207
+ if canvas_result.json_data is None:
208
+ print("stops")
209
+ st.stop()
210
+
211
+ st.session_state["cached_canvas"] = canvas_result
212
+
213
+ print("compute result")
214
+ img_res, vp = compute_results(vp)
215
+ st.session_state["last_result"] = img_res
216
+ st.session_state["result_vp"] = vp
217
+
218
+ st.markdown("### Mask: " + active_param)
219
+
220
+ if st.session_state.local_edit_action in ("slider", "param_change", "init"):
221
+ print("set redraw")
222
+ st.session_state.local_edit_action = "redraw"
223
+
224
+
225
+ print("plot masks")
226
+ texts = []
227
+ preview_masks = []
228
+ img = st.session_state["last_mask"]
229
+ for i, p in enumerate(param_set):
230
+ idx = effect.vpd.name2idx[p]
231
+ iii = F.interpolate(vp[:, idx:idx + 1] + 0.5, (int(img.height * 0.2), int(img.width * 0.2)))
232
+ texts.append(p[:15])
233
+ preview_masks.append(torch_to_np(iii))
234
+
235
+ coll2.image(img_res) # , use_column_width="auto")
236
+ ppp = st.columns(len(param_set))
237
+ for i, (txt, im) in enumerate(zip(texts, preview_masks)):
238
+ ppp[i].text(txt)
239
+ ppp[i].image(im, clamp=True)
240
+
241
+ print("....")
pages/Readme.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ st.title("White-box Style Transfer Editing")
4
+
5
+ st.markdown("""
6
+ This app demonstrates the editing capabilities of the White-box Style Transfer Editing (WISE) framework.
7
+ It optimizes the parameters of classical image processing filters to match a given style image.
8
+ After optimization, parameters can be tuned by hand to achieve a desired look.
9
+
10
+ ### How does it work?
11
+ We provide a small stylization effect that contains several filters such as bump mapping or edge enhancement that can be optimized. The optimization yields so-called parameter masks, which contain per pixel parameter settings of each filter.
12
+
13
+ ### How to use the app ?
14
+ - On the first page select existing content/style combinations or upload images to optimize.
15
+ - After the effect has been applied, use the parameter sliders to adjust a parameter value globally
16
+ - On the "apply preset" page, we defined several parameter presets that can be drawn on the image. Press "Apply" to make the changes permanent
17
+ - On the " local editing" page, individual parameter masks can be edited regionally. Choose the parameter on the left sidebar, and use the parameter strength slider to either increase or decrease the strength of the drawn strokes
18
+ - Strokes on the drawing canvas (left column) are updated in real-time on the result in the right column.
19
+ - Strokes stay on the canvas unless manually deleted by clicking the trash button. To remove them from the canvas after each stroke, tick the corresponding checkbox in the sidebar.
20
+
21
+ ### Links & Paper
22
+ [Project page](https://ivpg.hpi3d.de/wise/),
23
+ [arxiv link](https://arxiv.org/abs/2207.14606)
24
+
25
+ "WISE: Whitebox Image Stylization by Example-based Learning", by Winfried Lötzsch*, Max Reimann*, Martin Büßemeyer, Amir Semmo, Jürgen Döllner, Matthias Trapp, in ECCV 2022
26
+
27
+ ### Further notes
28
+ Pull Requests and further improvements are very welcome.
29
+ Please note that the shown effect is a minimal pipeline in terms of stylization capability, the much more feature-rich oilpaint and watercolor pipelines we show in our ECCV paper cannot be open-sourced due to IP reasons.
30
+ """)
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ imageio
2
+ imageio-ffmpeg
3
+ matplotlib
4
+ Pillow
5
+ numpy
6
+ --extra-index-url https://download.pytorch.org/whl/cu113
7
+ torch
8
+ torchvision
9
+ streamlit==1.10.0
10
+ streamlit_drawable_canvas==0.8.0
11
+ streamlit_extras==0.1.5
12
+ st_click_detector
13
+ scipy