vincent-doan commited on
Commit
8744832
·
1 Parent(s): eee6b71

Added clamping to delete noise in output

Browse files
Files changed (2) hide show
  1. models/RCAN/rcan.py +6 -4
  2. models/SRGAN/srgan.py +6 -4
models/RCAN/rcan.py CHANGED
@@ -106,15 +106,17 @@ class RCAN(nn.Module):
106
  upscaled_image = self.upscaling_module(deep_feature) # Upscaling module
107
  reconstructed_image = self.reconstruction_conv(upscaled_image) # Reconstruction
108
 
109
- return reconstructed_image
110
 
111
  def inference(self, x):
112
  """
113
  x is a PIL image
114
  """
115
- x = ToTensor()(x).unsqueeze(0)
116
- x = self.forward(x)
117
- x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
 
 
118
  return x
119
 
120
  if __name__ == '__main__':
 
106
  upscaled_image = self.upscaling_module(deep_feature) # Upscaling module
107
  reconstructed_image = self.reconstruction_conv(upscaled_image) # Reconstruction
108
 
109
+ return reconstructed_image.clamp(0, 1)
110
 
111
  def inference(self, x):
112
  """
113
  x is a PIL image
114
  """
115
+ self.eval()
116
+ with torch.no_grad():
117
+ x = ToTensor()(x).unsqueeze(0)
118
+ x = self.forward(x)
119
+ x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
120
  return x
121
 
122
  if __name__ == '__main__':
models/SRGAN/srgan.py CHANGED
@@ -51,15 +51,17 @@ class GeneratorResnet(nn.Module):
51
  out = torch.add(out1, out2)
52
  out = self.upsampling(out)
53
  out = self.conv3(out)
54
- return out
55
 
56
  def inference(self, x):
57
  """
58
  x is a PIL image
59
  """
60
- x = ToTensor()(x).unsqueeze(0)
61
- x = self.forward(x)
62
- x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
 
 
63
  return x
64
 
65
  if __name__ == '__main__':
 
51
  out = torch.add(out1, out2)
52
  out = self.upsampling(out)
53
  out = self.conv3(out)
54
+ return out.clamp(0, 1)
55
 
56
  def inference(self, x):
57
  """
58
  x is a PIL image
59
  """
60
+ self.eval()
61
+ with torch.no_grad():
62
+ x = ToTensor()(x).unsqueeze(0)
63
+ x = self.forward(x)
64
+ x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
65
  return x
66
 
67
  if __name__ == '__main__':