hysts HF staff commited on
Commit
d2aba00
1 Parent(s): 85fcf0d
Files changed (4) hide show
  1. .pre-commit-config.yaml +2 -12
  2. README.md +1 -1
  3. app.py +75 -101
  4. model.py +4 -3
.pre-commit-config.yaml CHANGED
@@ -21,11 +21,11 @@ repos:
21
  - id: docformatter
22
  args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
- rev: 5.10.1
25
  hooks:
26
  - id: isort
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.812
29
  hooks:
30
  - id: mypy
31
  args: ['--ignore-missing-imports']
@@ -34,13 +34,3 @@ repos:
34
  hooks:
35
  - id: yapf
36
  args: ['--parallel', '--in-place']
37
- - repo: https://github.com/kynan/nbstripout
38
- rev: 0.5.0
39
- hooks:
40
- - id: nbstripout
41
- args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
- - repo: https://github.com/nbQA-dev/nbQA
43
- rev: 1.3.1
44
- hooks:
45
- - id: nbqa-isort
46
- - id: nbqa-yapf
 
21
  - id: docformatter
22
  args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
  hooks:
26
  - id: isort
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
  hooks:
30
  - id: mypy
31
  args: ['--ignore-missing-imports']
 
34
  hooks:
35
  - id: yapf
36
  args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: pink
5
  colorTo: purple
6
  sdk: gradio
7
  python_version: 3.9.13
8
- sdk_version: 3.0.11
9
  app_file: app.py
10
  pinned: false
11
  ---
 
5
  colorTo: purple
6
  sdk: gradio
7
  python_version: 3.9.13
8
+ sdk_version: 3.19.1
9
  app_file: app.py
10
  pinned: false
11
  ---
app.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
  import os
7
  import pathlib
8
  import subprocess
@@ -29,7 +28,6 @@ DESCRIPTION = '''# MMDetection
29
  This is an unofficial demo for [https://github.com/open-mmlab/mmdetection](https://github.com/open-mmlab/mmdetection).
30
  <img id="overview" alt="overview" src="https://user-images.githubusercontent.com/12907710/137271636-56ba1cd2-b110-4812-8221-b4c120320aa9.png" />
31
  '''
32
- FOOTER = '<img id="visitor-badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.mmdetection" alt="visitor badge" />'
33
 
34
  DEFAULT_MODEL_TYPE = 'detection'
