Nguyễn Bá Thiêm commited on
Commit
12f4dcf
1 Parent(s): bff701c

Add test methods to RCAN and SRGAN models

Browse files
models/RCAN/rcan.py CHANGED
@@ -118,6 +118,15 @@ class RCAN(nn.Module):
118
  x = self.forward(x)
119
  x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
120
  return x
 
 
 
 
 
 
 
 
 
121
 
122
  if __name__ == '__main__':
123
  current_dir = os.path.dirname(os.path.realpath(__file__))
 
118
  x = self.forward(x)
119
  x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
120
  return x
121
+
122
+ def test(self, x):
123
+ """
124
+ x is a tensor
125
+ """
126
+ self.eval()
127
+ with torch.no_grad():
128
+ x = self.forward(x)
129
+ return x
130
 
131
  if __name__ == '__main__':
132
  current_dir = os.path.dirname(os.path.realpath(__file__))
models/SRFlow/srflow.py CHANGED
@@ -39,7 +39,41 @@ def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.
39
  sr = Image.fromarray((sr).astype('uint8'))
40
  return sr
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  if __name__ == '__main__':
43
  ip = Image.open('images/demo.png')
44
- sr = return_SRFlow_result(ip)
 
 
45
  print(sr.size)
 
39
  sr = Image.fromarray((sr).astype('uint8'))
40
  return sr
41
 
42
+ def test(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', heat=0.7):
43
+ """
44
+ Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
45
+
46
+ Args:
47
+ - lr: tensor
48
+ - conf_path (str): Configuration file path for the SRFlow model. Default is SRFlow_DF2K_4X.yml.
49
+ - heat (float): Heat parameter for the SRFlow model. Default is 0.6.
50
+
51
+ Returns:
52
+ - sr: tensor
53
+ """
54
+ model, opt = load_model(conf_path)
55
+
56
+ scale = opt['scale']
57
+ pad_factor = 2
58
+
59
+ h, w, c = lr.shape
60
+ lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
61
+ right=int(np.ceil(w / pad_factor) * pad_factor - w))
62
+
63
+ lr_t = t(lr)
64
+ heat = opt['heat']
65
+
66
+ sr_t = model.get_sr(lq=lr_t, heat=heat)
67
+
68
+ sr = rgb(torch.clamp(sr_t, 0, 1))
69
+ sr = sr[:h * scale, :w * scale]
70
+
71
+ sr = Image.fromarray((sr).astype('uint8'))
72
+ return sr
73
+
74
  if __name__ == '__main__':
75
  ip = Image.open('images/demo.png')
76
+ lr = PILToTensor()(ip).permute(1, 2, 0).numpy()
77
+ sr = test(lr)
78
+ # sr = return_SRFlow_result(ip)
79
  print(sr.size)
models/SRGAN/srgan.py CHANGED
@@ -64,6 +64,15 @@ class GeneratorResnet(nn.Module):
64
  x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
65
  return TF.adjust_brightness(x, 1.1)
66
 
 
 
 
 
 
 
 
 
 
67
  if __name__ == '__main__':
68
  current_dir = os.path.dirname(os.path.realpath(__file__))
69
 
@@ -72,4 +81,6 @@ if __name__ == '__main__':
72
  model.eval()
73
  with torch.no_grad():
74
  input_image = Image.open('images/demo.png')
75
- output_image = model.inference(input_image)
 
 
 
64
  x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
65
  return TF.adjust_brightness(x, 1.1)
66
 
67
+ def test(self, x):
68
+ """
69
+ x is a tensor
70
+ """
71
+ self.eval()
72
+ with torch.no_grad():
73
+ x = self.forward(x)
74
+ return x
75
+
76
  if __name__ == '__main__':
77
  current_dir = os.path.dirname(os.path.realpath(__file__))
78
 
 
81
  model.eval()
82
  with torch.no_grad():
83
  input_image = Image.open('images/demo.png')
84
+ input_image = ToTensor()(input_image).unsqueeze(0)
85
+ output_image = model.test(input_image)
86
+ print(output_image.max())