KyanChen commited on
Commit
8335262
1 Parent(s): ab01e4a

add interface

Browse files
App_main.py CHANGED
@@ -1,8 +1,107 @@
 
 
1
  import gradio as gr
2
  import os
3
 
4
- def greet(name):
5
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
8
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
  import gradio as gr
4
  import os
5
 
6
+ import torch
7
+ from torchvision import transforms
8
+ from torchvision.transforms import InterpolationMode
9
+ from STTNet import STTNet
10
+
11
+ def construct_sample(img, mean, std):
12
+ img = transforms.ToTensor()(img)
13
+ img = transforms.Resize(512, InterpolationMode.BICUBIC)(img)
14
+ img = transforms.Normalize(mean=mean, std=std)(img)
15
+
16
+ return img
17
+
18
+ def build_model(checkpoint):
19
+ model_infos = {
20
+ # vgg16_bn, resnet50, resnet18
21
+ 'backbone': 'resnet50',
22
+ 'pretrained': False,
23
+ 'out_keys': ['block4'],
24
+ 'in_channel': 3,
25
+ 'n_classes': 2,
26
+ 'top_k_s': 64,
27
+ 'top_k_c': 16,
28
+ 'encoder_pos': True,
29
+ 'decoder_pos': True,
30
+ 'model_pattern': ['X', 'A', 'S', 'C'],
31
+ }
32
+ model = STTNet(**model_infos)
33
+ state_dict = torch.load(checkpoint, map_location='cpu')
34
+ model_dict = state_dict['model_state_dict']
35
+ try:
36
+ model_dict = OrderedDict({k.replace('module.', ''): v for k, v in model_dict.items()})
37
+ model.load_state_dict(model_dict)
38
+ except Exception as e:
39
+ model.load_state_dict(model_dict)
40
+ return model
41
+
42
+
43
+ # Function for building extraction
44
+ def seg_buildings(Image, Checkpoint):
45
+ if Checkpoint == 'WHU':
46
+ mean = [0.4352682576428411, 0.44523221318154493, 0.41307610541534784]
47
+ std = [0.026973196780331585, 0.026424642808887323, 0.02791246590291434]
48
+ checkpoint = 'Pretrain/WHU_ckpt_latest.pt'
49
+ elif Checkpoint == 'INRIA':
50
+ mean = [0.40672500537632994, 0.42829032416229895, 0.39331840468605667]
51
+ std = [0.029498464618176873, 0.027740088491668233, 0.028246722411879095]
52
+ checkpoint = 'Pretrain/INRIA_ckpt_latest.pt'
53
+ else:
54
+ raise NotImplementedError
55
+ sample = construct_sample(Image, mean, std)
56
+ model = build_model(checkpoint)
57
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
58
+
59
+ model = model.to(device)
60
+ model.eval()
61
+ sample = sample.to(device)
62
+ sample = sample.unsqueeze(0)
63
+
64
+ with torch.no_grad():
65
+ logits, att_branch_output = model(sample)
66
+ pred_label = torch.argmax(logits, 1, keepdim=True)
67
+ pred_label *= 255
68
+ pred_label = pred_label[0].detach().cpu()
69
+ # pred_label = transforms.Resize(32, InterpolationMode.NEAREST)(pred_label)
70
+ pred = pred_label.numpy()[0]
71
+
72
+ return pred
73
+
74
+ title = "BuildingExtraction"
75
+ description = "Gradio Demo for Building Extraction. Upload image from INRIA or WHU Dataset or click any one of the examples, " \
76
+ "Then click \"Submit\" and wait for the segmentation result. " \
77
+ "Paper: Building Extraction from Remote Sensing Images with Sparse Token Transformers"
78
+ article = "<p style='text-align: center'><a href='https://github.com/KyanChen/BuildingExtraction' target='_blank'>STT Github " \
79
+ "Repo</a></p> "
80
+
81
+ examples = [
82
+ ['Examples/2_970.png', 'WHU'],
83
+ ['Examples/2_1139.png', 'WHU'],
84
+ ['Examples/502.png', 'WHU'],
85
+ ['Examples/austin24_460_3680.png', 'INRIA'],
86
+ ['Examples/austin36_1380_1840.png', 'INRIA'],
87
+ ['Examples/tyrol-w19_920_3220.png', 'INRIA'],
88
+ ]
89
+
90
+ with gr.Row():
91
+ image_input = gr.Image(type='pil', label='Input Img')
92
+ image_output = gr.Image(image_mode='L', shape=(32, 32), label='Segmentation Result', tool='select')
93
+ with gr.Column():
94
+ checkpoint = gr.inputs.Radio(['WHU', 'INRIA'], label='Checkpoint')
95
 
96
+ io = gr.Interface(fn=seg_buildings,
97
+ inputs=[image_input,
98
+ checkpoint],
99
+ outputs=image_output,
100
+ title=title,
101
+ description=description,
102
+ article=article,
103
+ allow_flagging='auto',
104
+ examples=examples,
105
+ cache_examples=True
106
+ )
107
+ io.launch()
Examples/2_1139.png ADDED
Examples/2_12.png ADDED
Examples/2_775.png ADDED
Examples/2_970.png ADDED
Examples/502.png ADDED
Examples/austin24_460_3680.png ADDED
Examples/austin36_1380_1840.png ADDED
Examples/austin9_0_3680.png ADDED
Examples/tyrol-w19_920_3220.png ADDED
Examples/vienna26_3220_1840.png ADDED