File size: 7,109 Bytes
55f7dbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dad1b70
 
55f7dbc
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

import os

os.system("pip install mmcv-full")
os.system("pip install mmdet\<3.0.0")
os.system("pip install mmrotate")

import fnmatch
import PIL
import cv2
import gradio as gr
import numpy as np
import torch
from mmdet.apis import inference_detector, init_detector, show_result_pyplot
from mim import download

import mmrotate

mmrorate_model_list = ['cfa_r50_fpn_1x_dota_le135', 'cfa_r50_fpn_40e_dota_oc',
                       'rotated_retinanet_obb_csl_gaussian_r50_fpn_fp16_1x_dota_le90',
                       'g_reppoints_r50_fpn_1x_dota_le135', 'gliding_vertex_r50_fpn_1x_dota_le90',
                       'rotated_retinanet_hbb_gwd_r50_fpn_1x_dota_oc',
                       'rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le90',
                       'rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_oc',
                       'rotated_retinanet_hbb_kfiou_r50_fpn_1x_dota_le135', 'r3det_kfiou_ln_r50_fpn_1x_dota_oc',
                       'rotated_retinanet_hbb_kld_r50_fpn_1x_dota_oc',
                       'rotated_retinanet_hbb_kld_stable_r50_fpn_1x_dota_oc',
                       'rotated_retinanet_obb_kld_stable_r50_fpn_1x_dota_le90', 'r3det_kld_r50_fpn_1x_dota_oc',
                       'r3det_kld_stable_r50_fpn_1x_dota_oc', 'r3det_tiny_kld_r50_fpn_1x_dota_oc',
                       'rotated_retinanet_hbb_kld_stable_r50_fpn_6x_hrsc_rr_oc',
                       'rotated_retinanet_obb_kld_stable_r50_fpn_6x_hrsc_rr_le90',
                       'rotated_retinanet_obb_kld_stable_r50_adamw_fpn_1x_dota_le90',
                       'oriented_rcnn_r50_fpn_fp16_1x_dota_le90', 'oriented_rcnn_r50_fpn_1x_dota_le90',
                       'r3det_r50_fpn_1x_dota_oc', 'r3det_tiny_r50_fpn_1x_dota_oc',
                       'redet_re50_refpn_fp16_1x_dota_le90', 'redet_re50_refpn_1x_dota_le90',
                       'redet_re50_refpn_1x_dota_ms_rr_le90', 'redet_re50_refpn_3x_hrsc_le90',
                       'roi_trans_r50_fpn_fp16_1x_dota_le90', 'roi_trans_r50_fpn_1x_dota_le90',
                       'roi_trans_swin_tiny_fpn_1x_dota_le90', 'roi_trans_r50_fpn_1x_dota_ms_le90',
                       'rotated_atss_hbb_r50_fpn_1x_dota_oc', 'rotated_atss_obb_r50_fpn_1x_dota_le90',
                       'rotated_atss_obb_r50_fpn_1x_dota_le135', 'rotated_faster_rcnn_r50_fpn_1x_dota_le90',
                       'rotated_fcos_sep_angle_r50_fpn_1x_dota_le90', 'rotated_fcos_r50_fpn_1x_dota_le90',
                       'rotated_fcos_csl_gaussian_r50_fpn_1x_dota_le90', 'rotated_fcos_kld_r50_fpn_1x_dota_le90',
                       'rotated_reppoints_r50_fpn_1x_dota_oc', 'rotated_retinanet_hbb_r50_fpn_1x_dota_oc',
                       'rotated_retinanet_obb_r50_fpn_1x_dota_le90', 'rotated_retinanet_obb_r50_fpn_fp16_1x_dota_le90',
                       'rotated_retinanet_obb_r50_fpn_1x_dota_le135',
                       'rotated_retinanet_obb_r50_fpn_1x_dota_ms_rr_le90',
                       'rotated_retinanet_hbb_r50_fpn_6x_hrsc_rr_oc', 'rotated_retinanet_obb_r50_fpn_6x_hrsc_rr_le90',
                       's2anet_r50_fpn_1x_dota_le135', 's2anet_r50_fpn_fp16_1x_dota_le135',
                       'sasm_reppoints_r50_fpn_1x_dota_oc']

