BuildingExtraction / App_main.py
KyanChen's picture
add requirements
649b010
raw
history blame
3.98 kB
from collections import OrderedDict
import gradio as gr
import os
os.system('nvidia-smi')
os.system('ls /usr/local')
os.system('pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113')
import torch
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from STTNet import STTNet
def construct_sample(img, mean, std):
img = transforms.ToTensor()(img)
img = transforms.Resize(512, InterpolationMode.BICUBIC)(img)
img = transforms.Normalize(mean=mean, std=std)(img)
return img
def build_model(checkpoint):
model_infos = {
# vgg16_bn, resnet50, resnet18
'backbone': 'resnet50',
'pretrained': False,
'out_keys': ['block4'],
'in_channel': 3,
'n_classes': 2,
'top_k_s': 64,
'top_k_c': 16,
'encoder_pos': True,
'decoder_pos': True,
'model_pattern': ['X', 'A', 'S', 'C'],
}
model = STTNet(**model_infos)
state_dict = torch.load(checkpoint, map_location='cpu')
model_dict = state_dict['model_state_dict']
try:
model_dict = OrderedDict({k.replace('module.', ''): v for k, v in model_dict.items()})
model.load_state_dict(model_dict)
except Exception as e:
model.load_state_dict(model_dict)
return model
# Function for building extraction
def seg_buildings(Image, Checkpoint):
if Checkpoint == 'WHU':
mean = [0.4352682576428411, 0.44523221318154493, 0.41307610541534784]
std = [0.026973196780331585, 0.026424642808887323, 0.02791246590291434]
checkpoint = 'Pretrain/WHU_ckpt_latest.pt'
elif Checkpoint == 'INRIA':
mean = [0.40672500537632994, 0.42829032416229895, 0.39331840468605667]
std = [0.029498464618176873, 0.027740088491668233, 0.028246722411879095]
checkpoint = 'Pretrain/INRIA_ckpt_latest.pt'
else:
raise NotImplementedError
sample = construct_sample(Image, mean, std)
model = build_model(checkpoint)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('Use: ', device)
model = model.to(device)
model.eval()
sample = sample.to(device)
sample = sample.unsqueeze(0)
with torch.no_grad():
logits, att_branch_output = model(sample)
pred_label = torch.argmax(logits, 1, keepdim=True)
pred_label *= 255
pred_label = pred_label[0].detach().cpu()
# pred_label = transforms.Resize(32, InterpolationMode.NEAREST)(pred_label)
pred = pred_label.numpy()[0]
return pred
title = "BuildingExtraction"
description = "Gradio Demo for Building Extraction. Upload image from INRIA or WHU Dataset or click any one of the examples, " \
"Then click \"Submit\" and wait for the segmentation result. " \
"Paper: Building Extraction from Remote Sensing Images with Sparse Token Transformers"
article = "<p style='text-align: center'><a href='https://github.com/KyanChen/BuildingExtraction' target='_blank'>STT Github " \
"Repo</a></p> "
examples = [
['Examples/2_970.png', 'WHU'],
['Examples/2_1139.png', 'WHU'],
['Examples/502.png', 'WHU'],
['Examples/austin24_460_3680.png', 'INRIA'],
['Examples/austin36_1380_1840.png', 'INRIA'],
['Examples/tyrol-w19_920_3220.png', 'INRIA'],
]
with gr.Row():
image_input = gr.Image(type='pil', label='Input Img')
image_output = gr.Image(image_mode='L', shape=(32, 32), label='Segmentation Result', tool='select')
with gr.Column():
checkpoint = gr.inputs.Radio(['WHU', 'INRIA'], label='Checkpoint')
io = gr.Interface(fn=seg_buildings,
inputs=[image_input,
checkpoint],
outputs=image_output,
title=title,
description=description,
article=article,
allow_flagging='auto',
examples=examples,
cache_examples=True
)
io.launch()