Nguyễn Bá Thiêm commited on
Commit
b16ab70
1 Parent(s): 239e299

Add image super resolution functionality

Browse files
.gitignore CHANGED
@@ -0,0 +1 @@
 
 
1
+ models/HAT/__pycache__/hat.cpython-39.pyc
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import streamlit as st
3
+ import numpy as np
4
+ from PIL import Image
5
+ import cv2 # If you're using OpenCV for image processing
6
+ from io import BytesIO
7
+ import base64
8
+ from models.HAT.hat import *
9
+ # Initialize session state for enhanced images
10
+ if 'hat_enhanced_image' not in st.session_state:
11
+ st.session_state['hat_enhanced_image'] = None
12
+
13
+ if 'rcan_enhanced_image' not in st.session_state:
14
+ st.session_state['rcan_enhanced_image'] = None
15
+
16
+ if 'hat_clicked' not in st.session_state:
17
+ st.session_state['hat_clicked'] = False
18
+ if 'rcan_clicked' not in st.session_state:
19
+ st.session_state['rcan_clicked'] = False
20
+
21
+ st.markdown("<h1 style='text-align: center'>Image Super Resolution</h1>", unsafe_allow_html=True)
22
+ # Sidebar for navigation
23
+ st.sidebar.title("Options")
24
+ app_mode = st.sidebar.selectbox("Choose the input source",
25
+ ["Upload image", "Take a photo"])
26
+ # Depending on the choice, show the uploader widget or webcam capture
27
+ if app_mode == "Upload image":
28
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"], on_change=lambda: reset_states())
29
+ if uploaded_file is not None:
30
+ image = Image.open(uploaded_file).convert("RGB")
31
+ elif app_mode == "Take a photo":
32
+ # Using JS code to access user's webcam
33
+ camera_input = st.camera_input("Take a picture", on_change=lambda: reset_states())
34
+ if camera_input is not None:
35
+ # Convert the camera image to an RGB image
36
+ image = Image.open(camera_input).convert("RGB")
37
+
38
+ def reset_states():
39
+ st.session_state['hat_enhanced_image'] = None
40
+ st.session_state['rcan_enhanced_image'] = None
41
+ st.session_state['hat_clicked'] = False
42
+ st.session_state['rcan_clicked'] = False
43
+
44
+ def get_image_download_link(img, filename):
45
+ """Generates a link allowing the PIL image to be downloaded"""
46
+ # Convert the PIL image to Bytes
47
+ buffered = BytesIO()
48
+ img.save(buffered, format="PNG")
49
+ return st.download_button(
50
+ label="Download Image",
51
+ data=buffered.getvalue(),
52
+ file_name=filename,
53
+ mime="image/png"
54
+ )
55
+
56
+ if 'image' in locals():
57
+ # st.image(image, caption='Uploaded Image', use_column_width=True)
58
+ st.write("")
59
+
60
+ if st.button('Enhance with HAT'):
61
+ with st.spinner('Processing using HAT...'):
62
+ with st.spinner('Wait for it... the model is processing the image'):
63
+ # Simulate a delay for processing image
64
+
65
+ enhanced_image = HAT_for_deployment(image)
66
+ st.session_state['hat_enhanced_image'] = enhanced_image
67
+ st.session_state['hat_clicked'] = True
68
+ st.success('Done!')
69
+ # Display the low and high resolution images side by side
70
+ if st.session_state['hat_enhanced_image'] is not None:
71
+ col1, col2 = st.columns(2)
72
+ col1.header("Original")
73
+ col1.image(image, use_column_width=True)
74
+
75
+ col2.header("Enhanced")
76
+ col2.image(st.session_state['hat_enhanced_image'], use_column_width=True)
77
+ with col2:
78
+ get_image_download_link(st.session_state['hat_enhanced_image'], 'hat_enhanced.jpg')
79
+
80
+ if st.button('Enhance with RCAN'):
81
+ with st.spinner('Processing using RCAN...'):
82
+ with st.spinner('Wait for it... the model is processing the image'):
83
+ # Simulate a delay for processing image
84
+ time.sleep(2) # replace this with actual model processing code
85
+
86
+ enhanced_image = image
87
+ # Display the low and high resolution images side by side
88
+ st.session_state['rcan_enhanced_image'] = enhanced_image
89
+
90
+ st.session_state['rcan_clicked'] = True
91
+ st.success('Done!')
92
+
93
+ if st.session_state['rcan_enhanced_image'] is not None:
94
+ col1, col2 = st.columns(2)
95
+ col1.header("Original")
96
+ col1.image(image, use_column_width=True)
97
+
98
+ col2.header("Enhanced")
99
+ col2.image(st.session_state['rcan_enhanced_image'], use_column_width=True)
100
+ with col2:
101
+ get_image_download_link(st.session_state['rcan_enhanced_image'], 'rcan_enhanced.jpg')
102
+
103
+
104
+
105
+
images/{img_003_SRF_4_LR.png → demo.png} RENAMED
File without changes
models/HAT/hat.py CHANGED
@@ -1,28 +1,15 @@
 