path = "./checkpoint"
if not os.path.exists(path):
    os.makedirs(path)

def clear_folder(folder_path):
    import shutil
    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print(f"Failed to delete {file_path}. Reason: {e}")
    print(f"Clear {folder_path} successfully.")


def download_cfg_checkpoint_model_name(model_name):
    clear_folder("./checkpoint")
    download(package='mmrotate',
             configs=[model_name],
             dest_root='./checkpoint')
def download_test_image():
    # Images
    torch.hub.download_url_to_file(
        'https://user-images.githubusercontent.com/59380685/266800230-e8396b83-92a7-4367-bc4b-a36348e63dbe.jpg',
        'demo.jpg')
    torch.hub.download_url_to_file(
        'https://user-images.githubusercontent.com/59380685/266800231-d544d5ea-fc91-45d5-b79e-97bb9c717259.jpg',
        'dota_demo.jpg')

def save_image(img, img_path):
    # Convert PIL image to OpenCV image
    img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
    # Save OpenCV image
    cv2.imwrite(img_path, img)


def predict_image(image, model_name, palette, score_thr, device):
    image = np.array(image)
    save_dir = './output_img.jpg'
    download_cfg_checkpoint_model_name(model_name)
    config = [f for f in os.listdir(path) if fnmatch.fnmatch(f, "*.py")][0]
    config = path + "/" + config

    checkpoint = [f for f in os.listdir(path) if fnmatch.fnmatch(f, "*.pth")][0]
    checkpoint = path + "/" + checkpoint

    # build the model from a config file and a checkpoint file
    model = init_detector(config, checkpoint, device=device)
    result = inference_detector(model, image)
    # show the results
    show_result_pyplot(
        model,
        image,
        result,
        palette=palette,
        score_thr=score_thr,
        out_file=save_dir)
    img_out = PIL.Image.open(save_dir)
    return img_out

download_test_image()
inputs = [
    gr.inputs.Image(type='pil', label="Input Image"),
    gr.inputs.Dropdown(label="Model Name", choices=[m for m in mmrorate_model_list], default='oriented_rcnn_r50_fpn_1x_dota_le90'),
    gr.inputs.Dropdown(label="Color palette used for visualization",
                       choices=['dota', 'sar', 'hrsc', 'hrsc_classwise', 'random'], default='dota'),
    gr.inputs.Slider(label="bbox score threshold", minimum=0.0, maximum=1.0, step=0.01, default=0.3),
    gr.inputs.Dropdown(label="Device used for inference", choices=['cuda:0', 'cpu'], default='cpu'),
]

output = gr.outputs.Image(type='pil', label="Output Image")

title = "MMRotate detection web demo"
description = "<div align='center'><img src='https://raw.githubusercontent.com/open-mmlab/mmrotate/main/resources/mmrotate-logo.png' width='450''/><div>" \
              "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmrotate'>MMSegmentation</a> MMRotate 是一款基于 PyTorch 的旋转框检测的开源工具箱,是 OpenMMLab 项目的成员之一。" \
              "OpenMMLab Rotated Object Detection Toolbox and Benchmark.</p>"
article = "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmrotate'>MMRotate</a></p>" \
          "<p style='text-align: center'><a href='https://github.com/isLinXu'>gradio build by gatilin</a></a></p>"
examples = [["demo.jpg", "oriented_rcnn_r50_fpn_1x_dota_le90",'dota',0.3,'cpu'],
            ["dota_demo.jpg", "r3det_r50_fpn_1x_dota_oc",'dota',0.3,'cpu'],
            ]
gr.Interface(
    fn=predict_image,
    inputs=inputs,
    outputs=output,
    title=title,
    description=description,
    article=article,
    examples=examples,
    allow_flagging=False,
    theme="default"
).launch()