from collections import OrderedDict import gradio as gr import os os.system('nvidia-smi') os.system('ls /usr/local') os.system('pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113') import torch from torchvision import transforms from torchvision.transforms import InterpolationMode from STTNet import STTNet def construct_sample(img, mean, std): img = transforms.ToTensor()(img) img = transforms.Resize(512, InterpolationMode.BICUBIC)(img) img = transforms.Normalize(mean=mean, std=std)(img) return img def build_model(checkpoint): model_infos = { # vgg16_bn, resnet50, resnet18 'backbone': 'resnet50', 'pretrained': False, 'out_keys': ['block4'], 'in_channel': 3, 'n_classes': 2, 'top_k_s': 64, 'top_k_c': 16, 'encoder_pos': True, 'decoder_pos': True, 'model_pattern': ['X', 'A', 'S', 'C'], } model = STTNet(**model_infos) state_dict = torch.load(checkpoint, map_location='cpu') model_dict = state_dict['model_state_dict'] try: model_dict = OrderedDict({k.replace('module.', ''): v for k, v in model_dict.items()}) model.load_state_dict(model_dict) except Exception as e: model.load_state_dict(model_dict) return model # Function for building extraction def seg_buildings(Image, Checkpoint): if Checkpoint == 'WHU': mean = [0.4352682576428411, 0.44523221318154493, 0.41307610541534784] std = [0.026973196780331585, 0.026424642808887323, 0.02791246590291434] checkpoint = 'Pretrain/WHU_ckpt_latest.pt' elif Checkpoint == 'INRIA': mean = [0.40672500537632994, 0.42829032416229895, 0.39331840468605667] std = [0.029498464618176873, 0.027740088491668233, 0.028246722411879095] checkpoint = 'Pretrain/INRIA_ckpt_latest.pt' else: raise NotImplementedError sample = construct_sample(Image, mean, std) model = build_model(checkpoint) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' print('Use: ', device) model = model.to(device) model.eval() sample = sample.to(device) sample = sample.unsqueeze(0) with torch.no_grad(): logits, att_branch_output = model(sample) pred_label = torch.argmax(logits, 1, keepdim=True) pred_label *= 255 pred_label = pred_label[0].detach().cpu() # pred_label = transforms.Resize(32, InterpolationMode.NEAREST)(pred_label) pred = pred_label.numpy()[0] return pred title = "BuildingExtraction" description = "Gradio Demo for Building Extraction. Upload image from INRIA or WHU Dataset or click any one of the examples, " \ "Then click \"Submit\" and wait for the segmentation result. " \ "Paper: Building Extraction from Remote Sensing Images with Sparse Token Transformers" article = "

STT Github " \ "Repo

" examples = [ ['Examples/2_970.png', 'WHU'], ['Examples/2_1139.png', 'WHU'], ['Examples/502.png', 'WHU'], ['Examples/austin24_460_3680.png', 'INRIA'], ['Examples/austin36_1380_1840.png', 'INRIA'], ['Examples/tyrol-w19_920_3220.png', 'INRIA'], ] with gr.Row(): image_input = gr.Image(type='pil', label='Input Img') image_output = gr.Image(image_mode='L', shape=(32, 32), label='Segmentation Result', tool='select') with gr.Column(): checkpoint = gr.inputs.Radio(['WHU', 'INRIA'], label='Checkpoint') io = gr.Interface(fn=seg_buildings, inputs=[image_input, checkpoint], outputs=image_output, title=title, description=description, article=article, allow_flagging='auto', examples=examples, cache_examples=True ) io.launch()