1
  import gdown
2
-
3
- # url = 'https://drive.google.com/file/d/1LHIUM7YoUDk8cXWzVZhroAcA1xXi-d87/view?usp=drive_link'
4
- output = 'models/HAT/hat_model_checkpoint_best.pth'
5
- # gdown.download(url, output, quiet=False)
6
-
7
  import gc
8
  import os
9
  import random
10
  import time
11
- import wandb
12
- from tqdm import tqdm
13
-
14
  import matplotlib.pyplot as plt
15
  from PIL import Image
16
- from skimage.metrics import structural_similarity as ssim
17
-
18
  import torch
19
  from torch import nn, optim
20
  import torch.nn.functional as F
21
- from torch.utils.data import Dataset, DataLoader, ConcatDataset
22
  from torchvision import transforms
23
- from torchvision.transforms import Compose
24
- from torchmetrics.functional.image import structural_similarity_index_measure as ssim
25
-
26
  from basicsr.archs.arch_util import to_2tuple, trunc_normal_
27
  from einops import rearrange
28
  import math
@@ -299,6 +286,117 @@ class OCAB(nn.Module):
299
  x = self.proj(x) + shortcut
300
 
301
  x = x + self.mlp(self.norm2(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  return x
303
  class AttenBlocks(nn.Module):
304
  """ A series of attention blocks for one RHAG.
@@ -843,6 +941,8 @@ class HAT(nn.Module):
843
  x = x / self.img_range + self.mean
844
 
845
  return x
 
 
846
  # ------------------------------ HYPERPARAMS ------------------------------ #
847
  config = {
848
  "network_g": {
@@ -892,12 +992,12 @@ config = {
892
  }
893
 
894
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
895
- DEVICE
896
-
897
  class Network:
898
- def __init__(self, train_dataloader=train_dataloader, valid_dataloader=valid_dataloader,
899
- config = config, device=DEVICE, run_id=None, wandb_mode = False, STOP = float('inf'), save_temp_model = True, train_model_continue = False):
900
  self.config = config
 
901
  self.model = HAT(
902
  upscale=self.config['network_g']['upscale'],
903
  in_chans=self.config['network_g']['in_chans'],
@@ -914,59 +1014,15 @@ class Network:
914
  mlp_ratio=self.config['network_g']['mlp_ratio'],
915
  upsampler=self.config['network_g']['upsampler'],
916
  resi_connection=self.config['network_g']['resi_connection']
917
- ).to(device)
918
- self.device = device
919
- self.STOP = STOP
920
- self.wandb_mode = wandb_mode
921
- self.loss_fn = nn.L1Loss(reduction='mean').to(device)
922
-
923
  self.optimizer = optim.Adam(self.model.parameters(), lr=self.config['train']['optim_g']['lr'], weight_decay=config['train']['optim_g']['weight_decay'],betas=tuple(config['train']['optim_g']['betas']))
924
- self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones = self.config['train']['scheduler']['milestones'], gamma=self.config['train']['scheduler']['gamma'])
925
- self.train_dataloader = train_dataloader
926
- self.valid_dataloader = valid_dataloader
927
- self.num_epochs = self.config['train']['total_iter']
928
- self.run_id = run_id
929
- self.save_temp_model = save_temp_model
930
- self.train_model_continue = train_model_continue
931
- self.last_valid_loss = float('inf')
932
- checkpoint_path = output
933
- if self.save_temp_model:
934
- if self.train_model_continue:
935
- # Load the network and other states from the checkpoint
936
- self.start_epoch, train_loss, valid_loss = self.load_network(checkpoint_path)
937
-
938
- initial_lr = self.config['train']['optim_g']['lr'] * self.config['train']['scheduler']['gamma'] # Define your initial or desired learning rate
939
- for param_group in self.optimizer.param_groups:
940
- param_group['lr'] = initial_lr # Resetting learning rate
941
-
942
- # Recreate the scheduler with the updated optimizer
943
- self.scheduler = optim.lr_scheduler.MultiStepLR(
944
- self.optimizer,
945
- milestones=self.config['train']['scheduler']['milestones'],
946
- gamma=self.config['train']['scheduler']['gamma'],
947
- last_epoch = self.start_epoch - 1 # Ensure to set the last_epoch to continue correctly
948
- )
949
-
950
- # Print the updated learning rate and scheduler state
951
- print("Updated Learning Rate is:", self.optimizer.param_groups[0]['lr'])
952
- print(self.scheduler.state_dict())
953
- self.last_valid_loss = valid_loss
954
- # self.num_epochs-= self.start_epoch
955
- print("Previous train loss: ", train_loss)
956
- print("Previous valid loss: ", self.last_valid_loss)
957
-
958
- # Resume training notice
959
- print("------------------- Resuming training -------------------")
960
-
961
- self.save_network(0, 0, 0, 'temp_model_checkpoint.pth')
962
-
963
- def del_model(self):
964
- del self.model
965
- del self.optimizer
966
- del self.scheduler
967
- gc.collect()
968
- torch.cuda.empty_cache()
969
-
970
  def pre_process(self):
971
  # pad to multiplication of window_size
972
  window_size = self.config['network_g']['window_size'] * 4
@@ -986,84 +1042,11 @@ class Network:
986
  self.mod_pad_w = window_size - w % window_size
987
  for i in range(self.mod_pad_w):
988
  self.input_tile = F.pad(self.input_tile, (0, 1, 0, 0), 'reflect')
989
-
990
-
991
  def post_process(self):
992
  _, _, h, w = self.output_tile.size()
993
  self.output_tile = self.output_tile[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
994
 
995
-
996
- def save_network(self, epoch, train_loss, valid_loss, checkpoint_path):
997
- checkpoint = {
998
- 'epoch': epoch,
999
- 'train_loss': train_loss,
1000
- 'valid_loss': valid_loss,
1001
- 'model': self.model.state_dict(),
1002
- 'optimizer': self.optimizer.state_dict(),
1003
- 'learning_rate_scheduler': self.scheduler.state_dict(),
1004
- 'network': self
1005
- }
1006
- torch.save(checkpoint, checkpoint_path)
1007
-
1008
- def load_network(self, checkpoint_path):
1009
-
1010
- checkpoint = torch.load(checkpoint_path, map_location=self.device)
1011
- self.model = HAT(
1012
- upscale=self.config['network_g']['upscale'],
1013
- in_chans=self.config['network_g']['in_chans'],
1014
- img_size=self.config['network_g']['img_size'],
1015
- window_size=self.config['network_g']['window_size'],
1016
- compress_ratio=self.config['network_g']['compress_ratio'],
1017
- squeeze_factor=self.config['network_g']['squeeze_factor'],
1018
- conv_scale=self.config['network_g']['conv_scale'],
1019
- overlap_ratio=self.config['network_g']['overlap_ratio'],
1020
- img_range=self.config['network_g']['img_range'],
1021
- depths=self.config['network_g']['depths'],
1022
- embed_dim=self.config['network_g']['embed_dim'],
1023
- num_heads=self.config['network_g']['num_heads'],
1024
- mlp_ratio=self.config['network_g']['mlp_ratio'],
1025
- upsampler=self.config['network_g']['upsampler'],
1026
- resi_connection=self.config['network_g']['resi_connection']
1027
- ).to(self.device)
1028
- self.optimizer = optim.Adam(self.model.parameters(), lr=self.config['train']['optim_g']['lr'], weight_decay=config['train']['optim_g']['weight_decay'],betas=tuple(config['train']['optim_g']['betas']))
1029
- self.model.load_state_dict(checkpoint['model'])
1030
- self.optimizer.load_state_dict(checkpoint['optimizer']) # before create and load scheduler
1031
-
1032
- self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones = self.config['train']['scheduler']['milestones'], gamma=self.config['train']['scheduler']['gamma'])
1033
- self.scheduler.load_state_dict(checkpoint['learning_rate_scheduler'])
1034
- return checkpoint['epoch'], checkpoint['train_loss'], checkpoint['valid_loss']
1035
-
1036
- def train_step(self, lr_images, hr_images):
1037
- lr_images, hr_images = lr_images.to(self.device), hr_images.to(self.device)
1038
- sr_images = self.model(lr_images)
1039
-
1040
- self.optimizer.zero_grad()
1041
- loss = self.loss_fn(sr_images, hr_images)
1042
- loss.backward()
1043
- self.optimizer.step()
1044
-
1045
- # Memory cleanup
1046
- del sr_images, lr_images, hr_images
1047
- gc.collect()
1048
- torch.cuda.empty_cache()
1049
-
1050
- return loss.item()
1051
-
1052
- def valid_step(self, lr_images, hr_images):
1053
- lr_images, hr_images = lr_images.to(self.device), hr_images.to(self.device)
1054
-
1055
- sr_images = self.tile_valid(lr_images)
1056
-
1057
- loss = self.loss_fn(sr_images, hr_images)
1058
-
1059
- # Memory cleanup
1060
- del sr_images, lr_images, hr_images
1061
- gc.collect()
1062
- torch.cuda.empty_cache()
1063
-
1064
- return loss.item()
1065
-
1066
-
1067
  def tile_valid(self, lr_images):
1068
  """
1069
  Process all tiles of an image in a batch and then merge them back into the output image.
@@ -1167,115 +1150,8 @@ class Network:
1167
  gc.collect()
1168
  torch.cuda.empty_cache()
1169
  return sr_images
1170
-
1171
- def train_model(self):
1172
-
1173
- if self.wandb_mode:
1174
- wandb.init(project='HAT-for-image-sr',
1175
- resume='allow',
1176
- config= self.config,
1177
- id=self.run_id)
1178
- wandb.watch(self.model)
1179
- if self.train_model_continue:
1180
- epoch_lst = range(self.start_epoch, self.num_epochs)
1181
- else:
1182
- epoch_lst = range(self.num_epochs)
1183
- for epoch in epoch_lst:
1184
-
1185
- start1 = time.time()
1186
-
1187
- # ------------------- TRAIN -------------------
1188
- if self.save_temp_model:
1189
- self.load_network('temp_model_checkpoint.pth')
1190
- self.model.train()
1191
- train_epoch_loss = 0
1192
-
1193
- stop = 0
1194
- for hr_images, lr_images in tqdm(self.train_dataloader, desc=f'Epoch {epoch+1}/{self.num_epochs}'):
1195
-
1196
- if stop == self.STOP:
1197
- break
1198
- stop+=1
1199
-
1200
- loss = self.train_step(lr_images, hr_images)
1201
- train_epoch_loss += loss
1202
-
1203
- if self.wandb_mode:
1204
- wandb.log({
1205
- 'batch_loss': loss,
1206
- })
1207
-
1208
- if self.wandb_mode:
1209
- wandb.log({
1210
- 'learning_rate': self.optimizer.param_groups[0]['lr']
1211
- })
1212
- print("Learning Rate is:", self.optimizer.param_groups[0]['lr'])
1213
-
1214
- self.scheduler.step()
1215
-
1216
-
1217
- if self.save_temp_model:
1218
- self.save_network(epoch, train_epoch_loss, 0, 'temp_model_checkpoint.pth')
1219
- print(self.scheduler.state_dict())
1220
- self.del_model()
1221
-
1222
- del hr_images
1223
- del lr_images
1224
- gc.collect()
1225
-
1226
- train_epoch_loss /= len(self.train_dataloader)
1227
-
1228
- end1 = time.time()
1229
-
1230
-
1231
- # ------------------- VALID -------------------
1232
- start2 = time.time()
1233
- if self.save_temp_model:
1234
- self.load_network('temp_model_checkpoint.pth')
1235
-
1236
- self.model.eval()
1237
- with torch.no_grad():
1238
- valid_epoch_loss = 0
1239
-
1240
- stop = 0
1241
- for hr_images, lr_images in tqdm(self.valid_dataloader, desc=f'Epoch {epoch+1}/{self.num_epochs}'):
1242
- if stop == self.STOP:
1243
- break
1244
- stop+=1
1245
- loss = self.valid_step(lr_images, hr_images)
1246
- valid_epoch_loss += loss
1247
-
1248
- valid_epoch_loss /= len(self.valid_dataloader)
1249
-
1250
- end2 = time.time()
1251
-
1252
- # ------------------- LOG -------------------
1253
- if self.wandb_mode:
1254
- wandb.log({
1255
- 'train_loss': train_epoch_loss,
1256
- 'valid_loss': valid_epoch_loss,
1257
- })
1258
- # ------------------- VERBOSE -------------------
1259
- print(f'Epoch {epoch+1}/{self.num_epochs} | Train Loss: {train_epoch_loss:.4f} | Valid Loss: {valid_epoch_loss:.4f} | Time train: {end1-start1:.2f}s | Time valid: {end2-start2:.2f}s')
1260
-
1261
- # ------------------- CHECKPOINT -------------------
1262
- self.save_network(epoch, train_epoch_loss, valid_epoch_loss, 'model_checkpoint_latest.pth')
1263
- if valid_epoch_loss < self.last_valid_loss:
1264
- self.last_valid_loss = valid_epoch_loss
1265
- self.save_network(epoch, train_epoch_loss, valid_epoch_loss, 'model_checkpoint_best.pth')
1266
- print("New best checkpoint saved!")
1267
-
1268
- if self.save_temp_model:
1269
- self.del_model()
1270
-
1271
- del hr_images
1272
- del lr_images
1273
- gc.collect()
1274
-
1275
- if self.wandb_mode:
1276
- wandb.finish()
1277
 
1278
- def inference(self, lr_image, hr_image):
1279
  """
1280
  - lr_image: torch.Tensor
1281
  3D Tensor (C, H, W)
@@ -1284,80 +1160,87 @@ class Network:
1284
  ground-truth high-res image. If used solely for inference, skip this. Default is None/
1285
  """
1286
  lr_image = lr_image.unsqueeze(0).to(self.device)
 
1287
  self.for_inference = True
1288
  with torch.no_grad():
1289
  sr_image = self.tile_valid(lr_image)
 
1290
 
1291
- lr_image = lr_image.squeeze(0)
1292
- sr_image = sr_image.squeeze(0)
1293
-
1294
- print(">> Size of low-res image:", lr_image.size())
1295
- print(">> Size of super-res image:", sr_image.size())
1296
- if hr_image != None:
1297
- print(">> Size of high-res image:", hr_image.size())
1298
-
1299
- if hr_image != None:
1300
- fig, axes = plt.subplots(1, 3, figsize=(10, 6))
1301
- axes[0].imshow(lr_image.cpu().detach().permute((1, 2, 0)))
1302
- axes[0].set_title('Low Resolution')
1303
- axes[1].imshow(sr_image.cpu().detach().permute((1, 2, 0)))
1304
- axes[1].set_title('Super Resolution')
1305
- axes[2].imshow(hr_image.cpu().detach().permute((1, 2, 0)))
1306
- axes[2].set_title('High Resolution')
1307
- for ax in axes.flat:
1308
- ax.axis('off')
1309
  else:
1310
- fig, axes = plt.subplots(1, 2, figsize=(10, 6))
1311
- axes[0].imshow(lr_image.cpu().detach().permute((1, 2, 0)))
1312
- axes[0].set_title('Low Resolution')
1313
- axes[1].imshow(sr_image.cpu().detach().permute((1, 2, 0)))
1314
- axes[1].set_title('Super Resolution')
1315
- for ax in axes.flat:
1316
- ax.axis('off')
1317
 
1318
- plt.tight_layout()
1319
- plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1320
 
1321
- return sr_image
1322
-
1323
-
1324
- class TestDataset(Dataset):
1325
- def __init__(self, lr_images_path):
1326
- super(TestDataset, self).__init__()
1327
- # hr_images_list = os.listdir(hr_images_path)
1328
- self.lr_images_path = lr_images_path
 
 
 
 
 
 
1329
 
1330
- def __getitem__(self, idx):
1331
-
1332
- lr_image = Image.open(self.lr_image_path)
1333
-
1334
- lr_image = transforms.functional.to_tensor(lr_image)
1335
-
1336
- return lr_image
1337
-
1338
-
1339
  if __name__ == "__main__":
1340
  import os
1341
  import sys
1342
- # Getting to the Lambda directory
1343
  sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../"))
1344
- image_path = "images/img_003_SRF_4_LR.png"
1345
-
1346
- infer_dataset = TestDataset(images_path=image_path)
1347
 
1348
- # hat = Network(run_id="hat-for-image-sr-" + str(int(1704006834)),config = config, wandb_mode = False, save_temp_model = True, train_model_continue = False) # STOP = 2
1349
- # num_params = sum(p.numel() for p in hat.model.parameters() if p.requires_grad)
1350
- # print("Number of learnable parameters: ", num_params)
 
 
 
 
 
 
 
 
 
 
1351
 
1352
- # ---------- LOAD FROM LATEST CHECKPOINT ---------- #
1353
- gc.collect()
1354
- torch.cuda.empty_cache()
1355
- hat = Network()
1356
- hat.load_network(output)
1357
- num_params = sum(p.numel() for p in hat.model.parameters() if p.requires_grad)
1358
- print("Number of learnable parameters: ", num_params)
1359
- image = image.squeeze(0)
1360
- hat.inference(lr_image)
1361
-
1362
-
1363
 
 
1
+ import numpy as np
2
  import gdown
 
 
 
 
 
3
  import gc
4
  import os
5
  import random
6
  import time
 
 
 
7
  import matplotlib.pyplot as plt
8
  from PIL import Image
 
 
9
  import torch
10
  from torch import nn, optim
11
  import torch.nn.functional as F
 
12
  from torchvision import transforms
 
 
 
13
  from basicsr.archs.arch_util import to_2tuple, trunc_normal_
14
  from einops import rearrange
15
  import math
 
286
  x = self.proj(x) + shortcut
287
 
288
  x = x + self.mlp(self.norm2(x))
289
+ return x
290
+ class HAB(nn.Module):
291
+ r""" Hybrid Attention Block.
292
+
293
+ Args:
294
+ dim (int): Number of input channels.
295
+ input_resolution (tuple[int]): Input resolution.
296
+ num_heads (int): Number of attention heads.
297
+ window_size (int): Window size.
298
+ shift_size (int): Shift size for SW-MSA.
299
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
300
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
301
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
302
+ drop (float, optional): Dropout rate. Default: 0.0
303
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
304
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
305
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
306
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
307
+ """
308
+
309
+ def __init__(self,
310
+ dim,
311
+ input_resolution,
312
+ num_heads,
313
+ window_size=7,
314
+ shift_size=0,
315
+ compress_ratio=3,
316
+ squeeze_factor=30,
317
+ conv_scale=0.01,
318
+ mlp_ratio=4.,
319
+ qkv_bias=True,
320
+ qk_scale=None,
321
+ drop=0.,
322
+ attn_drop=0.,
323
+ drop_path=0.,
324
+ act_layer=nn.GELU,
325
+ norm_layer=nn.LayerNorm):
326
+ super().__init__()
327
+ self.dim = dim
328
+ self.input_resolution = input_resolution
329
+ self.num_heads = num_heads
330
+ self.window_size = window_size
331
+ self.shift_size = shift_size
332
+ self.mlp_ratio = mlp_ratio
333
+ if min(self.input_resolution) <= self.window_size:
334
+ # if window size is larger than input resolution, we don't partition windows
335
+ self.shift_size = 0
336
+ self.window_size = min(self.input_resolution)
337
+ assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
338
+
339
+ self.norm1 = norm_layer(dim)
340
+ self.attn = WindowAttention(
341
+ dim,
342
+ window_size=to_2tuple(self.window_size),
343
+ num_heads=num_heads,
344
+ qkv_bias=qkv_bias,
345
+ qk_scale=qk_scale,
346
+ attn_drop=attn_drop,
347
+ proj_drop=drop)
348
+
349
+ self.conv_scale = conv_scale
350
+ self.conv_block = CAB(num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor)
351
+
352
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
353
+ self.norm2 = norm_layer(dim)
354
+ mlp_hidden_dim = int(dim * mlp_ratio)
355
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
356
+
357
+ def forward(self, x, x_size, rpi_sa, attn_mask):
358
+ h, w = x_size
359
+ b, _, c = x.shape
360
+ # assert seq_len == h * w, "input feature has wrong size"
361
+
362
+ shortcut = x
363
+ x = self.norm1(x)
364
+ x = x.view(b, h, w, c)
365
+
366
+ # Conv_X
367
+ conv_x = self.conv_block(x.permute(0, 3, 1, 2))
368
+ conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(b, h * w, c)
369
+
370
+ # cyclic shift
371
+ if self.shift_size > 0:
372
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
373
+ attn_mask = attn_mask
374
+ else:
375
+ shifted_x = x
376
+ attn_mask = None
377
+
378
+ # partition windows
379
+ x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
380
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
381
+
382
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
383
+ attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=attn_mask)
384
+
385
+ # merge windows
386
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
387
+ shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
388
+
389
+ # reverse cyclic shift
390
+ if self.shift_size > 0:
391
+ attn_x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
392
+ else:
393
+ attn_x = shifted_x
394
+ attn_x = attn_x.view(b, h * w, c)
395
+
396
+ # FFN
397
+ x = shortcut + self.drop_path(attn_x) + conv_x * self.conv_scale
398
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
399
+
400
  return x
401
  class AttenBlocks(nn.Module):
402
  """ A series of attention blocks for one RHAG.
 
941
  x = x / self.img_range + self.mean
942
 
943
  return x
944
+
945
+
946
  # ------------------------------ HYPERPARAMS ------------------------------ #
947
  config = {
948
  "network_g": {
 
992
  }
993
 
994
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
995
+ # DEVICE = torch.device('mps' if torch.backends.mps.is_built() else 'cpu')
996
+ print('device', DEVICE)
997
  class Network:
998
+ def __init__(self,config = config, device=DEVICE):
 
999
  self.config = config
1000
+ self.device = device
1001
  self.model = HAT(
1002
  upscale=self.config['network_g']['upscale'],
1003
  in_chans=self.config['network_g']['in_chans'],
 
1014
  mlp_ratio=self.config['network_g']['mlp_ratio'],
1015
  upsampler=self.config['network_g']['upsampler'],
1016
  resi_connection=self.config['network_g']['resi_connection']
1017
+ ).to(self.device)
 
 
 
 
 
1018
  self.optimizer = optim.Adam(self.model.parameters(), lr=self.config['train']['optim_g']['lr'], weight_decay=config['train']['optim_g']['weight_decay'],betas=tuple(config['train']['optim_g']['betas']))
1019
+
1020
+ def load_network(self, checkpoint_path):
1021
+
1022
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
1023
+ self.model.load_state_dict(checkpoint['model'])
1024
+ self.optimizer.load_state_dict(checkpoint['optimizer']) # before create and load scheduler
1025
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1026
  def pre_process(self):
1027
  # pad to multiplication of window_size
1028
  window_size = self.config['network_g']['window_size'] * 4
 
1042
  self.mod_pad_w = window_size - w % window_size
1043
  for i in range(self.mod_pad_w):
1044
  self.input_tile = F.pad(self.input_tile, (0, 1, 0, 0), 'reflect')
1045
+
 
1046
  def post_process(self):
1047
  _, _, h, w = self.output_tile.size()
1048
  self.output_tile = self.output_tile[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
1049
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1050
  def tile_valid(self, lr_images):
1051
  """
1052
  Process all tiles of an image in a batch and then merge them back into the output image.
 
1150
  gc.collect()
1151
  torch.cuda.empty_cache()
1152
  return sr_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1153
 
1154
+ def inference(self, lr_image, hr_image = None, deployment = False):
1155
  """
1156
  - lr_image: torch.Tensor
1157
  3D Tensor (C, H, W)
 
1160
  ground-truth high-res image. If used solely for inference, skip this. Default is None/
1161
  """
1162
  lr_image = lr_image.unsqueeze(0).to(self.device)
1163
+
1164
  self.for_inference = True
1165
  with torch.no_grad():
1166
  sr_image = self.tile_valid(lr_image)
1167
+ sr_image = torch.clamp(sr_image, 0, 1)
1168
 
1169
+ if deployment:
1170
+ return sr_image.squeeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1171
  else:
1172
+ lr_image = lr_image.squeeze(0)
1173
+ sr_image = sr_image.squeeze(0)
 
 
 
 
 
1174
 
1175
+ print(">> Size of low-res image:", lr_image.size())
1176
+ print(">> Size of super-res image:", sr_image.size())
1177
+ if hr_image != None:
1178
+ print(">> Size of high-res image:", hr_image.size())
1179
+
1180
+ if hr_image != None:
1181
+ fig, axes = plt.subplots(1, 3, figsize=(10, 6))
1182
+ axes[0].imshow(lr_image.cpu().detach().permute((1, 2, 0)))
1183
+ axes[0].set_title('Low Resolution')
1184
+ axes[1].imshow(sr_image.cpu().detach().permute((1, 2, 0)))
1185
+ axes[1].set_title('Super Resolution')
1186
+ axes[2].imshow(hr_image.cpu().detach().permute((1, 2, 0)))
1187
+ axes[2].set_title('High Resolution')
1188
+ for ax in axes.flat:
1189
+ ax.axis('off')
1190
+ else:
1191
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6))
1192
+ axes[0].imshow(lr_image.cpu().detach().permute((1, 2, 0)))
1193
+ axes[0].set_title('Low Resolution')
1194
+ axes[1].imshow(sr_image.cpu().detach().permute((1, 2, 0)))
1195
+ axes[1].set_title('Super Resolution')
1196
+ for ax in axes.flat:
1197
+ ax.axis('off')
1198
+
1199
+ plt.tight_layout()
1200
+ plt.show()
1201
+ return sr_image
1202
 
