.gitattributes CHANGED
@@ -37,3 +37,4 @@ models/HAT/hat-for-image-sr-2.ipynb filter=lfs diff=lfs merge=lfs -text
37
  model/RCAN/rcan-for-image-sr.ipynb filter=lfs diff=lfs merge=lfs -text
38
  models/RCAN/rcan-for-image-sr.ipynb filter=lfs diff=lfs merge=lfs -text
39
  models/SRGAN/srgan-dl-prj.ipynb filter=lfs diff=lfs merge=lfs -text
 
 
37
  model/RCAN/rcan-for-image-sr.ipynb filter=lfs diff=lfs merge=lfs -text
38
  models/RCAN/rcan-for-image-sr.ipynb filter=lfs diff=lfs merge=lfs -text
39
  models/SRGAN/srgan-dl-prj.ipynb filter=lfs diff=lfs merge=lfs -text
40
+ models/VDSR/vdsr.ipynb filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,216 +1,241 @@
1
- import time
2
- import streamlit as st
3
- import subprocess
4
- import numpy as np
5
- from PIL import Image
6
- from io import BytesIO
7
- from models.HAT.hat import *
8
- from models.RCAN.rcan import *
9
- from models.SRGAN.srgan import *
10
- from models.Interpolation.nearest_neighbor import NearestNeighbor_for_deployment
11
- from models.Interpolation.bilinear import Bilinear_for_deployment
12
- from models.Interpolation.bicubic import Bicubic_for_deployment
13
-
14
- subprocess.call('pip install natsort', shell=True)
15
- from models.SRFlow.srflow import *
16
-
17
- # Initialize session state for enhanced images
18
- if 'nearest_enhanced_image' not in st.session_state:
19
- st.session_state['nearest_enhanced_image'] = None
20
- if 'bilinear_enhanced_image' not in st.session_state:
21
- st.session_state['bilinear_enhanced_image'] = None
22
- if 'bicubic_enhanced_image' not in st.session_state:
23
- st.session_state['bicubic_enhanced_image'] = None
24
- if 'hat_enhanced_image' not in st.session_state:
25
- st.session_state['hat_enhanced_image'] = None
26
- if 'rcan_enhanced_image' not in st.session_state:
27
- st.session_state['rcan_enhanced_image'] = None
28
- if 'srgan_enhanced_image' not in st.session_state:
29
- st.session_state['srgan_enhanced_image'] = None
30
- if 'srflow_enhanced_image' not in st.session_state:
31
- st.session_state['srflow_enhanced_image'] = None
32
-
33
- # Initialize session state for button clicks
34
- if 'nearest_clicked' not in st.session_state:
35
- st.session_state['nearest_clicked'] = False
36
- if 'bilinear_clicked' not in st.session_state:
37
- st.session_state['bilinear_clicked'] = False
38
- if 'bicubic_clicked' not in st.session_state:
39
- st.session_state['bicubic_clicked'] = False
40
- if 'hat_clicked' not in st.session_state:
41
- st.session_state['hat_clicked'] = False
42
- if 'rcan_clicked' not in st.session_state:
43
- st.session_state['rcan_clicked'] = False
44
- if 'srgan_clicked' not in st.session_state:
45
- st.session_state['srgan_clicked'] = False
46
- if 'srflow_clicked' not in st.session_state:
47
- st.session_state['srflow_clicked'] = False
48
-
49
- st.markdown("<h1 style='text-align: center'>Image Super Resolution</h1>", unsafe_allow_html=True)
50
-
51
- # Sidebar for navigation
52
- st.sidebar.title("Options")
53
- app_mode = st.sidebar.selectbox("Choose the input source", ["Upload image", "Take a photo"])
54
-
55
- # Depending on the choice, show the uploader widget or webcam capture
56
- if app_mode == "Upload image":
57
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"], on_change=lambda: reset_states())
58
- if uploaded_file is not None:
59
- image = Image.open(uploaded_file).convert("RGB")
60
- elif app_mode == "Take a photo":
61
- camera_input = st.camera_input("Take a picture", on_change=lambda: reset_states())
62
- if camera_input is not None:
63
- image = Image.open(camera_input).convert("RGB")
64
-
65
- def reset_states():
66
- st.session_state['hat_enhanced_image'] = None
67
- st.session_state['rcan_enhanced_image'] = None
68
- st.session_state['srgan_enhanced_image'] = None
69
- st.session_state['srflow_enhanced_image'] = None
70
- st.session_state['bicubic_enhanced_image'] = None
71
- st.session_state['bilinear_enhanced_image'] = None
72
- st.session_state['nearest_enhanced_image'] = None
73
- st.session_state['hat_clicked'] = False
74
- st.session_state['rcan_clicked'] = False
75
- st.session_state['srgan_clicked'] = False
76
- st.session_state['srflow_clicked'] = False
77
- st.session_state['bicubic_clicked'] = False
78
- st.session_state['bilinear_clicked'] = False
79
- st.session_state['nearest_clicked'] = False
80
-
81
- def get_image_download_link(img, filename):
82
- """Generates a link allowing the PIL image to be downloaded"""
83
- # Convert the PIL image to Bytes
84
- buffered = BytesIO()
85
- img.save(buffered, format="PNG")
86
- return st.download_button(
87
- label="Download Image",
88
- data=buffered.getvalue(),
89
- file_name=filename,
90
- mime="image/png"
91
- )
92
-
93
- if 'image' in locals():
94
- # st.image(image, caption='Uploaded Image', use_column_width=True)
95
- st.write("")
96
- # ------------------------ Nearest Neighbor ------------------------ #
97
- if st.button('Enhance with Nearest Neighbor'):
98
- with st.spinner('Processing using Nearest Neighbor...'):
99
- enhanced_image = NearestNeighbor_for_deployment(image)
100
- st.session_state['nearest_enhanced_image'] = enhanced_image
101
- st.session_state['nearest_clicked'] = True
102
- st.success('Done!')
103
- if st.session_state['nearest_enhanced_image'] is not None:
104
- col1, col2 = st.columns(2)
105
- col1.header("Original")
106
- col1.image(image, use_column_width=True)
107
- col2.header("Enhanced")
108
- col2.image(st.session_state['nearest_enhanced_image'], use_column_width=True)
109
- with col2:
110
- get_image_download_link(st.session_state['nearest_enhanced_image'], 'nearest_enhanced.jpg')
111
-
112
- # ------------------------ Bilinear ------------------------ #
113
- if st.button('Enhance with Bilinear'):
114
- with st.spinner('Processing using Bilinear...'):
115
- enhanced_image = Bilinear_for_deployment(image)
116
- st.session_state['bilinear_enhanced_image'] = enhanced_image
117
- st.session_state['bilinear_clicked'] = True
118
- st.success('Done!')
119
- if st.session_state['bilinear_enhanced_image'] is not None:
120
- col1, col2 = st.columns(2)
121
- col1.header("Original")
122
- col1.image(image, use_column_width=True)
123
- col2.header("Enhanced")
124
- col2.image(st.session_state['bilinear_enhanced_image'], use_column_width=True)
125
- with col2:
126
- get_image_download_link(st.session_state['bilinear_enhanced_image'], 'bilinear_enhanced.jpg')
127
-
128
- # ------------------------ Bicubic ------------------------ #
129
- if st.button('Enhance with Bicubic'):
130
- with st.spinner('Processing using Bicubic...'):
131
- enhanced_image = Bicubic_for_deployment(image)
132
- st.session_state['bicubic_enhanced_image'] = enhanced_image
133
- st.session_state['bicubic_clicked'] = True
134
- st.success('Done!')
135
- if st.session_state['bicubic_enhanced_image'] is not None:
136
- col1, col2 = st.columns(2)
137
- col1.header("Original")
138
- col1.image(image, use_column_width=True)
139
- col2.header("Enhanced")
140
- col2.image(st.session_state['bicubic_enhanced_image'], use_column_width=True)
141
- with col2:
142
- get_image_download_link(st.session_state['bicubic_enhanced_image'], 'bicubic_enhanced.jpg')
143
-
144
- # ------------------------ HAT ------------------------ #
145
- if st.button('Enhance with HAT'):
146
- with st.spinner('Processing using HAT...'):
147
- with st.spinner('Wait for it... the model is processing the image'):
148
- enhanced_image = HAT_for_deployment(image)
149
- st.session_state['hat_enhanced_image'] = enhanced_image
150
- st.session_state['hat_clicked'] = True
151
- st.success('Done!')
152
- if st.session_state['hat_enhanced_image'] is not None:
153
- col1, col2 = st.columns(2)
154
- col1.header("Original")
155
- col1.image(image, use_column_width=True)
156
- col2.header("Enhanced")
157
- col2.image(st.session_state['hat_enhanced_image'], use_column_width=True)
158
- with col2:
159
- get_image_download_link(st.session_state['hat_enhanced_image'], 'hat_enhanced.jpg')
160
-
161
- # ------------------------ RCAN ------------------------ #
162
- if st.button('Enhance with RCAN'):
163
- with st.spinner('Processing using RCAN...'):
164
- with st.spinner('Wait for it... the model is processing the image'):
165
- rcan_model = RCAN()
166
- device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda')
167
- rcan_model.load_state_dict(torch.load('models/RCAN/rcan_checkpoint.pth', map_location=device))
168
- enhanced_image = rcan_model.inference(image)
169
- st.session_state['rcan_enhanced_image'] = enhanced_image
170
- st.session_state['rcan_clicked'] = True
171
- st.success('Done!')
172
- if st.session_state['rcan_enhanced_image'] is not None:
173
- col1, col2 = st.columns(2)
174
- col1.header("Original")
175
- col1.image(image, use_column_width=True)
176
- col2.header("Enhanced")
177
- col2.image(st.session_state['rcan_enhanced_image'], use_column_width=True)
178
- with col2:
179
- get_image_download_link(st.session_state['rcan_enhanced_image'], 'rcan_enhanced.jpg')
180
-
181
- # --------------------------SRGAN-------------------------- #
182
- if st.button('Enhance with SRGAN'):
183
- with st.spinner('Processing using SRGAN...'):
184
- with st.spinner('Wait for it... the model is processing the image'):
185
- srgan_model = GeneratorResnet()
186
- device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda')
187
- srgan_model = torch.load('models/SRGAN/srgan_checkpoint.pth', map_location=device)
188
- enhanced_image = srgan_model.inference(image)
189
- st.session_state['srgan_enhanced_image'] = enhanced_image
190
- st.session_state['srgan_clicked'] = True
191
- st.success('Done!')
192
- if st.session_state['srgan_enhanced_image'] is not None:
193
- col1, col2 = st.columns(2)
194
- col1.header("Original")
195
- col1.image(image, use_column_width=True)
196
- col2.header("Enhanced")
197
- col2.image(st.session_state['srgan_enhanced_image'], use_column_width=True)
198
- with col2:
199
- get_image_download_link(st.session_state['srgan_enhanced_image'], 'srgan_enhanced.jpg')
200
-
201
- # ------------------------ SRFlow ------------------------ #
202
- if st.button('Enhance with SRFlow'):
203
- with st.spinner('Processing using SRFlow...'):
204
- with st.spinner('Wait for it... the model is processing the image'):
205
- enhanced_image = return_SRFlow_result(image)
206
- st.session_state['srflow_enhanced_image'] = enhanced_image
207
- st.session_state['srflow_clicked'] = True
208
- st.success('Done!')
209
- if st.session_state['srflow_enhanced_image'] is not None:
210
- col1, col2 = st.columns(2)
211
- col1.header("Original")
212
- col1.image(image, use_column_width=True)
213
- col2.header("Enhanced")
214
- col2.image(st.session_state['srflow_enhanced_image'], use_column_width=True)
215
- with col2:
216
- get_image_download_link(st.session_state['srflow_enhanced_image'], 'srflow_enhanced.jpg')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import streamlit as st
3
+ import subprocess
4
+ import numpy as np
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ from models.HAT.hat import *
8
+ from models.RCAN.rcan import *
9
+ from models.SRGAN.srgan import *
10
+ from models.VDSR.vdsr import *
11
+ from models.Interpolation.nearest_neighbor import NearestNeighbor_for_deployment
12
+ from models.Interpolation.bilinear import Bilinear_for_deployment
13
+ from models.Interpolation.bicubic import Bicubic_for_deployment
14
+
15
+ subprocess.call('pip install natsort', shell=True)
16
+ from models.SRFlow.srflow import *
17
+
18
+ # Initialize session state for enhanced images
19
+ if 'nearest_enhanced_image' not in st.session_state:
20
+ st.session_state['nearest_enhanced_image'] = None
21
+ if 'bilinear_enhanced_image' not in st.session_state:
22
+ st.session_state['bilinear_enhanced_image'] = None
23
+ if 'bicubic_enhanced_image' not in st.session_state:
24
+ st.session_state['bicubic_enhanced_image'] = None
25
+ if 'hat_enhanced_image' not in st.session_state:
26
+ st.session_state['hat_enhanced_image'] = None
27
+ if 'rcan_enhanced_image' not in st.session_state:
28
+ st.session_state['rcan_enhanced_image'] = None
29
+ if 'srgan_enhanced_image' not in st.session_state:
30
+ st.session_state['srgan_enhanced_image'] = None
31
+ if 'srflow_enhanced_image' not in st.session_state:
32
+ st.session_state['srflow_enhanced_image'] = None
33
+ if 'vdsr_enhanced_image' not in st.session_state:
34
+ st.session_state['vdsr_enhanced_image'] = None
35
+
36
+ # Initialize session state for button clicks
37
+ if 'nearest_clicked' not in st.session_state:
38
+ st.session_state['nearest_clicked'] = False
39
+ if 'bilinear_clicked' not in st.session_state:
40
+ st.session_state['bilinear_clicked'] = False
41
+ if 'bicubic_clicked' not in st.session_state:
42
+ st.session_state['bicubic_clicked'] = False
43
+ if 'hat_clicked' not in st.session_state:
44
+ st.session_state['hat_clicked'] = False
45
+ if 'rcan_clicked' not in st.session_state:
46
+ st.session_state['rcan_clicked'] = False
47
+ if 'srgan_clicked' not in st.session_state:
48
+ st.session_state['srgan_clicked'] = False
49
+ if 'srflow_clicked' not in st.session_state:
50
+ st.session_state['srflow_clicked'] = False
51
+ if 'vdsr_clicked' not in st.session_state:
52
+ st.session_state['vdsr_clicked'] = False
53
+
54
+ st.markdown("<h1 style='text-align: center'>Image Super Resolution</h1>", unsafe_allow_html=True)
55
+
56
+ # Sidebar for navigation
57
+ st.sidebar.title("Options")
58
+ app_mode = st.sidebar.selectbox("Choose the input source", ["Upload image", "Take a photo"])
59
+
60
+ # Depending on the choice, show the uploader widget or webcam capture
61
+ if app_mode == "Upload image":
62
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"], on_change=lambda: reset_states())
63
+ if uploaded_file is not None:
64
+ image = Image.open(uploaded_file).convert("RGB")
65
+ elif app_mode == "Take a photo":
66
+ camera_input = st.camera_input("Take a picture", on_change=lambda: reset_states())
67
+ if camera_input is not None:
68
+ image = Image.open(camera_input).convert("RGB")
69
+
70
+ def reset_states():
71
+ st.session_state['hat_enhanced_image'] = None
72
+ st.session_state['rcan_enhanced_image'] = None
73
+ st.session_state['srgan_enhanced_image'] = None
74
+ st.session_state['srflow_enhanced_image'] = None
75
+ st.session_state['bicubic_enhanced_image'] = None
76
+ st.session_state['bilinear_enhanced_image'] = None
77
+ st.session_state['nearest_enhanced_image'] = None
78
+ st.session_state['vdsr_enhanced_image'] = None
79
+ st.session_state['hat_clicked'] = False
80
+ st.session_state['rcan_clicked'] = False
81
+ st.session_state['srgan_clicked'] = False
82
+ st.session_state['srflow_clicked'] = False
83
+ st.session_state['bicubic_clicked'] = False
84
+ st.session_state['bilinear_clicked'] = False
85
+ st.session_state['nearest_clicked'] = False
86
+ st.session_state['vdsr_clicked'] = False
87
+
88
+ def get_image_download_link(img, filename):
89
+ """Generates a link allowing the PIL image to be downloaded"""
90
+ # Convert the PIL image to Bytes
91
+ buffered = BytesIO()
92
+ img.save(buffered, format="PNG")
93
+ return st.download_button(
94
+ label="Download Image",
95
+ data=buffered.getvalue(),
96
+ file_name=filename,
97
+ mime="image/png"
98
+ )
99
+
100
+ if 'image' in locals():
101
+ # st.image(image, caption='Uploaded Image', use_column_width=True)
102
+ st.write("")
103
+ # ------------------------ Nearest Neighbor ------------------------ #
104
+ if st.button('Enhance with Nearest Neighbor'):
105
+ with st.spinner('Processing using Nearest Neighbor...'):
106
+ enhanced_image = NearestNeighbor_for_deployment(image)
107
+ st.session_state['nearest_enhanced_image'] = enhanced_image
108
+ st.session_state['nearest_clicked'] = True
109
+ st.success('Done!')
110
+ if st.session_state['nearest_enhanced_image'] is not None:
111
+ col1, col2 = st.columns(2)
112
+ col1.header("Original")
113
+ col1.image(image, use_column_width=True)
114
+ col2.header("Enhanced")
115
+ col2.image(st.session_state['nearest_enhanced_image'], use_column_width=True)
116
+ with col2:
117
+ get_image_download_link(st.session_state['nearest_enhanced_image'], 'nearest_enhanced.jpg')
118
+
119
+ # ------------------------ Bilinear ------------------------ #
120
+ if st.button('Enhance with Bilinear'):
121
+ with st.spinner('Processing using Bilinear...'):
122
+ enhanced_image = Bilinear_for_deployment(image)
123
+ st.session_state['bilinear_enhanced_image'] = enhanced_image
124
+ st.session_state['bilinear_clicked'] = True
125
+ st.success('Done!')
126
+ if st.session_state['bilinear_enhanced_image'] is not None:
127
+ col1, col2 = st.columns(2)
128
+ col1.header("Original")
129
+ col1.image(image, use_column_width=True)
130
+ col2.header("Enhanced")
131
+ col2.image(st.session_state['bilinear_enhanced_image'], use_column_width=True)
132
+ with col2:
133
+ get_image_download_link(st.session_state['bilinear_enhanced_image'], 'bilinear_enhanced.jpg')
134
+
135
+ # ------------------------ Bicubic ------------------------ #
136
+ if st.button('Enhance with Bicubic'):
137
+ with st.spinner('Processing using Bicubic...'):
138
+ enhanced_image = Bicubic_for_deployment(image)
139
+ st.session_state['bicubic_enhanced_image'] = enhanced_image
140
+ st.session_state['bicubic_clicked'] = True
141
+ st.success('Done!')
142
+ if st.session_state['bicubic_enhanced_image'] is not None:
143
+ col1, col2 = st.columns(2)
144
+ col1.header("Original")
145
+ col1.image(image, use_column_width=True)
146
+ col2.header("Enhanced")
147
+ col2.image(st.session_state['bicubic_enhanced_image'], use_column_width=True)
148
+ with col2:
149
+ get_image_download_link(st.session_state['bicubic_enhanced_image'], 'bicubic_enhanced.jpg')
150
+
151
+ # ------------------------ HAT ------------------------ #
152
+ if st.button('Enhance with HAT'):
153
+ with st.spinner('Processing using HAT...'):
154
+ with st.spinner('Wait for it... the model is processing the image'):
155
+ enhanced_image = HAT_for_deployment(image)
156
+ st.session_state['hat_enhanced_image'] = enhanced_image
157
+ st.session_state['hat_clicked'] = True
158
+ st.success('Done!')
159
+ if st.session_state['hat_enhanced_image'] is not None:
160
+ col1, col2 = st.columns(2)
161
+ col1.header("Original")
162
+ col1.image(image, use_column_width=True)
163
+ col2.header("Enhanced")
164
+ col2.image(st.session_state['hat_enhanced_image'], use_column_width=True)
165
+ with col2:
166
+ get_image_download_link(st.session_state['hat_enhanced_image'], 'hat_enhanced.jpg')
167
+
168
+ # ------------------------ RCAN ------------------------ #
169
+ if st.button('Enhance with RCAN'):
170
+ with st.spinner('Processing using RCAN...'):
171
+ with st.spinner('Wait for it... the model is processing the image'):
172
+ rcan_model = RCAN()
173
+ device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda')
174
+ rcan_model.load_state_dict(torch.load('models/RCAN/rcan_checkpoint.pth', map_location=device))
175
+ enhanced_image = rcan_model.inference(image)
176
+ st.session_state['rcan_enhanced_image'] = enhanced_image
177
+ st.session_state['rcan_clicked'] = True
178
+ st.success('Done!')
179
+ if st.session_state['rcan_enhanced_image'] is not None:
180
+ col1, col2 = st.columns(2)
181
+ col1.header("Original")
182
+ col1.image(image, use_column_width=True)
183
+ col2.header("Enhanced")
184
+ col2.image(st.session_state['rcan_enhanced_image'], use_column_width=True)
185
+ with col2:
186
+ get_image_download_link(st.session_state['rcan_enhanced_image'], 'rcan_enhanced.jpg')
187
+
188
+ # --------------------------SRGAN-------------------------- #
189
+ if st.button('Enhance with SRGAN'):
190
+ with st.spinner('Processing using SRGAN...'):
191
+ with st.spinner('Wait for it... the model is processing the image'):
192
+ srgan_model = GeneratorResnet()
193
+ device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda')
194
+ srgan_model = torch.load('models/SRGAN/srgan_checkpoint.pth', map_location=device)
195
+ enhanced_image = srgan_model.inference(image)
196
+ st.session_state['srgan_enhanced_image'] = enhanced_image
197
+ st.session_state['srgan_clicked'] = True
198
+ st.success('Done!')
199
+ if st.session_state['srgan_enhanced_image'] is not None:
200
+ col1, col2 = st.columns(2)
201
+ col1.header("Original")
202
+ col1.image(image, use_column_width=True)
203
+ col2.header("Enhanced")
204
+ col2.image(st.session_state['srgan_enhanced_image'], use_column_width=True)
205
+ with col2:
206
+ get_image_download_link(st.session_state['srgan_enhanced_image'], 'srgan_enhanced.jpg')
207
+
208
+ # ------------------------ SRFlow ------------------------ #
209
+ if st.button('Enhance with SRFlow'):
210
+ with st.spinner('Processing using SRFlow...'):
211
+ with st.spinner('Wait for it... the model is processing the image'):
212
+ enhanced_image = return_SRFlow_result(image)
213
+ st.session_state['srflow_enhanced_image'] = enhanced_image
214
+ st.session_state['srflow_clicked'] = True
215
+ st.success('Done!')
216
+ if st.session_state['srflow_enhanced_image'] is not None:
217
+ col1, col2 = st.columns(2)
218
+ col1.header("Original")
219
+ col1.image(image, use_column_width=True)
220
+ col2.header("Enhanced")
221
+ col2.image(st.session_state['srflow_enhanced_image'], use_column_width=True)
222
+ with col2:
223
+ get_image_download_link(st.session_state['srflow_enhanced_image'], 'srflow_enhanced.jpg')
224
+
225
+ # ------------------------ VDSR ------------------------ #
226
+ if st.button('Enhance with VDSR'):
227
+ with st.spinner('Processing using VDSR...'):
228
+ # Load the VDSR model
229
+ vdsr_model = torch.load('models/VDSR/vdsr_checkpoint.pth', map_location=torch.device('cpu'))
230
+ enhanced_image = vdsr_model.inference(image)
231
+ st.session_state['vdsr_enhanced_image'] = enhanced_image
232
+ st.session_state['vdsr_clicked'] = True
233
+ st.success('Done!')
234
+ if st.session_state['vdsr_enhanced_image'] is not None:
235
+ col1, col2 = st.columns(2)
236
+ col1.header("Original")
237
+ col1.image(image, use_column_width=True)
238
+ col2.header("Enhanced")
239
+ col2.image(st.session_state['vdsr_enhanced_image'], use_column_width=True)
240
+ with col2:
241
+ get_image_download_link(st.session_state['vdsr_enhanced_image'], 'vdsr_enhanced.jpg')
models/VDSR/vdsr.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:acef3257ccc3ca0d7071b66a57f326b921e5696b917fe75c773afc08d14debcb
3
+ size 11173590
models/VDSR/vdsr.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision.transforms import ToTensor
4
+ from PIL import Image
5
+ import os
6
+ from math import sqrt
7
+ import torch.nn.functional as F
8
+
9
+ #define class Block contain conv and relu layer
10
+ class Block(nn.Module):
11
+ def __init__(self):
12
+ super(Block, self).__init__()
13
+ self.conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
14
+ self.relu = nn.ReLU(inplace=True)
15
+
16
+ def forward(self, x):
17
+ return self.relu(self.conv(x))
18
+
19
+ class VDSR(nn.Module):
20
+ def __init__(self, in_channels=3, out_channels=3, num_blocks=18):
21
+ super(VDSR, self).__init__()
22
+ self.residual_layer = self.make_layer(Block, num_blocks)
23
+ self.input = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
24
+ self.output = nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
25
+ self.relu = nn.ReLU(inplace=True)
26
+
27
+ for m in self.modules():
28
+ if isinstance(m, nn.Conv2d):
29
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
30
+ m.weight.data.normal_(0, sqrt(2. / n))
31
+
32
+ def make_layer(self, block, num_layers):
33
+ layers=[]
34
+ for _ in range(num_layers):
35
+ layers.append(block())
36
+ return nn.Sequential(*layers)
37
+
38
+ def forward(self, x):
39
+ residual = x
40
+ out = self.relu(self.input(x))
41
+ out = self.residual_layer(out)
42
+ out = self.output(out)
43
+ out = torch.add(residual, out)
44
+ return out
45
+
46
+ def inference(self, x):
47
+ """
48
+ x is a PIL image
49
+ """
50
+ self.eval()
51
+ with torch.no_grad():
52
+ x = ToTensor()(x).unsqueeze(0)
53
+ x = F.interpolate(x, scale_factor=4, mode='bicubic', align_corners=False).clamp(0, 1)
54
+ x = self.forward(x).clamp(0, 1)
55
+ x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
56
+ return x
57
+
58
+ if __name__ == '__main__':
59
+ current_dir = os.path.dirname(os.path.realpath(__file__))
60
+
61
+ model = torch.load(current_dir + '/vdsr_checkpoint.pth', map_location=torch.device('cpu'))
62
+ model.eval()
63
+ with torch.no_grad():
64
+ input_image = Image.open('images/demo.png')
65
+ output_image = model.inference(input_image)
66
+ print(input_image.size, output_image.size)
models/VDSR/vdsr_checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:839903f0373cbbd60ee00c4367436a718dd8689c8fda1d901471aa0f570e54be
3
+ size 2689946