vincent-doan commited on
Commit
2235ef5
·
1 Parent(s): fc75fbd

Updated loading from RCAN

Browse files
Files changed (1) hide show
  1. models/RCAN/rcan.py +17 -0
models/RCAN/rcan.py CHANGED
@@ -1,4 +1,8 @@
 
 
1
  from torch import nn
 
 
2
 
3
  NUM_RESIDUAL_GROUPS = 8
4
  NUM_RESIDUAL_BLOCKS = 16
@@ -103,3 +107,16 @@ class RCAN(nn.Module):
103
  reconstructed_image = self.reconstruction_conv(upscaled_image) # Reconstruction
104
 
105
  return reconstructed_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
  from torch import nn
4
+ from PIL import Image
5
+ from torchvision.transforms import ToTensor
6
 
7
  NUM_RESIDUAL_GROUPS = 8
8
  NUM_RESIDUAL_BLOCKS = 16
 
107
  reconstructed_image = self.reconstruction_conv(upscaled_image) # Reconstruction
108
 
109
  return reconstructed_image
110
+
111
+ if __name__ == '__main__':
112
+ current_dir = os.path.dirname(os.path.realpath(__file__))
113
+
114
+ model = RCAN()
115
+ model.load_state_dict(torch.load(current_dir + '/rcan_checkpoint.pth', map_location=torch.device('cpu')))
116
+ model.eval()
117
+ with torch.no_grad():
118
+ input_image = Image.open('images/demo.png')
119
+ input_tensor = ToTensor()(input_image).unsqueeze(0)
120
+ output_tensor = model(input_tensor)
121
+ print(input_tensor.shape)
122
+ print(output_tensor.shape)