Spaces:
Running
Running
vincent-doan
commited on
Commit
•
4381d4f
1
Parent(s):
54770f1
Configure for RCAN
Browse files
app.py
CHANGED
@@ -4,6 +4,8 @@ import numpy as np
|
|
4 |
from PIL import Image
|
5 |
from io import BytesIO
|
6 |
from models.HAT.hat import *
|
|
|
|
|
7 |
# Initialize session state for enhanced images
|
8 |
if 'hat_enhanced_image' not in st.session_state:
|
9 |
st.session_state['hat_enhanced_image'] = None
|
@@ -55,49 +57,39 @@ if 'image' in locals():
|
|
55 |
# st.image(image, caption='Uploaded Image', use_column_width=True)
|
56 |
st.write("")
|
57 |
|
|
|
58 |
if st.button('Enhance with HAT'):
|
59 |
-
with st.spinner('Processing using HAT...'):
|
60 |
-
with st.spinner('Wait for it... the model is processing the image'):
|
61 |
-
# Simulate a delay for processing image
|
62 |
-
|
63 |
enhanced_image = HAT_for_deployment(image)
|
64 |
st.session_state['hat_enhanced_image'] = enhanced_image
|
65 |
st.session_state['hat_clicked'] = True
|
66 |
st.success('Done!')
|
67 |
-
# Display the low and high resolution images side by side
|
68 |
if st.session_state['hat_enhanced_image'] is not None:
|
69 |
col1, col2 = st.columns(2)
|
70 |
col1.header("Original")
|
71 |
col1.image(image, use_column_width=True)
|
72 |
-
|
73 |
col2.header("Enhanced")
|
74 |
col2.image(st.session_state['hat_enhanced_image'], use_column_width=True)
|
75 |
with col2:
|
76 |
get_image_download_link(st.session_state['hat_enhanced_image'], 'hat_enhanced.jpg')
|
77 |
|
|
|
78 |
if st.button('Enhance with RCAN'):
|
79 |
with st.spinner('Processing using RCAN...'):
|
80 |
with st.spinner('Wait for it... the model is processing the image'):
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
enhanced_image = image
|
85 |
-
# Display the low and high resolution images side by side
|
86 |
st.session_state['rcan_enhanced_image'] = enhanced_image
|
87 |
-
|
88 |
st.session_state['rcan_clicked'] = True
|
89 |
st.success('Done!')
|
90 |
-
|
91 |
if st.session_state['rcan_enhanced_image'] is not None:
|
92 |
col1, col2 = st.columns(2)
|
93 |
col1.header("Original")
|
94 |
col1.image(image, use_column_width=True)
|
95 |
-
|
96 |
col2.header("Enhanced")
|
97 |
col2.image(st.session_state['rcan_enhanced_image'], use_column_width=True)
|
98 |
with col2:
|
99 |
get_image_download_link(st.session_state['rcan_enhanced_image'], 'rcan_enhanced.jpg')
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
4 |
from PIL import Image
|
5 |
from io import BytesIO
|
6 |
from models.HAT.hat import *
|
7 |
+
from models.RCAN.rcan import *
|
8 |
+
|
9 |
# Initialize session state for enhanced images
|
10 |
if 'hat_enhanced_image' not in st.session_state:
|
11 |
st.session_state['hat_enhanced_image'] = None
|
|
|
57 |
# st.image(image, caption='Uploaded Image', use_column_width=True)
|
58 |
st.write("")
|
59 |
|
60 |
+
# ------------------------ HAT ------------------------ #
|
61 |
if st.button('Enhance with HAT'):
|
62 |
+
with st.spinner('Processing using HAT...'):
|
63 |
+
with st.spinner('Wait for it... the model is processing the image'):
|
|
|
|
|
64 |
enhanced_image = HAT_for_deployment(image)
|
65 |
st.session_state['hat_enhanced_image'] = enhanced_image
|
66 |
st.session_state['hat_clicked'] = True
|
67 |
st.success('Done!')
|
|
|
68 |
if st.session_state['hat_enhanced_image'] is not None:
|
69 |
col1, col2 = st.columns(2)
|
70 |
col1.header("Original")
|
71 |
col1.image(image, use_column_width=True)
|
|
|
72 |
col2.header("Enhanced")
|
73 |
col2.image(st.session_state['hat_enhanced_image'], use_column_width=True)
|
74 |
with col2:
|
75 |
get_image_download_link(st.session_state['hat_enhanced_image'], 'hat_enhanced.jpg')
|
76 |
|
77 |
+
# ------------------------ RCAN ------------------------ #
|
78 |
if st.button('Enhance with RCAN'):
|
79 |
with st.spinner('Processing using RCAN...'):
|
80 |
with st.spinner('Wait for it... the model is processing the image'):
|
81 |
+
rcan_model = RCAN()
|
82 |
+
device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda')
|
83 |
+
rcan_model.load_state_dict(torch.load('models/RCAN/rcan_checkpoint.pth', map_location=device))
|
84 |
+
enhanced_image = rcan_model.inference(image)
|
|
|
85 |
st.session_state['rcan_enhanced_image'] = enhanced_image
|
|
|
86 |
st.session_state['rcan_clicked'] = True
|
87 |
st.success('Done!')
|
|
|
88 |
if st.session_state['rcan_enhanced_image'] is not None:
|
89 |
col1, col2 = st.columns(2)
|
90 |
col1.header("Original")
|
91 |
col1.image(image, use_column_width=True)
|
|
|
92 |
col2.header("Enhanced")
|
93 |
col2.image(st.session_state['rcan_enhanced_image'], use_column_width=True)
|
94 |
with col2:
|
95 |
get_image_download_link(st.session_state['rcan_enhanced_image'], 'rcan_enhanced.jpg')
|
|
|
|
|
|
|
|