thiemcun203 namnguyen2103 commited on
Commit
1374a6d
1 Parent(s): 92d51b9

Interpolation added (#5)

Browse files

- Interpolation added (ff4d653f5d182fc8825c2bf92bb9e4a65905c97f)


Co-authored-by: Nguyễn Nam <namnguyen2103@users.noreply.huggingface.co>

app.py CHANGED
@@ -7,11 +7,20 @@ from io import BytesIO
7
  from models.HAT.hat import *
8
  from models.RCAN.rcan import *
9
  from models.SRGAN.srgan import *
 
 
 
10
 
11
  subprocess.call('pip install natsort', shell=True)
12
  from models.SRFlow.srflow import *
13
 
14
  # Initialize session state for enhanced images
 
 
 
 
 
 
15
  if 'hat_enhanced_image' not in st.session_state:
16
  st.session_state['hat_enhanced_image'] = None
17
  if 'rcan_enhanced_image' not in st.session_state:
@@ -22,6 +31,12 @@ if 'srflow_enhanced_image' not in st.session_state:
22
  st.session_state['srflow_enhanced_image'] = None
23
 
24
  # Initialize session state for button clicks
 
 
 
 
 
 
25
  if 'hat_clicked' not in st.session_state:
26
  st.session_state['hat_clicked'] = False
27
  if 'rcan_clicked' not in st.session_state:
@@ -52,10 +67,16 @@ def reset_states():
52
  st.session_state['rcan_enhanced_image'] = None
53
  st.session_state['srgan_enhanced_image'] = None
54
  st.session_state['srflow_enhanced_image'] = None
 
 
 
55
  st.session_state['hat_clicked'] = False
56
  st.session_state['rcan_clicked'] = False
57
  st.session_state['srgan_clicked'] = False
58
  st.session_state['srflow_clicked'] = False
 
 
 
59
 
60
  def get_image_download_link(img, filename):
61
  """Generates a link allowing the PIL image to be downloaded"""
@@ -72,6 +93,53 @@ def get_image_download_link(img, filename):
72
  if 'image' in locals():
73
  # st.image(image, caption='Uploaded Image', use_column_width=True)
74
  st.write("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  # ------------------------ HAT ------------------------ #
77
  if st.button('Enhance with HAT'):
 
7
  from models.HAT.hat import *
8
  from models.RCAN.rcan import *
9
  from models.SRGAN.srgan import *
10
+ from models.Interpolation.nearest_neighbor import NearestNeighbor_for_deployment
11
+ from models.Interpolation.bilinear import Bilinear_for_deployment
12
+ from models.Interpolation.bicubic import Bicubic_for_deployment
13
 
14
  subprocess.call('pip install natsort', shell=True)
15
  from models.SRFlow.srflow import *
16
 
17
  # Initialize session state for enhanced images
18
+ if 'nearest_enhanced_image' not in st.session_state:
19
+ st.session_state['nearest_enhanced_image'] = None
20
+ if 'bilinear_enhanced_image' not in st.session_state:
21
+ st.session_state['bilinear_enhanced_image'] = None
22
+ if 'bicubic_enhanced_image' not in st.session_state:
23
+ st.session_state['bicubic_enhanced_image'] = None
24
  if 'hat_enhanced_image' not in st.session_state:
25
  st.session_state['hat_enhanced_image'] = None
26
  if 'rcan_enhanced_image' not in st.session_state:
 
31
  st.session_state['srflow_enhanced_image'] = None
32
 
33
  # Initialize session state for button clicks
34
+ if 'nearest_clicked' not in st.session_state:
35
+ st.session_state['nearest_clicked'] = False
36
+ if 'bilinear_clicked' not in st.session_state:
37
+ st.session_state['bilinear_clicked'] = False
38
+ if 'bicubic_clicked' not in st.session_state:
39
+ st.session_state['bicubic_clicked'] = False
40
  if 'hat_clicked' not in st.session_state:
41
  st.session_state['hat_clicked'] = False
42
  if 'rcan_clicked' not in st.session_state:
 
67
  st.session_state['rcan_enhanced_image'] = None
68
  st.session_state['srgan_enhanced_image'] = None
69
  st.session_state['srflow_enhanced_image'] = None
70
+ st.session_state['bicubic_enhanced_image'] = None
71
+ st.session_state['bilinear_enhanced_image'] = None
72
+ st.session_state['nearest_enhanced_image'] = None
73
  st.session_state['hat_clicked'] = False
74
  st.session_state['rcan_clicked'] = False
75
  st.session_state['srgan_clicked'] = False
76
  st.session_state['srflow_clicked'] = False
77
+ st.session_state['bicubic_clicked'] = False
78
+ st.session_state['bilinear_clicked'] = False
79
+ st.session_state['nearest_clicked'] = False
80
 
81
  def get_image_download_link(img, filename):
82
  """Generates a link allowing the PIL image to be downloaded"""
 
93
  if 'image' in locals():
94
  # st.image(image, caption='Uploaded Image', use_column_width=True)
95
  st.write("")
96
+ # ------------------------ Nearest Neighbor ------------------------ #
97
+ if st.button('Enhance with Nearest Neighbor'):
98
+ with st.spinner('Processing using Nearest Neighbor...'):
99
+ enhanced_image = NearestNeighbor_for_deployment(image)
100
+ st.session_state['nearest_enhanced_image'] = enhanced_image
101
+ st.session_state['nearest_clicked'] = True
102
+ st.success('Done!')
103
+ if st.session_state['nearest_enhanced_image'] is not None:
104
+ col1, col2 = st.columns(2)
105
+ col1.header("Original")
106
+ col1.image(image, use_column_width=True)
107
+ col2.header("Enhanced")
108
+ col2.image(st.session_state['nearest_enhanced_image'], use_column_width=True)
109
+ with col2:
110
+ get_image_download_link(st.session_state['nearest_enhanced_image'], 'nearest_enhanced.jpg')
111
+
112
+ # ------------------------ Bilinear ------------------------ #
113
+ if st.button('Enhance with Bilinear'):
114
+ with st.spinner('Processing using Bilinear...'):
115
+ enhanced_image = Bilinear_for_deployment(image)
116
+ st.session_state['bilinear_enhanced_image'] = enhanced_image
117
+ st.session_state['bilinear_clicked'] = True
118
+ st.success('Done!')
119
+ if st.session_state['bilinear_enhanced_image'] is not None:
120
+ col1, col2 = st.columns(2)
121
+ col1.header("Original")
122
+ col1.image(image, use_column_width=True)
123
+ col2.header("Enhanced")
124
+ col2.image(st.session_state['bilinear_enhanced_image'], use_column_width=True)
125
+ with col2:
126
+ get_image_download_link(st.session_state['bilinear_enhanced_image'], 'bilinear_enhanced.jpg')
127
+
128
+ # ------------------------ Bicubic ------------------------ #
129
+ if st.button('Enhance with Bicubic'):
130
+ with st.spinner('Processing using Bicubic...'):
131
+ enhanced_image = Bicubic_for_deployment(image)
132
+ st.session_state['bicubic_enhanced_image'] = enhanced_image
133
+ st.session_state['bicubic_clicked'] = True
134
+ st.success('Done!')
135
+ if st.session_state['bicubic_enhanced_image'] is not None:
136
+ col1, col2 = st.columns(2)
137
+ col1.header("Original")
138
+ col1.image(image, use_column_width=True)
139
+ col2.header("Enhanced")
140
+ col2.image(st.session_state['bicubic_enhanced_image'], use_column_width=True)
141
+ with col2:
142
+ get_image_download_link(st.session_state['bicubic_enhanced_image'], 'bicubic_enhanced.jpg')
143
 
144
  # ------------------------ HAT ------------------------ #
145
  if st.button('Enhance with HAT'):
models/Interpolation/bicubic.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from torchvision import transforms
3
+
4
+ def Bicubic_for_deployment(lr_image):
5
+ w, h = lr_image.size
6
+ sr_image = transforms.functional.resize(lr_image, size=(h*4, w*4), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)
7
+ return sr_image
models/Interpolation/bilinear.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from torchvision import transforms
3
+
4
+ def Bilinear_for_deployment(lr_image):
5
+ w, h = lr_image.size
6
+ sr_image = transforms.functional.resize(lr_image, size=(h*4, w*4), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True)
7
+ return sr_image
models/Interpolation/nearest_neighbor.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from torchvision import transforms
3
+
4
+ def NearestNeighbor_for_deployment(lr_image):
5
+ w, h = lr_image.size
6
+ sr_image = transforms.functional.resize(lr_image, size=(h*4, w*4),interpolation=transforms.InterpolationMode.NEAREST,antialias=False)
7
+ return sr_image