FunSR / App_main.py
KyanChen's picture
Update App_main.py
f214b9c
raw
history blame contribute delete
No virus
4.06 kB
import os
os.system('mim install mmcv')
import numpy as np
import models
import gradio as gr
import torch
from torchvision import transforms
from torchvision.transforms import InterpolationMode
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
def construct_sample(img, mean=0.5, std=0.5):
img = transforms.ToTensor()(img)
img = transforms.Resize((48, 48), InterpolationMode.BICUBIC)(img)
img = transforms.Normalize(mean, std)(img)
return img
def build_model(cp):
model_spec = torch.load(cp, map_location='cpu')['model']
print(model_spec['args'])
model = models.make(model_spec, load_sd=True).to(device)
return model
# Function for building extraction
def sr_func(img, cp, scale):
if cp == 'UC':
checkpoint = 'pretrain/UC_FunSR_RDN.pth'
elif cp == 'AID':
checkpoint = 'pretrain/AID_FunSR_RDN.pth'
else:
raise NotImplementedError
sample = construct_sample(img)
print('Use: ', device)
model = build_model(checkpoint)
model.eval()
sample = sample.to(device)
sample = sample.unsqueeze(0)
ori_size = torch.tensor(sample.shape[2:]) # BCHW
target_size = ori_size * scale
target_size = target_size.long()
lr_target_size_img = torch.nn.functional.interpolate(sample, scale_factor=scale, mode='nearest')
with torch.no_grad():
pred = model(sample, target_size.tolist())
if isinstance(pred, list):
pred = pred[-1]
pred = pred * 0.5 + 0.5
pred *= 255
pred = pred[0].detach().cpu()
lr_target_size_img = lr_target_size_img * 0.5 + 0.5
lr_target_size_img = 255 * lr_target_size_img[0].detach().cpu()
lr_target_size_img = torch.clamp(lr_target_size_img, 0, 255).permute(1,2,0).numpy().astype(np.uint8)
pred = torch.clamp(pred, 0, 255).permute(1,2,0).numpy().astype(np.uint8)
line = np.ones((pred.shape[0], 5, 3), dtype=np.uint8) * 255
pred = np.concatenate((lr_target_size_img, line, pred), axis=1)
return pred
title = "FunSR"
description = "Gradio demo for continuous remote sensing image super-resolution. Upload image from UCMerced or AID Dataset or click any one of the examples, " \
"Then change the upscaling magnification, and click \"Submit\" and wait for the super-resolved result. \n" \
"Paper: Continuous Remote Sensing Image Super-Resolution based on Context Interaction in Implicit Function Space"
article = "<p style='text-align: center'><a href='https://kyanchen.github.io/FunSR/' target='_blank'>FunSR Project " \
"Page</a></p> "
default_scale = 2.0
examples = [
['examples/AID_school_161_LR.png', 'AID', default_scale],
['examples/AID_bridge_19_LR.png', 'AID', default_scale],
['examples/AID_parking_60_LR.png', 'AID', default_scale],
['examples/AID_commercial_32_LR.png', 'AID', default_scale],
['examples/UC_airplane95_LR.png', 'UC', default_scale],
['examples/UC_freeway35_LR.png', 'UC', default_scale],
['examples/UC_storagetanks54_LR.png', 'UC', default_scale],
['examples/UC_airplane00_LR.png', 'UC', default_scale],
]
with gr.Blocks() as demo:
image_input = gr.Image(type='pil', label='Input Img')
# with gr.Row().style(equal_height=True):
# image_LR_output = gr.outputs.Image(label='LR Img', type='numpy')
image_output = gr.Image(label='SR Result', type='numpy')
with gr.Row():
checkpoint = gr.Radio(['UC', 'AID'], label='Checkpoint')
scale = gr.Slider(1, 10, value=4.0, step=0.1, label='scale')
io = gr.Interface(fn=sr_func,
inputs=[image_input,
checkpoint,
scale
],
outputs=[
# image_LR_output,
image_output
],
title=title,
description=description,
article=article,
allow_flagging='auto',
examples=examples,
cache_examples=True,
)
io.launch()