1203
+ def HAT_for_deployment(lr_image, model_path = 'models/HAT/hat_model_checkpoint_best.pth'):
1204
+ lr_image = transforms.functional.to_tensor(lr_image)
1205
+ hat = Network()
1206
+ hat.load_network(model_path)
1207
+ t1 = time.time()
1208
+ sr_image = hat.inference(lr_image, deployment=True).cpu().numpy()
1209
+ t2 = time.time()
1210
+ print("Time taken to infer:", t2 - t1)
1211
+ # If image is in [C, H, W] format, transpose it to [H, W, C]
1212
+ sr_image = np.transpose(sr_image, (1, 2, 0))
1213
+ if sr_image.max() <= 1.0:
1214
+ sr_image = (sr_image * 255).astype(np.uint8)
1215
+ sr_image = Image.fromarray(sr_image)
1216
+ return sr_image
1217
 
 
 
 
 
 
 
 
 
 
1218
  if __name__ == "__main__":
1219
  import os
1220
  import sys
1221
+ # Getting to the true directory
1222
  sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../"))
 
 
 
1223
 
1224
+ # Define the model's file path and the Google Drive link
1225
+ model_path = 'models/HAT/hat_model_checkpoint_best.pth'
1226
+ gdrive_id = '1LHIUM7YoUDk8cXWzVZhroAcA1xXi-d87' # Replace with your actual Google Drive file URL
1227
+
1228
+ # Check if the model file exists
1229
+ if not os.path.exists(model_path):
1230
+ print(f"Model file not found at {model_path}. Downloading from Google Drive...")
1231
+ # Ensure the directory exists, as gdown will not automatically create directory paths
1232
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
1233
+ # Download the file from Google Drive
1234
+ # gdown.download(id=gdrive_id, output=model_path, quiet=False)
1235
+ else:
1236
+ print(f"Model file found at {model_path}. No need to download.")
1237
 
1238
+ image_path = "images/demo.png"
1239
+ lr_image = Image.open(image_path)
1240
+ # lr_image = transforms.functional.to_tensor(lr_image)
1241
+
1242
+ # hat = Network()
1243
+ # hat.load_network(model_path)
1244
+ # hat.inference(lr_image)
1245
+ print(HAT_for_deployment(lr_image, model_path))
 
 
 
1246