Fabrice-TIERCELIN commited on
Commit
8fb74d6
1 Parent(s): d565481

Validate parameters before

Browse files
Files changed (1) hide show
  1. gradio_demo.py +89 -28
gradio_demo.py CHANGED
@@ -30,7 +30,7 @@ parser.add_argument("--no_llava", action='store_true', default=True)#False
30
  parser.add_argument("--use_image_slider", action='store_true', default=False)
31
  parser.add_argument("--log_history", action='store_true', default=False)
32
  parser.add_argument("--loading_half_params", action='store_true', default=True)#False
33
- parser.add_argument("--use_tile_vae", action='store_true', default=False)
34
  parser.add_argument("--encoder_tile_size", type=int, default=512)
35
  parser.add_argument("--decoder_tile_size", type=int, default=64)
36
  parser.add_argument("--load_8bit_llava", action='store_true', default=False)
@@ -67,15 +67,16 @@ if torch.cuda.device_count() > 0:
67
  else:
68
  llava_agent = None
69
 
70
- @spaces.GPU(duration=120)
 
 
 
 
71
  def stage1_process(input_image, gamma_correction):
72
  print('Start stage1_process')
73
  if torch.cuda.device_count() == 0:
74
  gr.Warning('Set this space to GPU config to make it work.')
75
  return None
76
- if input_image is None:
77
- gr.Warning('Please provide an image to restore.')
78
- return None
79
  torch.cuda.set_device(SUPIR_device)
80
  LQ = HWC3(input_image)
81
  LQ = fix_resize(LQ, 512)
@@ -92,15 +93,12 @@ def stage1_process(input_image, gamma_correction):
92
  print('End stage1_process')
93
  return LQ
94
 
95
- @spaces.GPU(duration=120)
96
  def llave_process(input_image, temperature, top_p, qs=None):
97
  print('Start llave_process')
98
  if torch.cuda.device_count() == 0:
99
  gr.Warning('Set this space to GPU config to make it work.')
100
  return 'Set this space to GPU config to make it work.'
101
- if input_image is None:
102
- gr.Warning('Please provide an image to restore.')
103
- return 'Please provide an image to restore.'
104
  torch.cuda.set_device(LLaVA_device)
105
  if use_llava:
106
  LQ = HWC3(input_image)
@@ -111,7 +109,7 @@ def llave_process(input_image, temperature, top_p, qs=None):
111
  print('End llave_process')
112
  return captions[0]
113
 
114
- @spaces.GPU(duration=120)
115
  def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
116
  s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
117
  linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select):
@@ -119,9 +117,6 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
119
  if torch.cuda.device_count() == 0:
120
  gr.Warning('Set this space to GPU config to make it work.')
121
  return None, None, None, None
122
- if input_image is None:
123
- gr.Warning('Please provide an image to restore.')
124
- return None, None, None, None
125
  torch.cuda.set_device(SUPIR_device)
126
  event_id = str(time.time_ns())
