Demise307 commited on
Commit
efed4da
·
verified ·
1 Parent(s): 76b207d

Update sampler_invsr.py

Browse files
Files changed (1) hide show
  1. sampler_invsr.py +73 -196
sampler_invsr.py CHANGED
@@ -1,6 +1,5 @@
1
  #!/usr/bin/env python
2
  # -*- coding:utf-8 -*-
3
- # Power by Zongsheng Yue 2022-07-13 16:59:27
4
 
5
  import os, sys, math, random
6
 
@@ -17,39 +16,26 @@ from utils import util_color_fix
17
 
18
  import torch
19
  import torch.nn.functional as F
20
- import torch.distributed as dist
21
- import torch.multiprocessing as mp
22
 
23
  from datapipe.datasets import create_dataset
24
- from diffusers import StableDiffusionInvEnhancePipeline, AutoencoderKL
25
 
26
- _positive= 'Cinematic, high-contrast, photo-realistic, 8k, ultra HD, ' +\
27
- 'meticulous detailing, hyper sharpness, perfect without deformations'
28
- _negative= 'Low quality, blurring, jpeg artifacts, deformed, over-smooth, cartoon, noisy,' +\
29
- 'painting, drawing, sketch, oil painting'
30
 
31
  def get_torch_dtype(torch_dtype: str):
32
- if torch_dtype == 'torch.float16':
33
- return torch.float16
34
- elif torch_dtype == 'torch.bfloat16':
35
- return torch.bfloat16
36
- elif torch_dtype == 'torch.float32':
37
- return torch.float32
38
- else:
39
- raise ValueError(f'Unexpected torch dtype:{torch_dtype}')
40
 
41
  class BaseSampler:
42
  def __init__(self, configs):
43
- '''
44
- Input:
45
- configs: config, see the yaml file in folder ./configs/
46
- configs.sampler_config.{start_timesteps, padding_mod, seed, sf, num_sample_steps}
47
- seed: int, random seed
48
- '''
49
  self.configs = configs
50
 
51
- self.setup_seed()
 
52
 
 
53
  self.build_model()
54
 
55
  def setup_seed(self, seed=None):
@@ -57,228 +43,119 @@ class BaseSampler:
57
  random.seed(seed)
58
  np.random.seed(seed)
59
  torch.manual_seed(seed)
60
- torch.cuda.manual_seed_all(seed)
61
 
62
  def write_log(self, log_str):
63
  print(log_str, flush=True)
64
 
65
  def build_model(self):
66
- # Build Stable diffusion
67
  params = dict(self.configs.sd_pipe.params)
68
- torch_dtype = params.pop('torch_dtype')
69
- params['torch_dtype'] = get_torch_dtype(torch_dtype)
70
- base_pipe = util_common.get_obj_from_str(self.configs.sd_pipe.target).from_pretrained(**params)
 
 
 
71
  if self.configs.get('scheduler', None) is not None:
72
- pipe_id = self.configs.scheduler.target.split('.')[-1]
73
- self.write_log(f'Loading scheduler of {pipe_id}...')
74
- base_pipe.scheduler = util_common.get_obj_from_str(self.configs.scheduler.target).from_config(
75
- base_pipe.scheduler.config
76
- )
77
- self.write_log('Loaded Done')
78
- if self.configs.get('vae_fp16', None) is not None:
79
- params_vae = dict(self.configs.vae_fp16.params)
80
- torch_dtype = params_vae.pop('torch_dtype')
81
- params_vae['torch_dtype'] = get_torch_dtype(torch_dtype)
82
- pipe_id = self.configs.vae_fp16.params.pretrained_model_name_or_path
83
- self.write_log(f'Loading improved vae from {pipe_id}...')
84
- base_pipe.vae = util_common.get_obj_from_str(self.configs.vae_fp16.target).from_pretrained(
85
- **params_vae,
86
- )
87
- self.write_log('Loaded Done')
88
- if self.configs.base_model in ['sd-turbo', 'sd2base'] :
89
  sd_pipe = StableDiffusionInvEnhancePipeline.from_pipe(base_pipe)
