Spaces:
Runtime error
Runtime error
import os | |
os.system('pip install -U openmim') | |
os.system('mim install mmcv') | |
import glob | |
import mmcv | |
import mmengine | |
import numpy as np | |
from mmengine import Config, get | |
from mmengine.dataset import Compose | |
from mmpl.registry import MODELS, VISUALIZERS | |
from mmpl.utils import register_all_modules | |
register_all_modules() | |
# os.system('nvidia-smi') | |
# os.system('ls /usr/local') | |
# | |
import gradio as gr | |
import torch | |
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
def construct_sample(img, pipeline): | |
img = np.array(img)[:, :, ::-1] | |
inputs = { | |
'ori_shape': img.shape[:2], | |
'img': img, | |
} | |
pipeline = Compose(pipeline) | |
sample = pipeline(inputs) | |
return sample | |
def build_model(cp, model_cfg): | |
model_cpkt = torch.load(cp, map_location='cpu') | |
model = MODELS.build(model_cfg) | |
model.load_state_dict(model_cpkt, strict=True) | |
model.to(device=device) | |
model.eval() | |
return model | |
# Function for building extraction | |
def inference_func(ori_img, cp): | |
checkpoint = f'pretrain/{cp}_anchor.pth' | |
cfg = f'configs/huggingface/rsprompter_anchor_{cp}_config.py' | |
cfg = Config.fromfile(cfg) | |
sample = construct_sample(ori_img, cfg.predict_pipeline) | |
sample['inputs'] = [sample['inputs']] | |
sample['data_samples'] = [sample['data_samples']] | |
print('Use: ', device) | |
model = build_model(checkpoint, cfg.model_cfg) | |
with torch.no_grad(): | |
pred_results = model.predict_step(sample, batch_idx=0) | |
cfg.visualizer.setdefault('save_dir', 'visualizer') | |
visualizer = VISUALIZERS.build(cfg.visualizer) | |
data_sample = pred_results[0] | |
img = np.array(ori_img).copy() | |
out_file = 'visualizer/test_img.jpg' | |
mmengine.mkdir_or_exist(os.path.dirname(out_file)) | |
visualizer.add_datasample( | |
'test_img', | |
img, | |
draw_gt=False, | |
data_sample=data_sample, | |
show=False, | |
wait_time=0.01, | |
pred_score_thr=0.4, | |
out_file=out_file | |
) | |
img_bytes = get(out_file) | |
img = mmcv.imfrombytes(img_bytes, channel_order='rgb') | |
return img | |
title = "RSPrompter" | |
description = "Gradio demo for RSPrompter. Upload image from WHU building dataset, NWPU dataset, or SSDD Dataset or click any one of the examples, " \ | |
"Then select the prompt model, and click \"Submit\" and wait for the result. \n \n" \ | |
"Paper: RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model" | |
article = "<p style='text-align: center'><a href='https://kyanchen.github.io/RSPrompter/' target='_blank'>RSPrompter Project " \ | |
"Page</a></p> " | |
files = glob.glob('examples/*') | |
examples = [[f, f.split('/')[-1].split('_')[0]] for f in files] | |
with gr.Blocks() as demo: | |
image_input = gr.Image(type='pil', label='Input Img') | |
# with gr.Row().style(equal_height=True): | |
# image_LR_output = gr.outputs.Image(label='LR Img', type='numpy') | |
image_output = gr.Image(label='Segment Result', type='numpy') | |
with gr.Row(): | |
checkpoint = gr.Radio(['WHU', 'NWPU', 'SSDD'], label='Checkpoint') | |
io = gr.Interface(fn=inference_func, | |
inputs=[image_input, checkpoint], | |
outputs=[image_output], | |
title=title, | |
description=description, | |
article=article, | |
allow_flagging='auto', | |
examples=examples, | |
cache_examples=True, | |
layout="grid" | |
) | |
io.launch() | |