127
  event_dict = {'event_id': event_id, 'localtime': time.ctime(), 'prompt': prompt, 'a_prompt': a_prompt,
@@ -279,7 +274,7 @@ with gr.Blocks(title='SUPIR') as interface:
279
  qs = gr.Textbox(label="Question", info="Describe the image and its style in a very detailed manner", placeholder="The image is a realistic photography, not an art painting.")
280
 
281
  with gr.Accordion("Restoring options", open=False):
282
- num_samples = gr.Slider(label="Num Samples", info="Number of generated results; I discourage to increase because the process is limited to 2 min", minimum=1, maximum=4 if not args.use_image_slider else 1
283
  , value=1, step=1)
284
  upscale = gr.Slider(label="Upscale", info="The resolution increase factor", minimum=1, maximum=8, value=1, step=1)
285
  edm_steps = gr.Slider(label="Steps", info="lower=faster, higher=more details", minimum=1, maximum=200, value=default_setting.edm_steps if torch.cuda.device_count() > 0 else 1, step=1)
@@ -319,10 +314,10 @@ with gr.Blocks(title='SUPIR') as interface:
319
  ae_dtype = gr.Radio(['fp32', 'bf16'], label="Auto-Encoder Data Type", value="bf16",
320
  interactive=True)
321
  with gr.Column():
322
- color_fix_type = gr.Radio(["None", "AdaIn", "Wavelet"], label="Color-Fix Type", value="Wavelet",
323
  interactive=True)
324
  with gr.Column():
325
- model_select = gr.Radio(["v0-Q", "v0-F"], label="Model Selection", value="v0-Q",
326
  interactive=True)
327
 
328
  with gr.Column():
@@ -352,17 +347,83 @@ with gr.Blocks(title='SUPIR') as interface:
352
  with gr.Row():
353
  gr.Markdown(claim_md)
354
  event_id = gr.Textbox(label="Event ID", value="", visible=False)
355
-
356
- llave_button.click(fn=llave_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt])
357
- denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction],
358
- outputs=[denoise_image])
359
- stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
360
- s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
361
- linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select]
362
- diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery, event_id, fb_score, fb_text])
363
- restart_button.click(fn=load_and_reset, inputs=[param_setting],
364
- outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
365
- color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2])
366
- submit_button.click(fn=submit_feedback, inputs=[event_id, fb_score, fb_text], outputs=[fb_text])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
  interface.queue(10).launch()
 
30
  parser.add_argument("--use_image_slider", action='store_true', default=False)
31
  parser.add_argument("--log_history", action='store_true', default=False)
32
  parser.add_argument("--loading_half_params", action='store_true', default=True)#False
33
+ parser.add_argument("--use_tile_vae", action='store_true', default=True)#False
34
  parser.add_argument("--encoder_tile_size", type=int, default=512)
35
  parser.add_argument("--decoder_tile_size", type=int, default=64)
36
  parser.add_argument("--load_8bit_llava", action='store_true', default=False)
 
67
  else:
68
  llava_agent = None
69
 
70
+ def check(input_image):
71
+ if input_image is None:
72
+ raise gr.Error("Please provide an image to restore.")
73
+
74
+ @spaces.GPU(duration=180)
75
  def stage1_process(input_image, gamma_correction):
76
  print('Start stage1_process')
77
  if torch.cuda.device_count() == 0:
78
  gr.Warning('Set this space to GPU config to make it work.')
79
  return None
 
 
 
80
  torch.cuda.set_device(SUPIR_device)
81
  LQ = HWC3(input_image)
82
  LQ = fix_resize(LQ, 512)
 
93
  print('End stage1_process')
94
  return LQ
95
 
96
+ @spaces.GPU(duration=180)
97
  def llave_process(input_image, temperature, top_p, qs=None):
98
  print('Start llave_process')
99
  if torch.cuda.device_count() == 0:
100
  gr.Warning('Set this space to GPU config to make it work.')
101
  return 'Set this space to GPU config to make it work.'
 
 
 
102
  torch.cuda.set_device(LLaVA_device)
103
  if use_llava:
104
  LQ = HWC3(input_image)
 
109
  print('End llave_process')
110
  return captions[0]
111
 
112
+ @spaces.GPU(duration=180)
113
  def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
114
  s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
115
  linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select):
 
117
  if torch.cuda.device_count() == 0:
118
  gr.Warning('Set this space to GPU config to make it work.')
119
  return None, None, None, None
 
 
 
120
  torch.cuda.set_device(SUPIR_device)
121
  event_id = str(time.time_ns())
