Spaces:
Running
Running
vincent-doan
commited on
Commit
·
2235ef5
1
Parent(s):
fc75fbd
Updated loading from RCAN
Browse files- models/RCAN/rcan.py +17 -0
models/RCAN/rcan.py
CHANGED
@@ -1,4 +1,8 @@
|
|
|
|
|
|
1 |
from torch import nn
|
|
|
|
|
2 |
|
3 |
NUM_RESIDUAL_GROUPS = 8
|
4 |
NUM_RESIDUAL_BLOCKS = 16
|
@@ -103,3 +107,16 @@ class RCAN(nn.Module):
|
|
103 |
reconstructed_image = self.reconstruction_conv(upscaled_image) # Reconstruction
|
104 |
|
105 |
return reconstructed_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
from torch import nn
|
4 |
+
from PIL import Image
|
5 |
+
from torchvision.transforms import ToTensor
|
6 |
|
7 |
NUM_RESIDUAL_GROUPS = 8
|
8 |
NUM_RESIDUAL_BLOCKS = 16
|
|
|
107 |
reconstructed_image = self.reconstruction_conv(upscaled_image) # Reconstruction
|
108 |
|
109 |
return reconstructed_image
|
110 |
+
|
111 |
+
if __name__ == '__main__':
|
112 |
+
current_dir = os.path.dirname(os.path.realpath(__file__))
|
113 |
+
|
114 |
+
model = RCAN()
|
115 |
+
model.load_state_dict(torch.load(current_dir + '/rcan_checkpoint.pth', map_location=torch.device('cpu')))
|
116 |
+
model.eval()
|
117 |
+
with torch.no_grad():
|
118 |
+
input_image = Image.open('images/demo.png')
|
119 |
+
input_tensor = ToTensor()(input_image).unsqueeze(0)
|
120 |
+
output_tensor = model(input_tensor)
|
121 |
+
print(input_tensor.shape)
|
122 |
+
print(output_tensor.shape)
|