Fabrice-TIERCELIN commited on
Commit
180098c
·
verified ·
1 Parent(s): 610ac0b
Files changed (1) hide show
  1. SUPIR/models/SUPIR_model.py +11 -2
SUPIR/models/SUPIR_model.py CHANGED
@@ -47,18 +47,29 @@ class SUPIRModel(DiffusionEngine):
47
 
48
  @torch.no_grad()
49
  def encode_first_stage_with_denoise(self, x, use_sample=True, is_stage1=False):
 
50
  with torch.autocast("cuda", dtype=self.ae_dtype):
 
51
  if is_stage1:
 
52
  h = self.first_stage_model.denoise_encoder_s1(x)
53
  else:
 
54
  h = self.first_stage_model.denoise_encoder(x)
 
55
  moments = self.first_stage_model.quant_conv(h)
 
56
  posterior = DiagonalGaussianDistribution(moments)
 
57
  if use_sample:
 
58
  z = posterior.sample()
59
  else:
 
60
  z = posterior.mode()
 
61
  z = self.scale_factor * z
 
62
  return z
63
 
64
  @torch.no_grad()
@@ -73,9 +84,7 @@ class SUPIRModel(DiffusionEngine):
73
  '''
74
  [N, C, H, W], [-1, 1], RGB
75
  '''
76
- print('Start batchify_denoise')
77
  x = self.encode_first_stage_with_denoise(x, use_sample=False, is_stage1=is_stage1)
78
- print('End batchify_denoise')
79
  return self.decode_first_stage(x)
80
 
81
  @torch.no_grad()
 
47
 
48
  @torch.no_grad()
49
  def encode_first_stage_with_denoise(self, x, use_sample=True, is_stage1=False):
50
+ print('encode_first_stage_with_denoise 1')
51
  with torch.autocast("cuda", dtype=self.ae_dtype):
52
+ print('encode_first_stage_with_denoise 2')
53
  if is_stage1:
54
+ print('encode_first_stage_with_denoise 3')
55
  h = self.first_stage_model.denoise_encoder_s1(x)
56
  else:
57
+ print('encode_first_stage_with_denoise 4')
58
  h = self.first_stage_model.denoise_encoder(x)
59
+ print('encode_first_stage_with_denoise 5')
60
  moments = self.first_stage_model.quant_conv(h)
61
+ print('encode_first_stage_with_denoise 6')
62
  posterior = DiagonalGaussianDistribution(moments)
63
+ print('encode_first_stage_with_denoise 7')
64
  if use_sample:
65
+ print('encode_first_stage_with_denoise 8')
66
  z = posterior.sample()
67
  else:
68
+ print('encode_first_stage_with_denoise 9')
69
  z = posterior.mode()
70
+ print('encode_first_stage_with_denoise 10')
71
  z = self.scale_factor * z
72
+ print('encode_first_stage_with_denoise 11')
73
  return z
74
 
75
  @torch.no_grad()
 
84
  '''
85
  [N, C, H, W], [-1, 1], RGB
86
  '''
 
87
  x = self.encode_first_stage_with_denoise(x, use_sample=False, is_stage1=is_stage1)
 
88
  return self.decode_first_stage(x)
89
 
90
  @torch.no_grad()