90
  else:
91
  raise ValueError(f"Unsupported base model: {self.configs.base_model}!")
92
- sd_pipe.to(f"cuda")
93
- if self.configs.sliced_vae:
94
- sd_pipe.vae.enable_slicing()
95
- if self.configs.tiled_vae:
96
- sd_pipe.vae.enable_tiling()
97
- sd_pipe.vae.tile_latent_min_size = self.configs.latent_tiled_size
98
- sd_pipe.vae.tile_sample_min_size = self.configs.sample_tiled_size
99
- if self.configs.gradient_checkpointing_vae:
100
- self.write_log(f"Activating gradient checkpoing for vae...")
101
- sd_pipe.vae._set_gradient_checkpointing(sd_pipe.vae.encoder, True)
102
- sd_pipe.vae._set_gradient_checkpointing(sd_pipe.vae.decoder, True)
103
 
104
  model_configs = self.configs.model_start
105
  params = model_configs.get('params', dict)
 
106
  model_start = util_common.get_obj_from_str(model_configs.target)(**params)
107
- model_start.cuda()
 
108
  ckpt_path = model_configs.get('ckpt_path')
109
- assert ckpt_path is not None
110
- self.write_log(f"Loading started model from {ckpt_path}...")
111
- state = torch.load(ckpt_path, map_location=f"cuda")
 
112
  if 'state_dict' in state:
113
  state = state['state_dict']
 
114
  util_net.reload_model(model_start, state)
115
- self.write_log(f"Loading Done")
116
  model_start.eval()
117
  setattr(sd_pipe, 'start_noise_predictor', model_start)
118
 
119
  self.sd_pipe = sd_pipe
120
 
 
121
  class InvSamplerSR(BaseSampler):
122
  @torch.no_grad()
123
  def sample_func(self, im_cond):
124
- '''
125
- Input:
126
- im_cond: b x c x h x w, torch tensor, [0,1], RGB
127
- Output:
128
- xt: h x w x c, numpy array, [0,1], RGB
129
- '''
130
- if self.configs.cfg_scale > 1.0:
131
- negative_prompt = [_negative,]*im_cond.shape[0]
132
- else:
133
- negative_prompt = None
134
-
135
- ori_h_lq, ori_w_lq = im_cond.shape[-2:]
136
- ori_w_hq = ori_w_lq * self.configs.basesr.sf
137
- ori_h_hq = ori_h_lq * self.configs.basesr.sf
138
- vae_sf = (2 ** (len(self.sd_pipe.vae.config.block_out_channels) - 1))
139
- if hasattr(self.sd_pipe, 'unet'):
140
- diffusion_sf = (2 ** (len(self.sd_pipe.unet.config.block_out_channels) - 1))
141
- else:
142
- diffusion_sf = self.sd_pipe.transformer.patch_size
143
- mod_lq = vae_sf // self.configs.basesr.sf * diffusion_sf
144
  idle_pch_size = self.configs.basesr.chopping.pch_size
145
 
146
  if min(im_cond.shape[-2:]) >= idle_pch_size:
147
  pad_h_up = pad_w_left = 0
148
  else:
