Zongsheng commited on
Commit
d6fbbca
1 Parent(s): efc301f

change examples

Browse files
Files changed (2) hide show
  1. app.py +13 -14
  2. sampler.py +6 -1
app.py CHANGED
@@ -20,15 +20,17 @@ from sampler import DifIRSampler
20
  from ResizeRight.resize_right import resize
21
  from basicsr.utils.download_util import load_file_from_url
22
 
23
- def predict(im_path, background_enhance, face_upsample, upscale, started_timesteps):
24
- cfg_path = 'configs/sample/iddpm_ffhq512_swinir.yaml'
 
 
 
 
 
25
 
26
- # setting configurations
27
- configs = OmegaConf.load(cfg_path)
28
- configs.aligned = False
29
  configs.background_enhance = background_enhance
30
  configs.face_upsample = face_upsample
31
-
32
  started_timesteps = int(started_timesteps)
33
  assert started_timesteps < int(configs.diffusion.params.timestep_respacing)
34
 
@@ -56,9 +58,6 @@ def predict(im_path, background_enhance, face_upsample, upscale, started_timeste
56
  upscale = 2 # avoid momory exceeded due to too large img resolution
57
  configs.detection.upscale = int(upscale)
58
 
59
- # build the sampler for diffusion
60
- sampler_dist = DifIRSampler(configs)
61
-
62
  image_restored, face_restored, face_cropped = sampler_dist.sample_func_bfr_unaligned(
63
  y0=im_lq,
64
  start_timesteps=started_timesteps,
@@ -71,7 +70,7 @@ def predict(im_path, background_enhance, face_upsample, upscale, started_timeste
71
  restored_image_dir.mkdir()
72
  # save the whole image
73
  save_path = restored_image_dir / Path(im_path).name
74
- util_image.imwrite(image_restored, save_path, chn='bgr', dtype_in='uint8')
75
 
76
  return image_restored, str(save_path)
77
 
@@ -114,10 +113,10 @@ If you have any questions, please feel free to contact me via <b>zsyzam@gmail.co
114
  demo = gr.Interface(
115
  predict,
116
  inputs=[
117
- gr.inputs.Image(type="filepath", label="Input"),
118
- gr.inputs.Checkbox(default=True, label="Background_Enhance"),
119
- gr.inputs.Checkbox(default=True, label="Face_Upsample"),
120
- gr.inputs.Number(default=2, label="Rescaling_Factor (up to 4)"),
121
  gr.Slider(1, 200, value=100, step=10, label='Realism-Fidelity Trade-off')
122
  ],
123
  outputs=[
 
20
  from ResizeRight.resize_right import resize
21
  from basicsr.utils.download_util import load_file_from_url
22
 
23
+ # setting configurations
24
+ cfg_path = 'configs/sample/iddpm_ffhq512_swinir.yaml'
25
+ configs = OmegaConf.load(cfg_path)
26
+ configs.aligned = False
27
+
28
+ # build the sampler for diffusion
29
+ sampler_dist = DifIRSampler(configs)
30
 
31
+ def predict(im_path, background_enhance, face_upsample, upscale, started_timesteps):
 
 
32
  configs.background_enhance = background_enhance
33
  configs.face_upsample = face_upsample
 
34
  started_timesteps = int(started_timesteps)
35
  assert started_timesteps < int(configs.diffusion.params.timestep_respacing)
36
 
 
58
  upscale = 2 # avoid momory exceeded due to too large img resolution
59
  configs.detection.upscale = int(upscale)
60
 
 
 
 
61
  image_restored, face_restored, face_cropped = sampler_dist.sample_func_bfr_unaligned(
62
  y0=im_lq,
63
  start_timesteps=started_timesteps,
 
70
  restored_image_dir.mkdir()
71
  # save the whole image
72
  save_path = restored_image_dir / Path(im_path).name
73
+ util_image.imwrite(image_restored, save_path, chn='rgb', dtype_in='uint8')
74
 
75
  return image_restored, str(save_path)
76
 
 
113
  demo = gr.Interface(
114
  predict,
115
  inputs=[
116
+ gr.Image(type="filepath", label="Input"),
117
+ gr.Checkbox(default=True, label="Background_Enhance"),
118
+ gr.Checkbox(default=True, label="Face_Upsample"),
119
+ gr.Number(default=2, label="Rescaling_Factor (up to 4)"),
120
  gr.Slider(1, 200, value=100, step=10, label='Realism-Fidelity Trade-off')
121
  ],
122
  outputs=[
sampler.py CHANGED
@@ -54,7 +54,12 @@ class BaseSampler:
54
  torch.cuda.manual_seed_all(seed)
55
 
56
  def setup_dist(self):
57
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
 
 
58
  self.rank = 0
59
 
60
  def build_model(self):
 
54
  torch.cuda.manual_seed_all(seed)
55
 
56
  def setup_dist(self):
57
+ if torch.cuda.is_available():
58
+ self.device = torch.device('cuda')
59
+ print(f'Runing on GPU...')
60
+ else:
61
+ self.device = torch.device('cpu')
62
+ print(f'Runing on CPU...')
63
  self.rank = 0
64
 
65
  def build_model(self):