Zongsheng commited on
Commit
4cd2c6a
1 Parent(s): f857ecf

add resize for arbitraty size

Browse files
Files changed (1) hide show
  1. sampler.py +5 -4
sampler.py CHANGED
@@ -166,6 +166,11 @@ class DifIRSampler(BaseSampler):
166
  # basical image restoration
167
  device = next(self.model.parameters()).device
168
  y0 = y0.to(device=device, dtype=torch.float32)
 
 
 
 
 
169
  if need_restoration:
170
  with torch.no_grad():
171
  if model_kwargs_ir is None:
@@ -176,10 +181,6 @@ class DifIRSampler(BaseSampler):
176
  im_hq = y0
177
  im_hq.clamp_(0.0, 1.0)
178
 
179
- h_old, w_old = im_hq.shape[2:4]
180
- if not (h_old == self.configs.im_size and w_old == self.configs.im_size):
181
- im_hq = resize(im_hq, out_shape=(self.configs.im_size,) * 2).to(torch.float32)
182
-
183
  # diffuse for im_hq
184
  yt = self.diffusion.q_sample(
185
  x_start=post_fun(im_hq),
166
  # basical image restoration
167
  device = next(self.model.parameters()).device
168
  y0 = y0.to(device=device, dtype=torch.float32)
169
+
170
+ h_old, w_old = y0.shape[2:4]
171
+ if not (h_old == self.configs.im_size and w_old == self.configs.im_size):
172
+ y0 = resize(y0, out_shape=(self.configs.im_size,) * 2).to(torch.float32)
173
+
174
  if need_restoration:
175
  with torch.no_grad():
176
  if model_kwargs_ir is None:
181
  im_hq = y0
182
  im_hq.clamp_(0.0, 1.0)
183
 
 
 
 
 
184
  # diffuse for im_hq
185
  yt = self.diffusion.q_sample(
186
  x_start=post_fun(im_hq),