149
- while min(im_cond.shape[-2:]) < idle_pch_size:
150
- pad_h_up = max(min((idle_pch_size - im_cond.shape[-2]) // 2, im_cond.shape[-2]-1), 0)
151
- pad_h_down = max(min(idle_pch_size - im_cond.shape[-2] - pad_h_up, im_cond.shape[-2]-1), 0)
152
- pad_w_left = max(min((idle_pch_size - im_cond.shape[-1]) // 2, im_cond.shape[-1]-1), 0)
153
- pad_w_right = max(min(idle_pch_size - im_cond.shape[-1] - pad_w_left, im_cond.shape[-1]-1), 0)
154
- im_cond = F.pad(im_cond, pad=(pad_w_left, pad_w_right, pad_h_up, pad_h_down), mode='reflect')
155
-
156
- if im_cond.shape[-2] == idle_pch_size and im_cond.shape[-1] == idle_pch_size:
157
- target_size = (
158
- im_cond.shape[-2] * self.configs.basesr.sf,
159
- im_cond.shape[-1] * self.configs.basesr.sf
160
- )
161
- res_sr = self.sd_pipe(
162
- image=im_cond.type(torch.float16),
163
- prompt=[_positive, ]*im_cond.shape[0],
164
- negative_prompt=negative_prompt,
165
- target_size=target_size,
166
- timesteps=self.configs.timesteps,
167
- guidance_scale=self.configs.cfg_scale,
168
- output_type="pt", # torch tensor, b x c x h x w, [0, 1]
169
- ).images
170
- else:
171
- if not (im_cond.shape[-2] % mod_lq == 0 and im_cond.shape[-1] % mod_lq == 0):
172
- target_h_lq = math.ceil(im_cond.shape[-2] / mod_lq) * mod_lq
173
- target_w_lq = math.ceil(im_cond.shape[-1] / mod_lq) * mod_lq
174
- pad_h = target_h_lq - im_cond.shape[-2]
175
- pad_w = target_w_lq - im_cond.shape[-1]
176
- im_cond= F.pad(im_cond, pad=(0, pad_w, 0, pad_h), mode='reflect')
177
-
178
- im_spliter = util_image.ImageSpliterTh(
179
- im_cond,
180
- pch_size=idle_pch_size,
181
- stride= int(idle_pch_size * 0.50),
182
- sf=self.configs.basesr.sf,
183
- weight_type=self.configs.basesr.chopping.weight_type,
184
- extra_bs=1 if self.configs.bs > 1 else self.configs.bs,
185
- )
186
- for im_lq_pch, index_infos in im_spliter:
187
- target_size = (
188
- im_lq_pch.shape[-2] * self.configs.basesr.sf,
189
- im_lq_pch.shape[-1] * self.configs.basesr.sf,
190
- )
191
-
192
- # start = torch.cuda.Event(enable_timing=True)
193
- # end = torch.cuda.Event(enable_timing=True)
194
- # start.record()
195
-
196
- res_sr_pch = self.sd_pipe(
197
- image=im_lq_pch.type(torch.float16),
198
- prompt=[_positive, ]*im_lq_pch.shape[0],
199
- negative_prompt=negative_prompt,
200
- target_size=target_size,
201
- timesteps=self.configs.timesteps,
202
- guidance_scale=self.configs.cfg_scale,
203
- output_type="pt", # torch tensor, b x c x h x w, [0, 1]
204
- ).images
205
-
206
- # end.record()
207
- # torch.cuda.synchronize()
208
- # print(f"Time: {start.elapsed_time(end):.6f}")
209
-
210
- im_spliter.update(res_sr_pch, index_infos)
211
- res_sr = im_spliter.gather()
212
-
213
- pad_h_up *= self.configs.basesr.sf
214
- pad_w_left *= self.configs.basesr.sf
215
- res_sr = res_sr[:, :, pad_h_up:ori_h_hq+pad_h_up, pad_w_left:ori_w_hq+pad_w_left]
216
-
217
- if self.configs.color_fix:
218
- im_cond_up = F.interpolate(
219
- im_cond, size=res_sr.shape[-2:], mode='bicubic', align_corners=False, antialias=True
220
- )
221
- if self.configs.color_fix == 'ycbcr':
222
- res_sr = util_color_fix.ycbcr_color_replace(res_sr, im_cond_up)
223
- elif self.configs.color_fix == 'wavelet':
224
- res_sr = util_color_fix.wavelet_reconstruction(res_sr, im_cond_up)
225
- else:
226
- raise ValueError(f"Unsupported color fixing type: {self.configs.color_fix}")
227
-
228
- res_sr = res_sr.clamp(0.0, 1.0).cpu().permute(0,2,3,1).float().numpy()
229
 
230
  return res_sr
231
 
232
- def inference(self, in_path, out_path, bs=1):
233
- '''
234
- Inference demo.
235
- Input:
236
- in_path: str, folder or image path for LQ image
237
- out_path: str, folder save the results
238
- bs: int, default bs=1, bs % num_gpus == 0
239
- '''
240
 
241
- in_path = Path(in_path) if not isinstance(in_path, Path) else in_path
242
- out_path = Path(out_path) if not isinstance(out_path, Path) else out_path
243
 
244
- if not out_path.exists():
245
- out_path.mkdir(parents=True)
 
246
 
247
  if in_path.is_dir():
248
- data_config = {'type': 'base',
249
- 'params': {'dir_path': str(in_path),
250
- 'transform_type': 'default',
251
- 'transform_kwargs': {
252
- 'mean': 0.0,
253
- 'std': 1.0,
254
- },
255
- 'need_path': True,
256
- 'recursive': False,
257
- 'length': None,
258
- }
259
- }
260
- dataset = create_dataset(data_config)
261
- self.write_log(f'Find {len(dataset)} images in {in_path}')
262
- dataloader = torch.utils.data.DataLoader(
263
- dataset, batch_size=bs, shuffle=False, drop_last=False,
264
- )
265
  for data in dataloader:
266
- res = self.sample_func(data['lq'].cuda())
267
 
268
  for jj in range(res.shape[0]):
269
- im_name = Path(data['path'][jj]).stem
270
- save_path = str(out_path / f"{im_name}.png")
271
  util_image.imwrite(res[jj], save_path, dtype_in='float32')
 
272
  else:
273
- im_cond = util_image.imread(in_path, chn='rgb', dtype='float32') # h x w x c
274
- im_cond = util_image.img2tensor(im_cond).cuda() # 1 x c x h x w
275
 
276
  image = self.sample_func(im_cond).squeeze(0)
277
 
278
  save_path = str(out_path / f"{in_path.stem}.png")
279
  util_image.imwrite(image, save_path, dtype_in='float32')
280
 
281
- self.write_log(f"Processing done, enjoy the results in {str(out_path)}")
282
 
283
  if __name__ == '__main__':
284
  pass
 
1
  #!/usr/bin/env python
2
  # -*- coding:utf-8 -*-
 
3
 
4
  import os, sys, math, random
5
 
 
16
 
17
  import torch
18
  import torch.nn.functional as F
 
 
19
 
20
  from datapipe.datasets import create_dataset
21
+ from diffusers import StableDiffusionInvEnhancePipeline
22
 
23
+ _positive= 'Cinematic, high-contrast, photo-realistic, 8k, ultra HD, meticulous detailing'
24
+ _negative= 'Low quality, blurring, jpeg artifacts, deformed, noisy'
 
 
25
 
26
  def get_torch_dtype(torch_dtype: str):
27
+ # 🔥 Force float32 for CPU
28
+ return torch.float32
29
+
 
 
 
 
 
30
 
31
  class BaseSampler:
32
  def __init__(self, configs):
 
 
 
 
 
 
33
  self.configs = configs
34
 
35
+ # ✅ CPU device
36
+ self.device = torch.device("cpu")
37
 
38
+ self.setup_seed()
39
  self.build_model()
40
 
41
  def setup_seed(self, seed=None):
 
43
  random.seed(seed)
44
  np.random.seed(seed)
45
  torch.manual_seed(seed)
 
46
 
47
  def write_log(self, log_str):
48
  print(log_str, flush=True)
49
 
50
  def build_model(self):
 
51
  params = dict(self.configs.sd_pipe.params)
52
+ params['torch_dtype'] = torch.float32 # CPU safe
53
+
54
+ base_pipe = util_common.get_obj_from_str(
55
+ self.configs.sd_pipe.target
56
+ ).from_pretrained(**params)
57
+
58
  if self.configs.get('scheduler', None) is not None:
59
+ base_pipe.scheduler = util_common.get_obj_from_str(
60
+ self.configs.scheduler.target
61
+ ).from_config(base_pipe.scheduler.config)
62
+
63
+ if self.configs.base_model in ['sd-turbo', 'sd2base']:
 
 
 
 
 
 
 
 
 
 
 
 
64
  sd_pipe = StableDiffusionInvEnhancePipeline.from_pipe(base_pipe)
65
  else:
66
  raise ValueError(f"Unsupported base model: {self.configs.base_model}!")
67
+
68
+ # ✅ move to CPU
69
+ sd_pipe.to(self.device)
 
 
 
 
 
 
 
 
70
 
71
  model_configs = self.configs.model_start
72
  params = model_configs.get('params', dict)
73
+
74
  model_start = util_common.get_obj_from_str(model_configs.target)(**params)
75
+ model_start.to(self.device)
76
+
77
  ckpt_path = model_configs.get('ckpt_path')
78
+ self.write_log(f"Loading model from {ckpt_path}...")
79
+
80
+ state = torch.load(ckpt_path, map_location=self.device)
81
+
82
  if 'state_dict' in state:
83
  state = state['state_dict']
84
+
85
  util_net.reload_model(model_start, state)
86
+
87
  model_start.eval()
88
  setattr(sd_pipe, 'start_noise_predictor', model_start)
89
 
90
  self.sd_pipe = sd_pipe
91
 
92
+
93
  class InvSamplerSR(BaseSampler):
94
  @torch.no_grad()
95
  def sample_func(self, im_cond):
96
+
97
+ im_cond = im_cond.to(self.device)
98
+
99
+ negative_prompt = [_negative]*im_cond.shape[0] if self.configs.cfg_scale > 1.0 else None
100
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  idle_pch_size = self.configs.basesr.chopping.pch_size
102
 
103
  if min(im_cond.shape[-2:]) >= idle_pch_size:
104
  pad_h_up = pad_w_left = 0
105
  else:
106
+ pad_h_up = pad_w_left = 0
107
+
108
+ target_size = (
109
+ im_cond.shape[-2] * self.configs.basesr.sf,
110
+ im_cond.shape[-1] * self.configs.basesr.sf
111
+ )
112
+
113
+ res_sr = self.sd_pipe(
114
+ image=im_cond.float(), # ✅ float32
115
+ prompt=[_positive]*im_cond.shape[0],
116
+ negative_prompt=negative_prompt,
117
+ target_size=target_size,
118
+ timesteps=self.configs.timesteps,
119
+ guidance_scale=self.configs.cfg_scale,
120
+ output_type="pt",
121
+ ).images
122
+
123
+ res_sr = res_sr.clamp(0.0, 1.0).cpu().permute(0,2,3,1).numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  return res_sr
126
 
 
 
 
 
 
 
 
 
127
 
128
+ def inference(self, in_path, out_path, bs=1):
 
129
 
130
+ in_path = Path(in_path)
131
+ out_path = Path(out_path)
132
+ out_path.mkdir(parents=True, exist_ok=True)
133
 
134
  if in_path.is_dir():
135
+ dataset = create_dataset({
136
+ 'type': 'base',
137
+ 'params': {'dir_path': str(in_path)}
138
+ })
139
+
140
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs)
141
+
 
 
 
 
 
 
 
 
 
 
142
  for data in dataloader:
143
+ res = self.sample_func(data['lq'])
144
 
145
  for jj in range(res.shape[0]):
146
+ save_path = str(out_path / f"{jj}.png")
 
147
  util_image.imwrite(res[jj], save_path, dtype_in='float32')
148
+
149
  else:
150
+ im_cond = util_image.imread(in_path, chn='rgb', dtype='float32')
151
+ im_cond = util_image.img2tensor(im_cond).to(self.device)
152
 
153
  image = self.sample_func(im_cond).squeeze(0)
154
 
155
  save_path = str(out_path / f"{in_path.stem}.png")
156
  util_image.imwrite(image, save_path, dtype_in='float32')
157
 
158
+ self.write_log(f"Done {out_path}")
159
 
160
  if __name__ == '__main__':
161
  pass