Spaces:
Paused
Paused
Andranik Sargsyan
commited on
Commit
·
1df97f6
1
Parent(s):
fd3e2fa
add saving/recovering tmp user data for faster processing
Browse files- app.py +77 -25
- assets/sr_info.png +3 -0
- lib/methods/sr.py +8 -3
app.py
CHANGED
@@ -75,11 +75,57 @@ def set_model_from_name(inp_model_name):
|
|
75 |
inp_model = inpainting_models[inp_model_name]
|
76 |
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
def rasg_run(
|
79 |
use_painta, prompt, input, seed, eta,
|
80 |
negative_prompt, positive_prompt, ddim_steps,
|
81 |
guidance_scale=7.5,
|
82 |
-
batch_size=1
|
83 |
):
|
84 |
torch.cuda.empty_cache()
|
85 |
|
@@ -119,15 +165,18 @@ def rasg_run(
|
|
119 |
dilation=12
|
120 |
)
|
121 |
blended_images.append(blended_image)
|
122 |
-
inpainted_images.append(inpainted_image.
|
|
|
|
|
|
|
123 |
|
124 |
-
return blended_images,
|
125 |
|
126 |
|
127 |
def sd_run(use_painta, prompt, input, seed, eta,
|
128 |
negative_prompt, positive_prompt, ddim_steps,
|
129 |
guidance_scale=7.5,
|
130 |
-
batch_size=1
|
131 |
):
|
132 |
torch.cuda.empty_cache()
|
133 |
|
@@ -167,32 +216,37 @@ def sd_run(use_painta, prompt, input, seed, eta,
|
|
167 |
dilation=12
|
168 |
)
|
169 |
blended_images.append(blended_image)
|
170 |
-
inpainted_images.append(inpainted_image.
|
171 |
|
172 |
-
|
|
|
|
|
|
|
173 |
|
174 |
|
175 |
def upscale_run(
|
176 |
-
|
177 |
negative_prompt='',
|
178 |
positive_prompt=', high resolution professional photo'
|
179 |
):
|
|
|
|
|
|
|
|
|
|
|
180 |
torch.cuda.empty_cache()
|
181 |
|
182 |
seed = int(seed)
|
183 |
img_index = int(img_index)
|
184 |
-
|
185 |
img_index = 0 if img_index < 0 else img_index
|
186 |
img_index = len(gallery) - 1 if img_index >= len(gallery) else img_index
|
187 |
-
|
188 |
-
|
189 |
-
lr_image = IImage(inpainted_image)
|
190 |
-
hr_image = IImage(input['image']).resize(2048)
|
191 |
-
hr_mask = IImage(input['mask']).resize(2048)
|
192 |
output_image = sr.run(
|
193 |
sr_model,
|
194 |
sam_predictor,
|
195 |
-
|
196 |
hr_image,
|
197 |
hr_mask,
|
198 |
prompt=prompt + positive_prompt,
|
@@ -203,8 +257,8 @@ def upscale_run(
|
|
203 |
seed=seed,
|
204 |
use_sam_mask=use_sam_mask
|
205 |
)
|
206 |
-
|
207 |
-
return output_image
|
208 |
|
209 |
|
210 |
def switch_run(use_rasg, model_name, *args):
|
@@ -316,8 +370,7 @@ with gr.Blocks(css='style.css') as demo:
|
|
316 |
[input, prompt, example_container]
|
317 |
)
|
318 |
|
319 |
-
|
320 |
-
mock_hires = gr.Image(label = "__MHRO__", visible = False)
|
321 |
html_info = gr.HTML(elem_id=f'html_info', elem_classes="infotext")
|
322 |
|
323 |
inpaint_btn.click(
|
@@ -334,25 +387,24 @@ with gr.Blocks(css='style.css') as demo:
|
|
334 |
positive_prompt,
|
335 |
ddim_steps,
|
336 |
guidance_scale,
|
337 |
-
batch_size
|
|
|
338 |
],
|
339 |
-
outputs=[output_gallery,
|
340 |
api_name="inpaint"
|
341 |
)
|
342 |
upscale_btn.click(
|
343 |
fn=upscale_run,
|
344 |
inputs=[
|
345 |
-
prompt,
|
346 |
-
input,
|
347 |
ddim_steps,
|
348 |
seed,
|
349 |
use_sam_mask,
|
350 |
-
|
351 |
html_info
|
352 |
],
|
353 |
-
outputs=[hires_image
|
354 |
api_name="upscale",
|
355 |
-
_js="function(a, b, c, d, e
|
356 |
)
|
357 |
|
358 |
demo.queue(max_size=20)
|
|
|
75 |
inp_model = inpainting_models[inp_model_name]
|
76 |
|
77 |
|
78 |
+
def save_user_session(hr_image, hr_mask, lr_results, prompt, session_id=None):
|
79 |
+
if session_id == '':
|
80 |
+
session_id = str(uuid.uuid4())
|
81 |
+
|
82 |
+
tmp_dir = Path(TMP_DIR)
|
83 |
+
session_dir = tmp_dir / session_id
|
84 |
+
session_dir.mkdir(exist_ok=True, parents=True)
|
85 |
+
|
86 |
+
hr_image.save(session_dir / 'hr_image.png')
|
87 |
+
hr_mask.save(session_dir / 'hr_mask.png')
|
88 |
+
|
89 |
+
lr_results_dir = session_dir / 'lr_results'
|
90 |
+
if lr_results_dir.exists():
|
91 |
+
shutil.rmtree(lr_results_dir)
|
92 |
+
lr_results_dir.mkdir(parents=True)
|
93 |
+
for i, lr_result in enumerate(lr_results):
|
94 |
+
lr_result.save(lr_results_dir / f'{i}.png')
|
95 |
+
|
96 |
+
with open(session_dir / 'prompt.txt', 'w') as f:
|
97 |
+
f.write(prompt)
|
98 |
+
|
99 |
+
return session_id
|
100 |
+
|
101 |
+
|
102 |
+
def recover_user_session(session_id):
|
103 |
+
if session_id == '':
|
104 |
+
return None, None, []
|
105 |
+
|
106 |
+
tmp_dir = Path(TMP_DIR)
|
107 |
+
session_dir = tmp_dir / session_id
|
108 |
+
lr_results_dir = session_dir / 'lr_results'
|
109 |
+
|
110 |
+
hr_image = Image.open(session_dir / 'hr_image.png')
|
111 |
+
hr_mask = Image.open(session_dir / 'hr_mask.png')
|
112 |
+
|
113 |
+
lr_result_paths = list(lr_results_dir.glob('*.png'))
|
114 |
+
gallery = []
|
115 |
+
for lr_result_path in sorted(lr_result_paths):
|
116 |
+
gallery.append(Image.open(lr_result_path))
|
117 |
+
|
118 |
+
with open(session_dir / 'prompt.txt', "r") as f:
|
119 |
+
prompt = f.read()
|
120 |
+
|
121 |
+
return hr_image, hr_mask, gallery, prompt
|
122 |
+
|
123 |
+
|
124 |
def rasg_run(
|
125 |
use_painta, prompt, input, seed, eta,
|
126 |
negative_prompt, positive_prompt, ddim_steps,
|
127 |
guidance_scale=7.5,
|
128 |
+
batch_size=1, session_id=''
|
129 |
):
|
130 |
torch.cuda.empty_cache()
|
131 |
|
|
|
165 |
dilation=12
|
166 |
)
|
167 |
blended_images.append(blended_image)
|
168 |
+
inpainted_images.append(inpainted_image.pil())
|
169 |
+
|
170 |
+
session_id = save_user_session(
|
171 |
+
input['image'], input['mask'], inpainted_images, prompt, session_id=session_id)
|
172 |
|
173 |
+
return blended_images, session_id
|
174 |
|
175 |
|
176 |
def sd_run(use_painta, prompt, input, seed, eta,
|
177 |
negative_prompt, positive_prompt, ddim_steps,
|
178 |
guidance_scale=7.5,
|
179 |
+
batch_size=1, session_id=''
|
180 |
):
|
181 |
torch.cuda.empty_cache()
|
182 |
|
|
|
216 |
dilation=12
|
217 |
)
|
218 |
blended_images.append(blended_image)
|
219 |
+
inpainted_images.append(inpainted_image.pil())
|
220 |
|
221 |
+
session_id = save_user_session(
|
222 |
+
input['image'], input['mask'], inpainted_images, prompt, session_id=session_id)
|
223 |
+
|
224 |
+
return blended_images, session_id
|
225 |
|
226 |
|
227 |
def upscale_run(
|
228 |
+
ddim_steps, seed, use_sam_mask, session_id, img_index,
|
229 |
negative_prompt='',
|
230 |
positive_prompt=', high resolution professional photo'
|
231 |
):
|
232 |
+
hr_image, hr_mask, gallery, prompt = recover_user_session(session_id)
|
233 |
+
|
234 |
+
if len(gallery) == 0:
|
235 |
+
return Image.open('./assets/sr_info.png')
|
236 |
+
|
237 |
torch.cuda.empty_cache()
|
238 |
|
239 |
seed = int(seed)
|
240 |
img_index = int(img_index)
|
241 |
+
|
242 |
img_index = 0 if img_index < 0 else img_index
|
243 |
img_index = len(gallery) - 1 if img_index >= len(gallery) else img_index
|
244 |
+
inpainted_image = gallery[img_index if img_index >= 0 else 0]
|
245 |
+
|
|
|
|
|
|
|
246 |
output_image = sr.run(
|
247 |
sr_model,
|
248 |
sam_predictor,
|
249 |
+
inpainted_image,
|
250 |
hr_image,
|
251 |
hr_mask,
|
252 |
prompt=prompt + positive_prompt,
|
|
|
257 |
seed=seed,
|
258 |
use_sam_mask=use_sam_mask
|
259 |
)
|
260 |
+
|
261 |
+
return output_image
|
262 |
|
263 |
|
264 |
def switch_run(use_rasg, model_name, *args):
|
|
|
370 |
[input, prompt, example_container]
|
371 |
)
|
372 |
|
373 |
+
session_id = gr.Textbox(value='', visible=False)
|
|
|
374 |
html_info = gr.HTML(elem_id=f'html_info', elem_classes="infotext")
|
375 |
|
376 |
inpaint_btn.click(
|
|
|
387 |
positive_prompt,
|
388 |
ddim_steps,
|
389 |
guidance_scale,
|
390 |
+
batch_size,
|
391 |
+
session_id
|
392 |
],
|
393 |
+
outputs=[output_gallery, session_id],
|
394 |
api_name="inpaint"
|
395 |
)
|
396 |
upscale_btn.click(
|
397 |
fn=upscale_run,
|
398 |
inputs=[
|
|
|
|
|
399 |
ddim_steps,
|
400 |
seed,
|
401 |
use_sam_mask,
|
402 |
+
session_id,
|
403 |
html_info
|
404 |
],
|
405 |
+
outputs=[hires_image],
|
406 |
api_name="upscale",
|
407 |
+
_js="function(a, b, c, d, e){ return [a, b, c, d, selected_gallery_index()] }",
|
408 |
)
|
409 |
|
410 |
demo.queue(max_size=20)
|
assets/sr_info.png
ADDED
Git LFS Details
|
lib/methods/sr.py
CHANGED
@@ -73,6 +73,11 @@ def run(
|
|
73 |
negative_prompt = '',
|
74 |
use_sam_mask = False
|
75 |
):
|
|
|
|
|
|
|
|
|
|
|
76 |
torch.manual_seed(seed)
|
77 |
dtype = ddim.vae.encoder.conv_in.weight.dtype
|
78 |
device = ddim.vae.encoder.conv_in.weight.device
|
@@ -143,6 +148,6 @@ def run(
|
|
143 |
fake_img=hr_result,
|
144 |
mask=hr_mask_orig.alpha().data[0]
|
145 |
)
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
73 |
negative_prompt = '',
|
74 |
use_sam_mask = False
|
75 |
):
|
76 |
+
hr_image_info = hr_image.info
|
77 |
+
lr_image = IImage(lr_image)
|
78 |
+
hr_image = IImage(hr_image).resize(2048)
|
79 |
+
hr_mask = IImage(hr_mask).resize(2048)
|
80 |
+
|
81 |
torch.manual_seed(seed)
|
82 |
dtype = ddim.vae.encoder.conv_in.weight.dtype
|
83 |
device = ddim.vae.encoder.conv_in.weight.device
|
|
|
148 |
fake_img=hr_result,
|
149 |
mask=hr_mask_orig.alpha().data[0]
|
150 |
)
|
151 |
+
hr_result = Image.fromarray(hr_result)
|
152 |
+
hr_result.info = hr_image_info # save metadata
|
153 |
+
return hr_result
|