vincent-doan commited on
Commit
4381d4f
1 Parent(s): 54770f1

Configure for RCAN

Browse files
Files changed (1) hide show
  1. app.py +10 -18
app.py CHANGED
@@ -4,6 +4,8 @@ import numpy as np
4
  from PIL import Image
5
  from io import BytesIO
6
  from models.HAT.hat import *
 
 
7
  # Initialize session state for enhanced images
8
  if 'hat_enhanced_image' not in st.session_state:
9
  st.session_state['hat_enhanced_image'] = None
@@ -55,49 +57,39 @@ if 'image' in locals():
55
  # st.image(image, caption='Uploaded Image', use_column_width=True)
56
  st.write("")
57
 
 
58
  if st.button('Enhance with HAT'):
59
- with st.spinner('Processing using HAT...'):
60
- with st.spinner('Wait for it... the model is processing the image'):
61
- # Simulate a delay for processing image
62
-
63
  enhanced_image = HAT_for_deployment(image)
64
  st.session_state['hat_enhanced_image'] = enhanced_image
65
  st.session_state['hat_clicked'] = True
66
  st.success('Done!')
67
- # Display the low and high resolution images side by side
68
  if st.session_state['hat_enhanced_image'] is not None:
69
  col1, col2 = st.columns(2)
70
  col1.header("Original")
71
  col1.image(image, use_column_width=True)
72
-
73
  col2.header("Enhanced")
74
  col2.image(st.session_state['hat_enhanced_image'], use_column_width=True)
75
  with col2:
76
  get_image_download_link(st.session_state['hat_enhanced_image'], 'hat_enhanced.jpg')
77
 
 
78
  if st.button('Enhance with RCAN'):
79
  with st.spinner('Processing using RCAN...'):
80
  with st.spinner('Wait for it... the model is processing the image'):
81
- # Simulate a delay for processing image
82
- time.sleep(2) # replace this with actual model processing code
83
-
84
- enhanced_image = image
85
- # Display the low and high resolution images side by side
86
  st.session_state['rcan_enhanced_image'] = enhanced_image
87
-
88
  st.session_state['rcan_clicked'] = True
89
  st.success('Done!')
90
-
91
  if st.session_state['rcan_enhanced_image'] is not None:
92
  col1, col2 = st.columns(2)
93
  col1.header("Original")
94
  col1.image(image, use_column_width=True)
95
-
96
  col2.header("Enhanced")
97
  col2.image(st.session_state['rcan_enhanced_image'], use_column_width=True)
98
  with col2:
99
  get_image_download_link(st.session_state['rcan_enhanced_image'], 'rcan_enhanced.jpg')
100
-
101
-
102
-
103
-
 
4
  from PIL import Image
5
  from io import BytesIO
6
  from models.HAT.hat import *
7
+ from models.RCAN.rcan import *
8
+
9
  # Initialize session state for enhanced images
10
  if 'hat_enhanced_image' not in st.session_state:
11
  st.session_state['hat_enhanced_image'] = None
 
57
  # st.image(image, caption='Uploaded Image', use_column_width=True)
58
  st.write("")
59
 
60
+ # ------------------------ HAT ------------------------ #
61
  if st.button('Enhance with HAT'):
62
+ with st.spinner('Processing using HAT...'):
63
+ with st.spinner('Wait for it... the model is processing the image'):
 
 
64
  enhanced_image = HAT_for_deployment(image)
65
  st.session_state['hat_enhanced_image'] = enhanced_image
66
  st.session_state['hat_clicked'] = True
67
  st.success('Done!')
 
68
  if st.session_state['hat_enhanced_image'] is not None:
69
  col1, col2 = st.columns(2)
70
  col1.header("Original")
71
  col1.image(image, use_column_width=True)
 
72
  col2.header("Enhanced")
73
  col2.image(st.session_state['hat_enhanced_image'], use_column_width=True)
74
  with col2:
75
  get_image_download_link(st.session_state['hat_enhanced_image'], 'hat_enhanced.jpg')
76
 
77
+ # ------------------------ RCAN ------------------------ #
78
  if st.button('Enhance with RCAN'):
79
  with st.spinner('Processing using RCAN...'):
80
  with st.spinner('Wait for it... the model is processing the image'):
81
+ rcan_model = RCAN()
82
+ device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda')
83
+ rcan_model.load_state_dict(torch.load('models/RCAN/rcan_checkpoint.pth', map_location=device))
84
+ enhanced_image = rcan_model.inference(image)
 
85
  st.session_state['rcan_enhanced_image'] = enhanced_image
 
86
  st.session_state['rcan_clicked'] = True
87
  st.success('Done!')
 
88
  if st.session_state['rcan_enhanced_image'] is not None:
89
  col1, col2 = st.columns(2)
90
  col1.header("Original")
91
  col1.image(image, use_column_width=True)
 
92
  col2.header("Enhanced")
93
  col2.image(st.session_state['rcan_enhanced_image'], use_column_width=True)
94
  with col2:
95
  get_image_download_link(st.session_state['rcan_enhanced_image'], 'rcan_enhanced.jpg')