35
  DEFAULT_MODEL_NAMES = {
@@ -40,18 +38,6 @@ DEFAULT_MODEL_NAMES = {
40
  DEFAULT_MODEL_NAME = DEFAULT_MODEL_NAMES[DEFAULT_MODEL_TYPE]
41
 
42
 
43
- def parse_args() -> argparse.Namespace:
44
- parser = argparse.ArgumentParser()
45
- parser.add_argument('--device', type=str, default='cpu')
46
- parser.add_argument('--theme', type=str)
47
- parser.add_argument('--share', action='store_true')
48
- parser.add_argument('--port', type=int)
49
- parser.add_argument('--disable-queue',
50
- dest='enable_queue',
51
- action='store_false')
52
- return parser.parse_args()
53
-
54
-
55
  def extract_tar() -> None:
56
  if pathlib.Path('mmdet_configs/configs').exists():
57
  return
@@ -87,94 +73,82 @@ def set_example_image(example: list) -> dict:
87
  return gr.Image.update(value=example[0])
88
 
89
 
90
- def main():
91
- args = parse_args()
92
- extract_tar()
93
- model = AppModel(DEFAULT_MODEL_NAME, args.device)
94
 
95
- with gr.Blocks(theme=args.theme, css='style.css') as demo:
96
- gr.Markdown(DESCRIPTION)
97
 
98
- with gr.Row():
99
- with gr.Column():
100
- with gr.Row():
101
- input_image = gr.Image(label='Input Image', type='numpy')
102
- with gr.Group():
103
- with gr.Row():
104
- model_type = gr.Radio(list(DEFAULT_MODEL_NAMES.keys()),
105
- value=DEFAULT_MODEL_TYPE,
106
- label='Model Type')
107
- with gr.Row():
108
- model_name = gr.Dropdown(list(
109
- model.DETECTION_MODEL_DICT.keys()),
110
- value=DEFAULT_MODEL_NAME,
111
- label='Model')
112
- with gr.Row():
113
- run_button = gr.Button(value='Run')
114
- prediction_results = gr.Variable()
115
- with gr.Column():
116
- with gr.Row():
117
- visualization = gr.Image(label='Result', type='numpy')
118
  with gr.Row():
119
- visualization_score_threshold = gr.Slider(
120
- 0,
121
- 1,
122
- step=0.05,
123
- value=0.3,
124
- label='Visualization Score Threshold')
125
  with gr.Row():
126
- redraw_button = gr.Button(value='Redraw')
127
-
128
- with gr.Row():
129
- paths = sorted(pathlib.Path('images').rglob('*.jpg'))
130
- example_images = gr.Dataset(components=[input_image],
131
- samples=[[path.as_posix()]
132
- for path in paths])
133
-
134
- gr.Markdown(FOOTER)
135
-
136
- input_image.change(fn=update_input_image,
137
- inputs=input_image,
138
- outputs=input_image)
139
-
140
- model_type.change(fn=update_model_name,
141
- inputs=model_type,
142
- outputs=model_name)
143
- model_type.change(fn=update_visualization_score_threshold,
144
- inputs=model_type,
145
- outputs=visualization_score_threshold)
146
- model_type.change(fn=update_redraw_button,
147
- inputs=model_type,
148
- outputs=redraw_button)
149
-
150
- model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
151
- run_button.click(fn=model.run,
152
- inputs=[
153
- model_name,
154
- input_image,
155
- visualization_score_threshold,
156
- ],
157
- outputs=[
158
- prediction_results,
159
- visualization,
160
- ])
161
- redraw_button.click(fn=model.visualize_detection_results,
162
- inputs=[
163
- input_image,
164
- prediction_results,
165
- visualization_score_threshold,
166
- ],
167
- outputs=visualization)
168
- example_images.click(fn=set_example_image,
169
- inputs=example_images,
170
- outputs=input_image)
171
-
172
- demo.launch(
173
- enable_queue=args.enable_queue,
174
- server_port=args.port,
175
- share=args.share,
176
- )
177
-
178
-
179
- if __name__ == '__main__':
180
- main()
 
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import os
6
  import pathlib
7
  import subprocess
 
28
  This is an unofficial demo for [https://github.com/open-mmlab/mmdetection](https://github.com/open-mmlab/mmdetection).
29
  <img id="overview" alt="overview" src="https://user-images.githubusercontent.com/12907710/137271636-56ba1cd2-b110-4812-8221-b4c120320aa9.png" />
30
  '''
 
31
 
32
  DEFAULT_MODEL_TYPE = 'detection'
33
  DEFAULT_MODEL_NAMES = {
 
38
  DEFAULT_MODEL_NAME = DEFAULT_MODEL_NAMES[DEFAULT_MODEL_TYPE]
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def extract_tar() -> None:
42
  if pathlib.Path('mmdet_configs/configs').exists():
43
  return
 
73
  return gr.Image.update(value=example[0])
74
 
75
 
76
+ extract_tar()
77
+ model = AppModel(DEFAULT_MODEL_NAME)
 
 
78
 
79
+ with gr.Blocks(css='style.css') as demo:
80
+ gr.Markdown(DESCRIPTION)
81
 
82
+ with gr.Row():
83
+ with gr.Column():
84
+ with gr.Row():
85
+ input_image = gr.Image(label='Input Image', type='numpy')
86
+ with gr.Group():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  with gr.Row():
88
+ model_type = gr.Radio(list(DEFAULT_MODEL_NAMES.keys()),
89
+ value=DEFAULT_MODEL_TYPE,
90
+ label='Model Type')
 
 
 
91
  with gr.Row():
92
+ model_name = gr.Dropdown(list(
93
+ model.DETECTION_MODEL_DICT.keys()),
94
+ value=DEFAULT_MODEL_NAME,
95
+ label='Model')
96
+ with gr.Row():
97
+ run_button = gr.Button(value='Run')
98
+ prediction_results = gr.Variable()
99
+ with gr.Column():
100
+ with gr.Row():
101
+ visualization = gr.Image(label='Result', type='numpy')
102
+ with gr.Row():
103
+ visualization_score_threshold = gr.Slider(
104
+ 0,
105
+ 1,
106
+ step=0.05,
107
+ value=0.3,
108
+ label='Visualization Score Threshold')
109
+ with gr.Row():
110
+ redraw_button = gr.Button(value='Redraw')
111
+
112
+ with gr.Row():
113
+ paths = sorted(pathlib.Path('images').rglob('*.jpg'))
114
+ example_images = gr.Dataset(components=[input_image],
115
+ samples=[[path.as_posix()]
116
+ for path in paths])
117
+
118
+ input_image.change(fn=update_input_image,
119
+ inputs=input_image,
120
+ outputs=input_image)
121
+
122
+ model_type.change(fn=update_model_name,
123
+ inputs=model_type,
124
+ outputs=model_name)
125
+ model_type.change(fn=update_visualization_score_threshold,
126
+ inputs=model_type,
127
+ outputs=visualization_score_threshold)
128
+ model_type.change(fn=update_redraw_button,
129
+ inputs=model_type,
130
+ outputs=redraw_button)
131
+
132
+ model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
133
+ run_button.click(fn=model.run,
134
+ inputs=[
135
+ model_name,
136
+ input_image,
137
+ visualization_score_threshold,
138
+ ],
139
+ outputs=[
140
+ prediction_results,
141
+ visualization,
142
+ ])
143
+ redraw_button.click(fn=model.visualize_detection_results,
144
+ inputs=[
145
+ input_image,
146
+ prediction_results,
147
+ visualization_score_threshold,
148
+ ],
149
+ outputs=visualization)
150
+ example_images.click(fn=set_example_image,
151
+ inputs=example_images,
152
+ outputs=input_image)
153
+
154
+ demo.queue().launch(show_api=False)
model.py CHANGED
@@ -6,7 +6,7 @@ import huggingface_hub
6
  import numpy as np
7
  import torch
8
  import torch.nn as nn
9
- import yaml
10
  from mmdet.apis import inference_detector, init_detector
11
 
12
 
@@ -48,8 +48,9 @@ class Model:
48
  'model_dict/panoptic_segmentation.yaml')
49
  MODEL_DICT = DETECTION_MODEL_DICT | INSTANCE_SEGMENTATION_MODEL_DICT | PANOPTIC_SEGMENTATION_MODEL_DICT
50
 
51
- def __init__(self, model_name: str, device: str | torch.device):
52
- self.device = torch.device(device)
 
53
  self._load_all_models_once()
54
  self.model_name = model_name
55
  self.model = self._load_model(model_name)
 
6
  import numpy as np
7
  import torch
8
  import torch.nn as nn
9
+ import yaml # type: ignore
10
  from mmdet.apis import inference_detector, init_detector
11
 
12
 
 
48
  'model_dict/panoptic_segmentation.yaml')
49
  MODEL_DICT = DETECTION_MODEL_DICT | INSTANCE_SEGMENTATION_MODEL_DICT | PANOPTIC_SEGMENTATION_MODEL_DICT
50
 
51
+ def __init__(self, model_name: str):
52
+ self.device = torch.device(
53
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
54
  self._load_all_models_once()
55
  self.model_name = model_name
56
  self.model = self._load_model(model_name)