Spaces:
Running
Running
Nguyễn Bá Thiêm
commited on
Commit
•
b16ab70
1
Parent(s):
239e299
Add image super resolution functionality
Browse files- .gitignore +1 -0
- app.py +105 -0
- images/{img_003_SRF_4_LR.png → demo.png} +0 -0
- models/HAT/hat.py +197 -314
.gitignore
CHANGED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
models/HAT/__pycache__/hat.cpython-39.pyc
|
app.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import streamlit as st
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import cv2 # If you're using OpenCV for image processing
|
6 |
+
from io import BytesIO
|
7 |
+
import base64
|
8 |
+
from models.HAT.hat import *
|
9 |
+
# Initialize session state for enhanced images
|
10 |
+
if 'hat_enhanced_image' not in st.session_state:
|
11 |
+
st.session_state['hat_enhanced_image'] = None
|
12 |
+
|
13 |
+
if 'rcan_enhanced_image' not in st.session_state:
|
14 |
+
st.session_state['rcan_enhanced_image'] = None
|
15 |
+
|
16 |
+
if 'hat_clicked' not in st.session_state:
|
17 |
+
st.session_state['hat_clicked'] = False
|
18 |
+
if 'rcan_clicked' not in st.session_state:
|
19 |
+
st.session_state['rcan_clicked'] = False
|
20 |
+
|
21 |
+
st.markdown("<h1 style='text-align: center'>Image Super Resolution</h1>", unsafe_allow_html=True)
|
22 |
+
# Sidebar for navigation
|
23 |
+
st.sidebar.title("Options")
|
24 |
+
app_mode = st.sidebar.selectbox("Choose the input source",
|
25 |
+
["Upload image", "Take a photo"])
|
26 |
+
# Depending on the choice, show the uploader widget or webcam capture
|
27 |
+
if app_mode == "Upload image":
|
28 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"], on_change=lambda: reset_states())
|
29 |
+
if uploaded_file is not None:
|
30 |
+
image = Image.open(uploaded_file).convert("RGB")
|
31 |
+
elif app_mode == "Take a photo":
|
32 |
+
# Using JS code to access user's webcam
|
33 |
+
camera_input = st.camera_input("Take a picture", on_change=lambda: reset_states())
|
34 |
+
if camera_input is not None:
|
35 |
+
# Convert the camera image to an RGB image
|
36 |
+
image = Image.open(camera_input).convert("RGB")
|
37 |
+
|
38 |
+
def reset_states():
|
39 |
+
st.session_state['hat_enhanced_image'] = None
|
40 |
+
st.session_state['rcan_enhanced_image'] = None
|
41 |
+
st.session_state['hat_clicked'] = False
|
42 |
+
st.session_state['rcan_clicked'] = False
|
43 |
+
|
44 |
+
def get_image_download_link(img, filename):
|
45 |
+
"""Generates a link allowing the PIL image to be downloaded"""
|
46 |
+
# Convert the PIL image to Bytes
|
47 |
+
buffered = BytesIO()
|
48 |
+
img.save(buffered, format="PNG")
|
49 |
+
return st.download_button(
|
50 |
+
label="Download Image",
|
51 |
+
data=buffered.getvalue(),
|
52 |
+
file_name=filename,
|
53 |
+
mime="image/png"
|
54 |
+
)
|
55 |
+
|
56 |
+
if 'image' in locals():
|
57 |
+
# st.image(image, caption='Uploaded Image', use_column_width=True)
|
58 |
+
st.write("")
|
59 |
+
|
60 |
+
if st.button('Enhance with HAT'):
|
61 |
+
with st.spinner('Processing using HAT...'):
|
62 |
+
with st.spinner('Wait for it... the model is processing the image'):
|
63 |
+
# Simulate a delay for processing image
|
64 |
+
|
65 |
+
enhanced_image = HAT_for_deployment(image)
|
66 |
+
st.session_state['hat_enhanced_image'] = enhanced_image
|
67 |
+
st.session_state['hat_clicked'] = True
|
68 |
+
st.success('Done!')
|
69 |
+
# Display the low and high resolution images side by side
|
70 |
+
if st.session_state['hat_enhanced_image'] is not None:
|
71 |
+
col1, col2 = st.columns(2)
|
72 |
+
col1.header("Original")
|
73 |
+
col1.image(image, use_column_width=True)
|
74 |
+
|
75 |
+
col2.header("Enhanced")
|
76 |
+
col2.image(st.session_state['hat_enhanced_image'], use_column_width=True)
|
77 |
+
with col2:
|
78 |
+
get_image_download_link(st.session_state['hat_enhanced_image'], 'hat_enhanced.jpg')
|
79 |
+
|
80 |
+
if st.button('Enhance with RCAN'):
|
81 |
+
with st.spinner('Processing using RCAN...'):
|
82 |
+
with st.spinner('Wait for it... the model is processing the image'):
|
83 |
+
# Simulate a delay for processing image
|
84 |
+
time.sleep(2) # replace this with actual model processing code
|
85 |
+
|
86 |
+
enhanced_image = image
|
87 |
+
# Display the low and high resolution images side by side
|
88 |
+
st.session_state['rcan_enhanced_image'] = enhanced_image
|
89 |
+
|
90 |
+
st.session_state['rcan_clicked'] = True
|
91 |
+
st.success('Done!')
|
92 |
+
|
93 |
+
if st.session_state['rcan_enhanced_image'] is not None:
|
94 |
+
col1, col2 = st.columns(2)
|
95 |
+
col1.header("Original")
|
96 |
+
col1.image(image, use_column_width=True)
|
97 |
+
|
98 |
+
col2.header("Enhanced")
|
99 |
+
col2.image(st.session_state['rcan_enhanced_image'], use_column_width=True)
|
100 |
+
with col2:
|
101 |
+
get_image_download_link(st.session_state['rcan_enhanced_image'], 'rcan_enhanced.jpg')
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
|
images/{img_003_SRF_4_LR.png → demo.png}
RENAMED
File without changes
|
models/HAT/hat.py
CHANGED
@@ -1,28 +1,15 @@
|
|
|
|
1 |
import gdown
|
2 |
-
|
3 |
-
# url = 'https://drive.google.com/file/d/1LHIUM7YoUDk8cXWzVZhroAcA1xXi-d87/view?usp=drive_link'
|
4 |
-
output = 'models/HAT/hat_model_checkpoint_best.pth'
|
5 |
-
# gdown.download(url, output, quiet=False)
|
6 |
-
|
7 |
import gc
|
8 |
import os
|
9 |
import random
|
10 |
import time
|
11 |
-
import wandb
|
12 |
-
from tqdm import tqdm
|
13 |
-
|
14 |
import matplotlib.pyplot as plt
|
15 |
from PIL import Image
|
16 |
-
from skimage.metrics import structural_similarity as ssim
|
17 |
-
|
18 |
import torch
|
19 |
from torch import nn, optim
|
20 |
import torch.nn.functional as F
|
21 |
-
from torch.utils.data import Dataset, DataLoader, ConcatDataset
|
22 |
from torchvision import transforms
|
23 |
-
from torchvision.transforms import Compose
|
24 |
-
from torchmetrics.functional.image import structural_similarity_index_measure as ssim
|
25 |
-
|
26 |
from basicsr.archs.arch_util import to_2tuple, trunc_normal_
|
27 |
from einops import rearrange
|
28 |
import math
|
@@ -299,6 +286,117 @@ class OCAB(nn.Module):
|
|
299 |
x = self.proj(x) + shortcut
|
300 |
|
301 |
x = x + self.mlp(self.norm2(x))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
return x
|
303 |
class AttenBlocks(nn.Module):
|
304 |
""" A series of attention blocks for one RHAG.
|
@@ -843,6 +941,8 @@ class HAT(nn.Module):
|
|
843 |
x = x / self.img_range + self.mean
|
844 |
|
845 |
return x
|
|
|
|
|
846 |
# ------------------------------ HYPERPARAMS ------------------------------ #
|
847 |
config = {
|
848 |
"network_g": {
|
@@ -892,12 +992,12 @@ config = {
|
|
892 |
}
|
893 |
|
894 |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
895 |
-
DEVICE
|
896 |
-
|
897 |
class Network:
|
898 |
-
def __init__(self,
|
899 |
-
config = config, device=DEVICE, run_id=None, wandb_mode = False, STOP = float('inf'), save_temp_model = True, train_model_continue = False):
|
900 |
self.config = config
|
|
|
901 |
self.model = HAT(
|
902 |
upscale=self.config['network_g']['upscale'],
|
903 |
in_chans=self.config['network_g']['in_chans'],
|
@@ -914,59 +1014,15 @@ class Network:
|
|
914 |
mlp_ratio=self.config['network_g']['mlp_ratio'],
|
915 |
upsampler=self.config['network_g']['upsampler'],
|
916 |
resi_connection=self.config['network_g']['resi_connection']
|
917 |
-
).to(device)
|
918 |
-
self.device = device
|
919 |
-
self.STOP = STOP
|
920 |
-
self.wandb_mode = wandb_mode
|
921 |
-
self.loss_fn = nn.L1Loss(reduction='mean').to(device)
|
922 |
-
|
923 |
self.optimizer = optim.Adam(self.model.parameters(), lr=self.config['train']['optim_g']['lr'], weight_decay=config['train']['optim_g']['weight_decay'],betas=tuple(config['train']['optim_g']['betas']))
|
924 |
-
|
925 |
-
|
926 |
-
|
927 |
-
|
928 |
-
self.
|
929 |
-
self.
|
930 |
-
|
931 |
-
self.last_valid_loss = float('inf')
|
932 |
-
checkpoint_path = output
|
933 |
-
if self.save_temp_model:
|
934 |
-
if self.train_model_continue:
|
935 |
-
# Load the network and other states from the checkpoint
|
936 |
-
self.start_epoch, train_loss, valid_loss = self.load_network(checkpoint_path)
|
937 |
-
|
938 |
-
initial_lr = self.config['train']['optim_g']['lr'] * self.config['train']['scheduler']['gamma'] # Define your initial or desired learning rate
|
939 |
-
for param_group in self.optimizer.param_groups:
|
940 |
-
param_group['lr'] = initial_lr # Resetting learning rate
|
941 |
-
|
942 |
-
# Recreate the scheduler with the updated optimizer
|
943 |
-
self.scheduler = optim.lr_scheduler.MultiStepLR(
|
944 |
-
self.optimizer,
|
945 |
-
milestones=self.config['train']['scheduler']['milestones'],
|
946 |
-
gamma=self.config['train']['scheduler']['gamma'],
|
947 |
-
last_epoch = self.start_epoch - 1 # Ensure to set the last_epoch to continue correctly
|
948 |
-
)
|
949 |
-
|
950 |
-
# Print the updated learning rate and scheduler state
|
951 |
-
print("Updated Learning Rate is:", self.optimizer.param_groups[0]['lr'])
|
952 |
-
print(self.scheduler.state_dict())
|
953 |
-
self.last_valid_loss = valid_loss
|
954 |
-
# self.num_epochs-= self.start_epoch
|
955 |
-
print("Previous train loss: ", train_loss)
|
956 |
-
print("Previous valid loss: ", self.last_valid_loss)
|
957 |
-
|
958 |
-
# Resume training notice
|
959 |
-
print("------------------- Resuming training -------------------")
|
960 |
-
|
961 |
-
self.save_network(0, 0, 0, 'temp_model_checkpoint.pth')
|
962 |
-
|
963 |
-
def del_model(self):
|
964 |
-
del self.model
|
965 |
-
del self.optimizer
|
966 |
-
del self.scheduler
|
967 |
-
gc.collect()
|
968 |
-
torch.cuda.empty_cache()
|
969 |
-
|
970 |
def pre_process(self):
|
971 |
# pad to multiplication of window_size
|
972 |
window_size = self.config['network_g']['window_size'] * 4
|
@@ -986,84 +1042,11 @@ class Network:
|
|
986 |
self.mod_pad_w = window_size - w % window_size
|
987 |
for i in range(self.mod_pad_w):
|
988 |
self.input_tile = F.pad(self.input_tile, (0, 1, 0, 0), 'reflect')
|
989 |
-
|
990 |
-
|
991 |
def post_process(self):
|
992 |
_, _, h, w = self.output_tile.size()
|
993 |
self.output_tile = self.output_tile[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
994 |
|
995 |
-
|
996 |
-
def save_network(self, epoch, train_loss, valid_loss, checkpoint_path):
|
997 |
-
checkpoint = {
|
998 |
-
'epoch': epoch,
|
999 |
-
'train_loss': train_loss,
|
1000 |
-
'valid_loss': valid_loss,
|
1001 |
-
'model': self.model.state_dict(),
|
1002 |
-
'optimizer': self.optimizer.state_dict(),
|
1003 |
-
'learning_rate_scheduler': self.scheduler.state_dict(),
|
1004 |
-
'network': self
|
1005 |
-
}
|
1006 |
-
torch.save(checkpoint, checkpoint_path)
|
1007 |
-
|
1008 |
-
def load_network(self, checkpoint_path):
|
1009 |
-
|
1010 |
-
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
1011 |
-
self.model = HAT(
|
1012 |
-
upscale=self.config['network_g']['upscale'],
|
1013 |
-
in_chans=self.config['network_g']['in_chans'],
|
1014 |
-
img_size=self.config['network_g']['img_size'],
|
1015 |
-
window_size=self.config['network_g']['window_size'],
|
1016 |
-
compress_ratio=self.config['network_g']['compress_ratio'],
|
1017 |
-
squeeze_factor=self.config['network_g']['squeeze_factor'],
|
1018 |
-
conv_scale=self.config['network_g']['conv_scale'],
|
1019 |
-
overlap_ratio=self.config['network_g']['overlap_ratio'],
|
1020 |
-
img_range=self.config['network_g']['img_range'],
|
1021 |
-
depths=self.config['network_g']['depths'],
|
1022 |
-
embed_dim=self.config['network_g']['embed_dim'],
|
1023 |
-
num_heads=self.config['network_g']['num_heads'],
|
1024 |
-
mlp_ratio=self.config['network_g']['mlp_ratio'],
|
1025 |
-
upsampler=self.config['network_g']['upsampler'],
|
1026 |
-
resi_connection=self.config['network_g']['resi_connection']
|
1027 |
-
).to(self.device)
|
1028 |
-
self.optimizer = optim.Adam(self.model.parameters(), lr=self.config['train']['optim_g']['lr'], weight_decay=config['train']['optim_g']['weight_decay'],betas=tuple(config['train']['optim_g']['betas']))
|
1029 |
-
self.model.load_state_dict(checkpoint['model'])
|
1030 |
-
self.optimizer.load_state_dict(checkpoint['optimizer']) # before create and load scheduler
|
1031 |
-
|
1032 |
-
self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones = self.config['train']['scheduler']['milestones'], gamma=self.config['train']['scheduler']['gamma'])
|
1033 |
-
self.scheduler.load_state_dict(checkpoint['learning_rate_scheduler'])
|
1034 |
-
return checkpoint['epoch'], checkpoint['train_loss'], checkpoint['valid_loss']
|
1035 |
-
|
1036 |
-
def train_step(self, lr_images, hr_images):
|
1037 |
-
lr_images, hr_images = lr_images.to(self.device), hr_images.to(self.device)
|
1038 |
-
sr_images = self.model(lr_images)
|
1039 |
-
|
1040 |
-
self.optimizer.zero_grad()
|
1041 |
-
loss = self.loss_fn(sr_images, hr_images)
|
1042 |
-
loss.backward()
|
1043 |
-
self.optimizer.step()
|
1044 |
-
|
1045 |
-
# Memory cleanup
|
1046 |
-
del sr_images, lr_images, hr_images
|
1047 |
-
gc.collect()
|
1048 |
-
torch.cuda.empty_cache()
|
1049 |
-
|
1050 |
-
return loss.item()
|
1051 |
-
|
1052 |
-
def valid_step(self, lr_images, hr_images):
|
1053 |
-
lr_images, hr_images = lr_images.to(self.device), hr_images.to(self.device)
|
1054 |
-
|
1055 |
-
sr_images = self.tile_valid(lr_images)
|
1056 |
-
|
1057 |
-
loss = self.loss_fn(sr_images, hr_images)
|
1058 |
-
|
1059 |
-
# Memory cleanup
|
1060 |
-
del sr_images, lr_images, hr_images
|
1061 |
-
gc.collect()
|
1062 |
-
torch.cuda.empty_cache()
|
1063 |
-
|
1064 |
-
return loss.item()
|
1065 |
-
|
1066 |
-
|
1067 |
def tile_valid(self, lr_images):
|
1068 |
"""
|
1069 |
Process all tiles of an image in a batch and then merge them back into the output image.
|
@@ -1167,115 +1150,8 @@ class Network:
|
|
1167 |
gc.collect()
|
1168 |
torch.cuda.empty_cache()
|
1169 |
return sr_images
|
1170 |
-
|
1171 |
-
def train_model(self):
|
1172 |
-
|
1173 |
-
if self.wandb_mode:
|
1174 |
-
wandb.init(project='HAT-for-image-sr',
|
1175 |
-
resume='allow',
|
1176 |
-
config= self.config,
|
1177 |
-
id=self.run_id)
|
1178 |
-
wandb.watch(self.model)
|
1179 |
-
if self.train_model_continue:
|
1180 |
-
epoch_lst = range(self.start_epoch, self.num_epochs)
|
1181 |
-
else:
|
1182 |
-
epoch_lst = range(self.num_epochs)
|
1183 |
-
for epoch in epoch_lst:
|
1184 |
-
|
1185 |
-
start1 = time.time()
|
1186 |
-
|
1187 |
-
# ------------------- TRAIN -------------------
|
1188 |
-
if self.save_temp_model:
|
1189 |
-
self.load_network('temp_model_checkpoint.pth')
|
1190 |
-
self.model.train()
|
1191 |
-
train_epoch_loss = 0
|
1192 |
-
|
1193 |
-
stop = 0
|
1194 |
-
for hr_images, lr_images in tqdm(self.train_dataloader, desc=f'Epoch {epoch+1}/{self.num_epochs}'):
|
1195 |
-
|
1196 |
-
if stop == self.STOP:
|
1197 |
-
break
|
1198 |
-
stop+=1
|
1199 |
-
|
1200 |
-
loss = self.train_step(lr_images, hr_images)
|
1201 |
-
train_epoch_loss += loss
|
1202 |
-
|
1203 |
-
if self.wandb_mode:
|
1204 |
-
wandb.log({
|
1205 |
-
'batch_loss': loss,
|
1206 |
-
})
|
1207 |
-
|
1208 |
-
if self.wandb_mode:
|
1209 |
-
wandb.log({
|
1210 |
-
'learning_rate': self.optimizer.param_groups[0]['lr']
|
1211 |
-
})
|
1212 |
-
print("Learning Rate is:", self.optimizer.param_groups[0]['lr'])
|
1213 |
-
|
1214 |
-
self.scheduler.step()
|
1215 |
-
|
1216 |
-
|
1217 |
-
if self.save_temp_model:
|
1218 |
-
self.save_network(epoch, train_epoch_loss, 0, 'temp_model_checkpoint.pth')
|
1219 |
-
print(self.scheduler.state_dict())
|
1220 |
-
self.del_model()
|
1221 |
-
|
1222 |
-
del hr_images
|
1223 |
-
del lr_images
|
1224 |
-
gc.collect()
|
1225 |
-
|
1226 |
-
train_epoch_loss /= len(self.train_dataloader)
|
1227 |
-
|
1228 |
-
end1 = time.time()
|
1229 |
-
|
1230 |
-
|
1231 |
-
# ------------------- VALID -------------------
|
1232 |
-
start2 = time.time()
|
1233 |
-
if self.save_temp_model:
|
1234 |
-
self.load_network('temp_model_checkpoint.pth')
|
1235 |
-
|
1236 |
-
self.model.eval()
|
1237 |
-
with torch.no_grad():
|
1238 |
-
valid_epoch_loss = 0
|
1239 |
-
|
1240 |
-
stop = 0
|
1241 |
-
for hr_images, lr_images in tqdm(self.valid_dataloader, desc=f'Epoch {epoch+1}/{self.num_epochs}'):
|
1242 |
-
if stop == self.STOP:
|
1243 |
-
break
|
1244 |
-
stop+=1
|
1245 |
-
loss = self.valid_step(lr_images, hr_images)
|
1246 |
-
valid_epoch_loss += loss
|
1247 |
-
|
1248 |
-
valid_epoch_loss /= len(self.valid_dataloader)
|
1249 |
-
|
1250 |
-
end2 = time.time()
|
1251 |
-
|
1252 |
-
# ------------------- LOG -------------------
|
1253 |
-
if self.wandb_mode:
|
1254 |
-
wandb.log({
|
1255 |
-
'train_loss': train_epoch_loss,
|
1256 |
-
'valid_loss': valid_epoch_loss,
|
1257 |
-
})
|
1258 |
-
# ------------------- VERBOSE -------------------
|
1259 |
-
print(f'Epoch {epoch+1}/{self.num_epochs} | Train Loss: {train_epoch_loss:.4f} | Valid Loss: {valid_epoch_loss:.4f} | Time train: {end1-start1:.2f}s | Time valid: {end2-start2:.2f}s')
|
1260 |
-
|
1261 |
-
# ------------------- CHECKPOINT -------------------
|
1262 |
-
self.save_network(epoch, train_epoch_loss, valid_epoch_loss, 'model_checkpoint_latest.pth')
|
1263 |
-
if valid_epoch_loss < self.last_valid_loss:
|
1264 |
-
self.last_valid_loss = valid_epoch_loss
|
1265 |
-
self.save_network(epoch, train_epoch_loss, valid_epoch_loss, 'model_checkpoint_best.pth')
|
1266 |
-
print("New best checkpoint saved!")
|
1267 |
-
|
1268 |
-
if self.save_temp_model:
|
1269 |
-
self.del_model()
|
1270 |
-
|
1271 |
-
del hr_images
|
1272 |
-
del lr_images
|
1273 |
-
gc.collect()
|
1274 |
-
|
1275 |
-
if self.wandb_mode:
|
1276 |
-
wandb.finish()
|
1277 |
|
1278 |
-
def inference(self, lr_image, hr_image):
|
1279 |
"""
|
1280 |
- lr_image: torch.Tensor
|
1281 |
3D Tensor (C, H, W)
|
@@ -1284,80 +1160,87 @@ class Network:
|
|
1284 |
ground-truth high-res image. If used solely for inference, skip this. Default is None/
|
1285 |
"""
|
1286 |
lr_image = lr_image.unsqueeze(0).to(self.device)
|
|
|
1287 |
self.for_inference = True
|
1288 |
with torch.no_grad():
|
1289 |
sr_image = self.tile_valid(lr_image)
|
|
|
1290 |
|
1291 |
-
|
1292 |
-
|
1293 |
-
|
1294 |
-
print(">> Size of low-res image:", lr_image.size())
|
1295 |
-
print(">> Size of super-res image:", sr_image.size())
|
1296 |
-
if hr_image != None:
|
1297 |
-
print(">> Size of high-res image:", hr_image.size())
|
1298 |
-
|
1299 |
-
if hr_image != None:
|
1300 |
-
fig, axes = plt.subplots(1, 3, figsize=(10, 6))
|
1301 |
-
axes[0].imshow(lr_image.cpu().detach().permute((1, 2, 0)))
|
1302 |
-
axes[0].set_title('Low Resolution')
|
1303 |
-
axes[1].imshow(sr_image.cpu().detach().permute((1, 2, 0)))
|
1304 |
-
axes[1].set_title('Super Resolution')
|
1305 |
-
axes[2].imshow(hr_image.cpu().detach().permute((1, 2, 0)))
|
1306 |
-
axes[2].set_title('High Resolution')
|
1307 |
-
for ax in axes.flat:
|
1308 |
-
ax.axis('off')
|
1309 |
else:
|
1310 |
-
|
1311 |
-
|
1312 |
-
axes[0].set_title('Low Resolution')
|
1313 |
-
axes[1].imshow(sr_image.cpu().detach().permute((1, 2, 0)))
|
1314 |
-
axes[1].set_title('Super Resolution')
|
1315 |
-
for ax in axes.flat:
|
1316 |
-
ax.axis('off')
|
1317 |
|
1318 |
-
|
1319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1320 |
|
1321 |
-
|
1322 |
-
|
1323 |
-
|
1324 |
-
|
1325 |
-
|
1326 |
-
|
1327 |
-
|
1328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1329 |
|
1330 |
-
def __getitem__(self, idx):
|
1331 |
-
|
1332 |
-
lr_image = Image.open(self.lr_image_path)
|
1333 |
-
|
1334 |
-
lr_image = transforms.functional.to_tensor(lr_image)
|
1335 |
-
|
1336 |
-
return lr_image
|
1337 |
-
|
1338 |
-
|
1339 |
if __name__ == "__main__":
|
1340 |
import os
|
1341 |
import sys
|
1342 |
-
# Getting to the
|
1343 |
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../"))
|
1344 |
-
image_path = "images/img_003_SRF_4_LR.png"
|
1345 |
-
|
1346 |
-
infer_dataset = TestDataset(images_path=image_path)
|
1347 |
|
1348 |
-
#
|
1349 |
-
|
1350 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1351 |
|
1352 |
-
|
1353 |
-
|
1354 |
-
|
1355 |
-
|
1356 |
-
hat
|
1357 |
-
|
1358 |
-
|
1359 |
-
|
1360 |
-
hat.inference(lr_image)
|
1361 |
-
|
1362 |
-
|
1363 |
|
|
|
1 |
+
import numpy as np
|
2 |
import gdown
|
|
|
|
|
|
|
|
|
|
|
3 |
import gc
|
4 |
import os
|
5 |
import random
|
6 |
import time
|
|
|
|
|
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
from PIL import Image
|
|
|
|
|
9 |
import torch
|
10 |
from torch import nn, optim
|
11 |
import torch.nn.functional as F
|
|
|
12 |
from torchvision import transforms
|
|
|
|
|
|
|
13 |
from basicsr.archs.arch_util import to_2tuple, trunc_normal_
|
14 |
from einops import rearrange
|
15 |
import math
|
|
|
286 |
x = self.proj(x) + shortcut
|
287 |
|
288 |
x = x + self.mlp(self.norm2(x))
|
289 |
+
return x
|
290 |
+
class HAB(nn.Module):
|
291 |
+
r""" Hybrid Attention Block.
|
292 |
+
|
293 |
+
Args:
|
294 |
+
dim (int): Number of input channels.
|
295 |
+
input_resolution (tuple[int]): Input resolution.
|
296 |
+
num_heads (int): Number of attention heads.
|
297 |
+
window_size (int): Window size.
|
298 |
+
shift_size (int): Shift size for SW-MSA.
|
299 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
300 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
301 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
302 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
303 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
304 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
305 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
306 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
307 |
+
"""
|
308 |
+
|
309 |
+
def __init__(self,
|
310 |
+
dim,
|
311 |
+
input_resolution,
|
312 |
+
num_heads,
|
313 |
+
window_size=7,
|
314 |
+
shift_size=0,
|
315 |
+
compress_ratio=3,
|
316 |
+
squeeze_factor=30,
|
317 |
+
conv_scale=0.01,
|
318 |
+
mlp_ratio=4.,
|
319 |
+
qkv_bias=True,
|
320 |
+
qk_scale=None,
|
321 |
+
drop=0.,
|
322 |
+
attn_drop=0.,
|
323 |
+
drop_path=0.,
|
324 |
+
act_layer=nn.GELU,
|
325 |
+
norm_layer=nn.LayerNorm):
|
326 |
+
super().__init__()
|
327 |
+
self.dim = dim
|
328 |
+
self.input_resolution = input_resolution
|
329 |
+
self.num_heads = num_heads
|
330 |
+
self.window_size = window_size
|
331 |
+
self.shift_size = shift_size
|
332 |
+
self.mlp_ratio = mlp_ratio
|
333 |
+
if min(self.input_resolution) <= self.window_size:
|
334 |
+
# if window size is larger than input resolution, we don't partition windows
|
335 |
+
self.shift_size = 0
|
336 |
+
self.window_size = min(self.input_resolution)
|
337 |
+
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
|
338 |
+
|
339 |
+
self.norm1 = norm_layer(dim)
|
340 |
+
self.attn = WindowAttention(
|
341 |
+
dim,
|
342 |
+
window_size=to_2tuple(self.window_size),
|
343 |
+
num_heads=num_heads,
|
344 |
+
qkv_bias=qkv_bias,
|
345 |
+
qk_scale=qk_scale,
|
346 |
+
attn_drop=attn_drop,
|
347 |
+
proj_drop=drop)
|
348 |
+
|
349 |
+
self.conv_scale = conv_scale
|
350 |
+
self.conv_block = CAB(num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor)
|
351 |
+
|
352 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
353 |
+
self.norm2 = norm_layer(dim)
|
354 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
355 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
356 |
+
|
357 |
+
def forward(self, x, x_size, rpi_sa, attn_mask):
|
358 |
+
h, w = x_size
|
359 |
+
b, _, c = x.shape
|
360 |
+
# assert seq_len == h * w, "input feature has wrong size"
|
361 |
+
|
362 |
+
shortcut = x
|
363 |
+
x = self.norm1(x)
|
364 |
+
x = x.view(b, h, w, c)
|
365 |
+
|
366 |
+
# Conv_X
|
367 |
+
conv_x = self.conv_block(x.permute(0, 3, 1, 2))
|
368 |
+
conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(b, h * w, c)
|
369 |
+
|
370 |
+
# cyclic shift
|
371 |
+
if self.shift_size > 0:
|
372 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
373 |
+
attn_mask = attn_mask
|
374 |
+
else:
|
375 |
+
shifted_x = x
|
376 |
+
attn_mask = None
|
377 |
+
|
378 |
+
# partition windows
|
379 |
+
x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
|
380 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
|
381 |
+
|
382 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
383 |
+
attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=attn_mask)
|
384 |
+
|
385 |
+
# merge windows
|
386 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
|
387 |
+
shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
|
388 |
+
|
389 |
+
# reverse cyclic shift
|
390 |
+
if self.shift_size > 0:
|
391 |
+
attn_x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
392 |
+
else:
|
393 |
+
attn_x = shifted_x
|
394 |
+
attn_x = attn_x.view(b, h * w, c)
|
395 |
+
|
396 |
+
# FFN
|
397 |
+
x = shortcut + self.drop_path(attn_x) + conv_x * self.conv_scale
|
398 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
399 |
+
|
400 |
return x
|
401 |
class AttenBlocks(nn.Module):
|
402 |
""" A series of attention blocks for one RHAG.
|
|
|
941 |
x = x / self.img_range + self.mean
|
942 |
|
943 |
return x
|
944 |
+
|
945 |
+
|
946 |
# ------------------------------ HYPERPARAMS ------------------------------ #
|
947 |
config = {
|
948 |
"network_g": {
|
|
|
992 |
}
|
993 |
|
994 |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
995 |
+
# DEVICE = torch.device('mps' if torch.backends.mps.is_built() else 'cpu')
|
996 |
+
print('device', DEVICE)
|
997 |
class Network:
|
998 |
+
def __init__(self,config = config, device=DEVICE):
|
|
|
999 |
self.config = config
|
1000 |
+
self.device = device
|
1001 |
self.model = HAT(
|
1002 |
upscale=self.config['network_g']['upscale'],
|
1003 |
in_chans=self.config['network_g']['in_chans'],
|
|
|
1014 |
mlp_ratio=self.config['network_g']['mlp_ratio'],
|
1015 |
upsampler=self.config['network_g']['upsampler'],
|
1016 |
resi_connection=self.config['network_g']['resi_connection']
|
1017 |
+
).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
1018 |
self.optimizer = optim.Adam(self.model.parameters(), lr=self.config['train']['optim_g']['lr'], weight_decay=config['train']['optim_g']['weight_decay'],betas=tuple(config['train']['optim_g']['betas']))
|
1019 |
+
|
1020 |
+
def load_network(self, checkpoint_path):
|
1021 |
+
|
1022 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
1023 |
+
self.model.load_state_dict(checkpoint['model'])
|
1024 |
+
self.optimizer.load_state_dict(checkpoint['optimizer']) # before create and load scheduler
|
1025 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1026 |
def pre_process(self):
|
1027 |
# pad to multiplication of window_size
|
1028 |
window_size = self.config['network_g']['window_size'] * 4
|
|
|
1042 |
self.mod_pad_w = window_size - w % window_size
|
1043 |
for i in range(self.mod_pad_w):
|
1044 |
self.input_tile = F.pad(self.input_tile, (0, 1, 0, 0), 'reflect')
|
1045 |
+
|
|
|
1046 |
def post_process(self):
|
1047 |
_, _, h, w = self.output_tile.size()
|
1048 |
self.output_tile = self.output_tile[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
1049 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1050 |
def tile_valid(self, lr_images):
|
1051 |
"""
|
1052 |
Process all tiles of an image in a batch and then merge them back into the output image.
|
|
|
1150 |
gc.collect()
|
1151 |
torch.cuda.empty_cache()
|
1152 |
return sr_images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1153 |
|
1154 |
+
def inference(self, lr_image, hr_image = None, deployment = False):
|
1155 |
"""
|
1156 |
- lr_image: torch.Tensor
|
1157 |
3D Tensor (C, H, W)
|
|
|
1160 |
ground-truth high-res image. If used solely for inference, skip this. Default is None/
|
1161 |
"""
|
1162 |
lr_image = lr_image.unsqueeze(0).to(self.device)
|
1163 |
+
|
1164 |
self.for_inference = True
|
1165 |
with torch.no_grad():
|
1166 |
sr_image = self.tile_valid(lr_image)
|
1167 |
+
sr_image = torch.clamp(sr_image, 0, 1)
|
1168 |
|
1169 |
+
if deployment:
|
1170 |
+
return sr_image.squeeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1171 |
else:
|
1172 |
+
lr_image = lr_image.squeeze(0)
|
1173 |
+
sr_image = sr_image.squeeze(0)
|
|
|
|
|
|
|
|
|
|
|
1174 |
|
1175 |
+
print(">> Size of low-res image:", lr_image.size())
|
1176 |
+
print(">> Size of super-res image:", sr_image.size())
|
1177 |
+
if hr_image != None:
|
1178 |
+
print(">> Size of high-res image:", hr_image.size())
|
1179 |
+
|
1180 |
+
if hr_image != None:
|
1181 |
+
fig, axes = plt.subplots(1, 3, figsize=(10, 6))
|
1182 |
+
axes[0].imshow(lr_image.cpu().detach().permute((1, 2, 0)))
|
1183 |
+
axes[0].set_title('Low Resolution')
|
1184 |
+
axes[1].imshow(sr_image.cpu().detach().permute((1, 2, 0)))
|
1185 |
+
axes[1].set_title('Super Resolution')
|
1186 |
+
axes[2].imshow(hr_image.cpu().detach().permute((1, 2, 0)))
|
1187 |
+
axes[2].set_title('High Resolution')
|
1188 |
+
for ax in axes.flat:
|
1189 |
+
ax.axis('off')
|
1190 |
+
else:
|
1191 |
+
fig, axes = plt.subplots(1, 2, figsize=(10, 6))
|
1192 |
+
axes[0].imshow(lr_image.cpu().detach().permute((1, 2, 0)))
|
1193 |
+
axes[0].set_title('Low Resolution')
|
1194 |
+
axes[1].imshow(sr_image.cpu().detach().permute((1, 2, 0)))
|
1195 |
+
axes[1].set_title('Super Resolution')
|
1196 |
+
for ax in axes.flat:
|
1197 |
+
ax.axis('off')
|
1198 |
+
|
1199 |
+
plt.tight_layout()
|
1200 |
+
plt.show()
|
1201 |
+
return sr_image
|
1202 |
|
1203 |
+
def HAT_for_deployment(lr_image, model_path = 'models/HAT/hat_model_checkpoint_best.pth'):
|
1204 |
+
lr_image = transforms.functional.to_tensor(lr_image)
|
1205 |
+
hat = Network()
|
1206 |
+
hat.load_network(model_path)
|
1207 |
+
t1 = time.time()
|
1208 |
+
sr_image = hat.inference(lr_image, deployment=True).cpu().numpy()
|
1209 |
+
t2 = time.time()
|
1210 |
+
print("Time taken to infer:", t2 - t1)
|
1211 |
+
# If image is in [C, H, W] format, transpose it to [H, W, C]
|
1212 |
+
sr_image = np.transpose(sr_image, (1, 2, 0))
|
1213 |
+
if sr_image.max() <= 1.0:
|
1214 |
+
sr_image = (sr_image * 255).astype(np.uint8)
|
1215 |
+
sr_image = Image.fromarray(sr_image)
|
1216 |
+
return sr_image
|
1217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1218 |
if __name__ == "__main__":
|
1219 |
import os
|
1220 |
import sys
|
1221 |
+
# Getting to the true directory
|
1222 |
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../"))
|
|
|
|
|
|
|
1223 |
|
1224 |
+
# Define the model's file path and the Google Drive link
|
1225 |
+
model_path = 'models/HAT/hat_model_checkpoint_best.pth'
|
1226 |
+
gdrive_id = '1LHIUM7YoUDk8cXWzVZhroAcA1xXi-d87' # Replace with your actual Google Drive file URL
|
1227 |
+
|
1228 |
+
# Check if the model file exists
|
1229 |
+
if not os.path.exists(model_path):
|
1230 |
+
print(f"Model file not found at {model_path}. Downloading from Google Drive...")
|
1231 |
+
# Ensure the directory exists, as gdown will not automatically create directory paths
|
1232 |
+
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
1233 |
+
# Download the file from Google Drive
|
1234 |
+
# gdown.download(id=gdrive_id, output=model_path, quiet=False)
|
1235 |
+
else:
|
1236 |
+
print(f"Model file found at {model_path}. No need to download.")
|
1237 |
|
1238 |
+
image_path = "images/demo.png"
|
1239 |
+
lr_image = Image.open(image_path)
|
1240 |
+
# lr_image = transforms.functional.to_tensor(lr_image)
|
1241 |
+
|
1242 |
+
# hat = Network()
|
1243 |
+
# hat.load_network(model_path)
|
1244 |
+
# hat.inference(lr_image)
|
1245 |
+
print(HAT_for_deployment(lr_image, model_path))
|
|
|
|
|
|
|
1246 |
|