122
  event_dict = {'event_id': event_id, 'localtime': time.ctime(), 'prompt': prompt, 'a_prompt': a_prompt,
 
274
  qs = gr.Textbox(label="Question", info="Describe the image and its style in a very detailed manner", placeholder="The image is a realistic photography, not an art painting.")
275
 
276
  with gr.Accordion("Restoring options", open=False):
277
+ num_samples = gr.Slider(label="Num Samples", info="Number of generated results; I discourage to increase because the process is limited to 3 min", minimum=1, maximum=4 if not args.use_image_slider else 1
278
  , value=1, step=1)
279
  upscale = gr.Slider(label="Upscale", info="The resolution increase factor", minimum=1, maximum=8, value=1, step=1)
280
  edm_steps = gr.Slider(label="Steps", info="lower=faster, higher=more details", minimum=1, maximum=200, value=default_setting.edm_steps if torch.cuda.device_count() > 0 else 1, step=1)
 
314
  ae_dtype = gr.Radio(['fp32', 'bf16'], label="Auto-Encoder Data Type", value="bf16",
315
  interactive=True)
316
  with gr.Column():
317
+ color_fix_type = gr.Radio(["None", "AdaIn", "Wavelet"], label="Color-Fix Type", info="Wavelet=For JPEG artifacts", value="Wavelet",
318
  interactive=True)
319
  with gr.Column():
320
+ model_select = gr.Radio(["v0-Q", "v0-F"], label="Model Selection", info="Q=Quality, F=Fidelity", value="v0-Q",
321
  interactive=True)
322
 
323
  with gr.Column():
 
347
  with gr.Row():
348
  gr.Markdown(claim_md)
349
  event_id = gr.Textbox(label="Event ID", value="", visible=False)
350
+
351
+ denoise_button.click(fn = check, inputs = [
352
+ input_image
353
+ ], outputs = [], queue = False, show_progress = False).success(fn = stage1_process, inputs = [
354
+ input_image,
355
+ gamma_correction
356
+ ], outputs=[
357
+ denoise_image
358
+ ])
359
+
360
+ llave_button.click(fn = check, inputs = [
361
+ denoise_image
362
+ ], outputs = [], queue = False, show_progress = False).success(fn = llave_process, inputs = [
363
+ denoise_image,
364
+ temperature,
365
+ top_p,
366
+ qs
367
+ ], outputs = [
368
+ prompt
369
+ ])
370
+
371
+ diffusion_button.click(fn = check, inputs = [
372
+ input_image
373
+ ], outputs = [], queue = False, show_progress = False).success(fn=stage2_process, inputs = [
374
+ input_image,
375
+ prompt,
376
+ a_prompt,
377
+ n_prompt,
378
+ num_samples,
379
+ upscale,
380
+ edm_steps,
381
+ s_stage1,
382
+ s_stage2,
383
+ s_cfg,
384
+ seed,
385
+ s_churn,
386
+ s_noise,
387
+ color_fix_type,
388
+ diff_dtype,
389
+ ae_dtype,
390
+ gamma_correction,
391
+ linear_CFG,
392
+ linear_s_stage2,
393
+ spt_linear_CFG,
394
+ spt_linear_s_stage2,
395
+ model_select
396
+ ], outputs = [
397
+ result_gallery,
398
+ event_id,
399
+ fb_score,
400
+ fb_text
401
+ ])
402
+
403
+ restart_button.click(fn = load_and_reset, inputs = [
404
+ param_setting
405
+ ], outputs = [
406
+ edm_steps,
407
+ s_cfg,
408
+ s_stage2,
409
+ s_stage1,
410
+ s_churn,
411
+ s_noise,
412
+ a_prompt,
413
+ n_prompt,
414
+ color_fix_type,
415
+ linear_CFG,
416
+ linear_s_stage2,
417
+ spt_linear_CFG,
418
+ spt_linear_s_stage2
419
+ ])
420
+
421
+ submit_button.click(fn = submit_feedback, inputs = [
422
+ event_id,
423
+ fb_score,
424
+ fb_text
425
+ ], outputs = [
426
+ fb_text
427
+ ])
428
 
429
  interface.queue(10).launch()