Spaces:
Runtime error
Runtime error
File size: 7,109 Bytes
55f7dbc dad1b70 55f7dbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
os.system("pip install mmcv-full")
os.system("pip install mmdet\<3.0.0")
os.system("pip install mmrotate")
import fnmatch
import PIL
import cv2
import gradio as gr
import numpy as np
import torch
from mmdet.apis import inference_detector, init_detector, show_result_pyplot
from mim import download
import mmrotate
mmrorate_model_list = ['cfa_r50_fpn_1x_dota_le135', 'cfa_r50_fpn_40e_dota_oc',
'rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90',
'g_reppoints_r50_fpn_1x_dota_le135', 'gliding_vertex_r50_fpn_1x_dota_le90',
'rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc',
'rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90',
'rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_oc',
'rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le135', 'r3det_kfiou_ln_r50_fpn_1x_dota_oc',
'rotated_retinanet_hbb_kld_r50_fpn_1x_dota_oc',
'rotated_retinanet_hbb_kld_stable_r50_fpn_1x_dota_oc',
'rotated_retinanet_obb_kld_stable_r50_fpn_1x_dota_le90', 'r3det_kld_r50_fpn_1x_dota_oc',
'r3det_kld_stable_r50_fpn_1x_dota_oc', 'r3det_tiny_kld_r50_fpn_1x_dota_oc',
'rotated_retinanet_hbb_kld_stable_r50_fpn_6x_hrsc_rr_oc',
'rotated_retinanet_obb_kld_stable_r50_fpn_6x_hrsc_rr_le90',
'rotated_retinanet_obb_kld_stable_r50_adamw_fpn_1x_dota_le90',
'oriented_rcnn_r50_fpn_fp16_1x_dota_le90', 'oriented_rcnn_r50_fpn_1x_dota_le90',
'r3det_r50_fpn_1x_dota_oc', 'r3det_tiny_r50_fpn_1x_dota_oc',
'redet_re50_refpn_fp16_1x_dota_le90', 'redet_re50_refpn_1x_dota_le90',
'redet_re50_refpn_1x_dota_ms_rr_le90', 'redet_re50_refpn_3x_hrsc_le90',
'roi_trans_r50_fpn_fp16_1x_dota_le90', 'roi_trans_r50_fpn_1x_dota_le90',
'roi_trans_swin_tiny_fpn_1x_dota_le90', 'roi_trans_r50_fpn_1x_dota_ms_le90',
'rotated_atss_hbb_r50_fpn_1x_dota_oc', 'rotated_atss_obb_r50_fpn_1x_dota_le90',
'rotated_atss_obb_r50_fpn_1x_dota_le135', 'rotated_faster_rcnn_r50_fpn_1x_dota_le90',
'rotated_fcos_sep_angle_r50_fpn_1x_dota_le90', 'rotated_fcos_r50_fpn_1x_dota_le90',
'rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90', 'rotated_fcos_kld_r50_fpn_1x_dota_le90',
'rotated_reppoints_r50_fpn_1x_dota_oc', 'rotated_retinanet_hbb_r50_fpn_1x_dota_oc',
'rotated_retinanet_obb_r50_fpn_1x_dota_le90', 'rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90',
'rotated_retinanet_obb_r50_fpn_1x_dota_le135',
'rotated_retinanet_obb_r50_fpn_1x_dota_ms_rr_le90',
'rotated_retinanet_hbb_r50_fpn_6x_hrsc_rr_oc', 'rotated_retinanet_obb_r50_fpn_6x_hrsc_rr_le90',
's2anet_r50_fpn_1x_dota_le135', 's2anet_r50_fpn_fp16_1x_dota_le135',
'sasm_reppoints_r50_fpn_1x_dota_oc']
path = "./checkpoint"
if not os.path.exists(path):
os.makedirs(path)
def clear_folder(folder_path):
import shutil
for filename in os.listdir(folder_path):
file_path = os.path.join(folder_path, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
print(f"Clear {folder_path} successfully.")
def download_cfg_checkpoint_model_name(model_name):
clear_folder("./checkpoint")
download(package='mmrotate',
configs=[model_name],
dest_root='./checkpoint')
def download_test_image():
# Images
torch.hub.download_url_to_file(
'https://user-images.githubusercontent.com/59380685/266800230-e8396b83-92a7-4367-bc4b-a36348e63dbe.jpg',
'demo.jpg')
torch.hub.download_url_to_file(
'https://user-images.githubusercontent.com/59380685/266800231-d544d5ea-fc91-45d5-b79e-97bb9c717259.jpg',
'dota_demo.jpg')
def save_image(img, img_path):
# Convert PIL image to OpenCV image
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
# Save OpenCV image
cv2.imwrite(img_path, img)
def predict_image(image, model_name, palette, score_thr, device):
image = np.array(image)
save_dir = './output_img.jpg'
download_cfg_checkpoint_model_name(model_name)
config = [f for f in os.listdir(path) if fnmatch.fnmatch(f, "*.py")][0]
config = path + "/" + config
checkpoint = [f for f in os.listdir(path) if fnmatch.fnmatch(f, "*.pth")][0]
checkpoint = path + "/" + checkpoint
# build the model from a config file and a checkpoint file
model = init_detector(config, checkpoint, device=device)
result = inference_detector(model, image)
# show the results
show_result_pyplot(
model,
image,
result,
palette=palette,
score_thr=score_thr,
out_file=save_dir)
img_out = PIL.Image.open(save_dir)
return img_out
download_test_image()
inputs = [
gr.inputs.Image(type='pil', label="Input Image"),
gr.inputs.Dropdown(label="Model Name", choices=[m for m in mmrorate_model_list], default='oriented_rcnn_r50_fpn_1x_dota_le90'),
gr.inputs.Dropdown(label="Color palette used for visualization",
choices=['dota', 'sar', 'hrsc', 'hrsc_classwise', 'random'], default='dota'),
gr.inputs.Slider(label="bbox score threshold", minimum=0.0, maximum=1.0, step=0.01, default=0.3),
gr.inputs.Dropdown(label="Device used for inference", choices=['cuda:0', 'cpu'], default='cpu'),
]
output = gr.outputs.Image(type='pil', label="Output Image")
title = "MMRotate detection web demo"
description = "<div align='center'><img src='https://raw.githubusercontent.com/open-mmlab/mmrotate/main/resources/mmrotate-logo.png' width='450''/><div>" \
"<p style='text-align: center'><a href='https://github.com/open-mmlab/mmrotate'>MMSegmentation</a> MMRotate 是一款基于 PyTorch 的旋转框检测的开源工具箱,是 OpenMMLab 项目的成员之一。" \
"OpenMMLab Rotated Object Detection Toolbox and Benchmark.</p>"
article = "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmrotate'>MMRotate</a></p>" \
"<p style='text-align: center'><a href='https://github.com/isLinXu'>gradio build by gatilin</a></a></p>"
examples = [["demo.jpg", "oriented_rcnn_r50_fpn_1x_dota_le90",'dota',0.3,'cpu'],
["dota_demo.jpg", "r3det_r50_fpn_1x_dota_oc",'dota',0.3,'cpu'],
]
gr.Interface(
fn=predict_image,
inputs=inputs,
outputs=output,
title=title,
description=description,
article=article,
examples=examples,
allow_flagging=False,
theme="default"
).launch()
|