Spaces:
Runtime error
Runtime error
add interface
Browse files- App_main.py +103 -4
- Examples/2_1139.png +0 -0
- Examples/2_12.png +0 -0
- Examples/2_775.png +0 -0
- Examples/2_970.png +0 -0
- Examples/502.png +0 -0
- Examples/austin24_460_3680.png +0 -0
- Examples/austin36_1380_1840.png +0 -0
- Examples/austin9_0_3680.png +0 -0
- Examples/tyrol-w19_920_3220.png +0 -0
- Examples/vienna26_3220_1840.png +0 -0
App_main.py
CHANGED
@@ -1,8 +1,107 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
import gradio as gr
|
4 |
import os
|
5 |
|
6 |
+
import torch
|
7 |
+
from torchvision import transforms
|
8 |
+
from torchvision.transforms import InterpolationMode
|
9 |
+
from STTNet import STTNet
|
10 |
+
|
11 |
+
def construct_sample(img, mean, std):
|
12 |
+
img = transforms.ToTensor()(img)
|
13 |
+
img = transforms.Resize(512, InterpolationMode.BICUBIC)(img)
|
14 |
+
img = transforms.Normalize(mean=mean, std=std)(img)
|
15 |
+
|
16 |
+
return img
|
17 |
+
|
18 |
+
def build_model(checkpoint):
|
19 |
+
model_infos = {
|
20 |
+
# vgg16_bn, resnet50, resnet18
|
21 |
+
'backbone': 'resnet50',
|
22 |
+
'pretrained': False,
|
23 |
+
'out_keys': ['block4'],
|
24 |
+
'in_channel': 3,
|
25 |
+
'n_classes': 2,
|
26 |
+
'top_k_s': 64,
|
27 |
+
'top_k_c': 16,
|
28 |
+
'encoder_pos': True,
|
29 |
+
'decoder_pos': True,
|
30 |
+
'model_pattern': ['X', 'A', 'S', 'C'],
|
31 |
+
}
|
32 |
+
model = STTNet(**model_infos)
|
33 |
+
state_dict = torch.load(checkpoint, map_location='cpu')
|
34 |
+
model_dict = state_dict['model_state_dict']
|
35 |
+
try:
|
36 |
+
model_dict = OrderedDict({k.replace('module.', ''): v for k, v in model_dict.items()})
|
37 |
+
model.load_state_dict(model_dict)
|
38 |
+
except Exception as e:
|
39 |
+
model.load_state_dict(model_dict)
|
40 |
+
return model
|
41 |
+
|
42 |
+
|
43 |
+
# Function for building extraction
|
44 |
+
def seg_buildings(Image, Checkpoint):
|
45 |
+
if Checkpoint == 'WHU':
|
46 |
+
mean = [0.4352682576428411, 0.44523221318154493, 0.41307610541534784]
|
47 |
+
std = [0.026973196780331585, 0.026424642808887323, 0.02791246590291434]
|
48 |
+
checkpoint = 'Pretrain/WHU_ckpt_latest.pt'
|
49 |
+
elif Checkpoint == 'INRIA':
|
50 |
+
mean = [0.40672500537632994, 0.42829032416229895, 0.39331840468605667]
|
51 |
+
std = [0.029498464618176873, 0.027740088491668233, 0.028246722411879095]
|
52 |
+
checkpoint = 'Pretrain/INRIA_ckpt_latest.pt'
|
53 |
+
else:
|
54 |
+
raise NotImplementedError
|
55 |
+
sample = construct_sample(Image, mean, std)
|
56 |
+
model = build_model(checkpoint)
|
57 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
58 |
+
|
59 |
+
model = model.to(device)
|
60 |
+
model.eval()
|
61 |
+
sample = sample.to(device)
|
62 |
+
sample = sample.unsqueeze(0)
|
63 |
+
|
64 |
+
with torch.no_grad():
|
65 |
+
logits, att_branch_output = model(sample)
|
66 |
+
pred_label = torch.argmax(logits, 1, keepdim=True)
|
67 |
+
pred_label *= 255
|
68 |
+
pred_label = pred_label[0].detach().cpu()
|
69 |
+
# pred_label = transforms.Resize(32, InterpolationMode.NEAREST)(pred_label)
|
70 |
+
pred = pred_label.numpy()[0]
|
71 |
+
|
72 |
+
return pred
|
73 |
+
|
74 |
+
title = "BuildingExtraction"
|
75 |
+
description = "Gradio Demo for Building Extraction. Upload image from INRIA or WHU Dataset or click any one of the examples, " \
|
76 |
+
"Then click \"Submit\" and wait for the segmentation result. " \
|
77 |
+
"Paper: Building Extraction from Remote Sensing Images with Sparse Token Transformers"
|
78 |
+
article = "<p style='text-align: center'><a href='https://github.com/KyanChen/BuildingExtraction' target='_blank'>STT Github " \
|
79 |
+
"Repo</a></p> "
|
80 |
+
|
81 |
+
examples = [
|
82 |
+
['Examples/2_970.png', 'WHU'],
|
83 |
+
['Examples/2_1139.png', 'WHU'],
|
84 |
+
['Examples/502.png', 'WHU'],
|
85 |
+
['Examples/austin24_460_3680.png', 'INRIA'],
|
86 |
+
['Examples/austin36_1380_1840.png', 'INRIA'],
|
87 |
+
['Examples/tyrol-w19_920_3220.png', 'INRIA'],
|
88 |
+
]
|
89 |
+
|
90 |
+
with gr.Row():
|
91 |
+
image_input = gr.Image(type='pil', label='Input Img')
|
92 |
+
image_output = gr.Image(image_mode='L', shape=(32, 32), label='Segmentation Result', tool='select')
|
93 |
+
with gr.Column():
|
94 |
+
checkpoint = gr.inputs.Radio(['WHU', 'INRIA'], label='Checkpoint')
|
95 |
|
96 |
+
io = gr.Interface(fn=seg_buildings,
|
97 |
+
inputs=[image_input,
|
98 |
+
checkpoint],
|
99 |
+
outputs=image_output,
|
100 |
+
title=title,
|
101 |
+
description=description,
|
102 |
+
article=article,
|
103 |
+
allow_flagging='auto',
|
104 |
+
examples=examples,
|
105 |
+
cache_examples=True
|
106 |
+
)
|
107 |
+
io.launch()
|
Examples/2_1139.png
ADDED
Examples/2_12.png
ADDED
Examples/2_775.png
ADDED
Examples/2_970.png
ADDED
Examples/502.png
ADDED
Examples/austin24_460_3680.png
ADDED
Examples/austin36_1380_1840.png
ADDED
Examples/austin9_0_3680.png
ADDED
Examples/tyrol-w19_920_3220.png
ADDED
Examples/vienna26_3220_1840.png
ADDED