hysts HF staff commited on
Commit
7fd17d1
1 Parent(s): 0c0d56d
Files changed (2) hide show
  1. app.py +1 -134
  2. model.py +139 -0
app.py CHANGED
@@ -3,33 +3,11 @@
3
  from __future__ import annotations
4
 
5
  import argparse
6
- import os
7
  import pathlib
8
- import subprocess
9
- import sys
10
-
11
- if os.getenv('SYSTEM') == 'spaces':
12
- import mim
13
-
14
- mim.uninstall('mmcv-full', confirm_yes=True)
15
- mim.install('mmcv-full==1.5.0', is_yes=True)
16
-
17
- subprocess.run('pip uninstall -y opencv-python'.split())
18
- subprocess.run('pip uninstall -y opencv-python-headless'.split())
19
- subprocess.run('pip install opencv-python-headless==4.5.5.64'.split())
20
-
21
- with open('patch') as f:
22
- subprocess.run('patch -p1'.split(), cwd='CBNetV2', stdin=f)
23
- subprocess.run('mv palette.py CBNetV2/mmdet/core/visualization/'.split())
24
 
25
  import gradio as gr
26
- import numpy as np
27
- import torch
28
- import torch.nn as nn
29
 
30
- sys.path.insert(0, 'CBNetV2/')
31
-
32
- from mmdet.apis import inference_detector, init_detector
33
 
34
  DESCRIPTION = '''# CBNetV2
35
 
@@ -49,117 +27,6 @@ def parse_args() -> argparse.Namespace:
49
  return parser.parse_args()
50
 
51
 
52
- class Model:
53
- def __init__(self, device: str | torch.device):
54
- self.device = torch.device(device)
55
- self.models = self._load_models()
56
- self.model_name = 'Improved HTC (DB-Swin-B)'
57
-
58
- def _load_models(self) -> dict[str, nn.Module]:
59
- model_dict = {
60
- 'Faster R-CNN (DB-ResNet50)': {
61
- 'config':
62
- 'CBNetV2/configs/cbnet/faster_rcnn_cbv2d1_r50_fpn_1x_coco.py',
63
- 'model':
64
- 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/faster_rcnn_cbv2d1_r50_fpn_1x_coco.pth.zip',
65
- },
66
- 'Mask R-CNN (DB-Swin-T)': {
67
- 'config':
68
- 'CBNetV2/configs/cbnet/mask_rcnn_cbv2_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.py',
69
- 'model':
70
- 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/mask_rcnn_cbv2_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.pth.zip',
71
- },
72
- # 'Cascade Mask R-CNN (DB-Swin-S)': {
73
- # 'config':
74
- # 'CBNetV2/configs/cbnet/cascade_mask_rcnn_cbv2_swin_small_patch4_window7_mstrain_400-1400_adamw_3x_coco.py',
75
- # 'model':
76
- # 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/cascade_mask_rcnn_cbv2_swin_small_patch4_window7_mstrain_400-1400_adamw_3x_coco.pth.zip',
77
- # },
78
- 'Improved HTC (DB-Swin-B)': {
79
- 'config':
80
- 'CBNetV2/configs/cbnet/htc_cbv2_swin_base_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_20e_coco.py',
81
- 'model':
82
- 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/htc_cbv2_swin_base22k_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_20e_coco.pth.zip',
83
- },
84
- 'Improved HTC (DB-Swin-L)': {
85
- 'config':
86
- 'CBNetV2/configs/cbnet/htc_cbv2_swin_large_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.py',
87
- 'model':
88
- 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/htc_cbv2_swin_large22k_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.pth.zip',
89
- },
90
- 'Improved HTC (DB-Swin-L (TTA))': {
91
- 'config':
92
- 'CBNetV2/configs/cbnet/htc_cbv2_swin_large_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.py',
93
- 'model':
94
- 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/htc_cbv2_swin_large22k_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.pth.zip',
95
- },
96
- }
97
-
98
- weight_dir = pathlib.Path('weights')
99
- weight_dir.mkdir(exist_ok=True)
100
-
101
- def _download(model_name: str, out_dir: pathlib.Path) -> None:
102
- import zipfile
103
-
104
- model_url = model_dict[model_name]['model']
105
- zip_name = model_url.split('/')[-1]
106
-
107
- out_path = out_dir / zip_name
108
- if out_path.exists():
109
- return
110
- torch.hub.download_url_to_file(model_url, out_path)
111
-
112
- with zipfile.ZipFile(out_path) as f:
113
- f.extractall(out_dir)
114
-
115
- def _get_model_path(model_name: str) -> str:
116
- model_url = model_dict[model_name]['model']
117
- model_name = model_url.split('/')[-1][:-4]
118
- return (weight_dir / model_name).as_posix()
119
-
120
- for model_name in model_dict:
121
- _download(model_name, weight_dir)
122
-
123
- models = {
124
- key: init_detector(dic['config'],
125
- _get_model_path(key),
126
- device=self.device)
127
- for key, dic in model_dict.items()
128
- }
129
- return models
130
-
131
- def set_model_name(self, name: str) -> None:
132
- self.model_name = name
133
-
134
- def detect_and_visualize(
135
- self, image: np.ndarray,
136
- score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
137
- out = self.detect(image)
138
- vis = self.visualize_detection_results(image, out, score_threshold)
139
- return out, vis
140
-
141
- def detect(self, image: np.ndarray) -> list[np.ndarray]:
142
- image = image[:, :, ::-1] # RGB -> BGR
143
- model = self.models[self.model_name]
144
- out = inference_detector(model, image)
145
- return out
146
-
147
- def visualize_detection_results(
148
- self,
149
- image: np.ndarray,
150
- detection_results: list[np.ndarray],
151
- score_threshold: float = 0.3) -> np.ndarray:
152
- image = image[:, :, ::-1] # RGB -> BGR
153
- model = self.models[self.model_name]
154
- vis = model.show_result(image,
155
- detection_results,
156
- score_thr=score_threshold,
157
- bbox_color=None,
158
- text_color=(200, 200, 200),
159
- mask_color=None)
160
- return vis[:, :, ::-1] # BGR -> RGB
161
-
162
-
163
  def set_example_image(example: list) -> dict:
164
  return gr.Image.update(value=example[0])
165
 
 
3
  from __future__ import annotations
4
 
5
  import argparse
 
6
  import pathlib
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  import gradio as gr
 
 
 
9
 
10
+ from model import Model
 
 
11
 
12
  DESCRIPTION = '''# CBNetV2
13
 
 
27
  return parser.parse_args()
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def set_example_image(example: list) -> dict:
31
  return gr.Image.update(value=example[0])
32
 
model.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pathlib
5
+ import subprocess
6
+ import sys
7
+
8
+ if os.getenv('SYSTEM') == 'spaces':
9
+ import mim
10
+
11
+ mim.uninstall('mmcv-full', confirm_yes=True)
12
+ mim.install('mmcv-full==1.5.0', is_yes=True)
13
+
14
+ subprocess.run('pip uninstall -y opencv-python'.split())
15
+ subprocess.run('pip uninstall -y opencv-python-headless'.split())
16
+ subprocess.run('pip install opencv-python-headless==4.5.5.64'.split())
17
+
18
+ with open('patch') as f:
19
+ subprocess.run('patch -p1'.split(), cwd='CBNetV2', stdin=f)
20
+ subprocess.run('mv palette.py CBNetV2/mmdet/core/visualization/'.split())
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+ sys.path.insert(0, 'CBNetV2/')
27
+
28
+ from mmdet.apis import inference_detector, init_detector
29
+
30
+
31
+ class Model:
32
+ def __init__(self, device: str | torch.device):
33
+ self.device = torch.device(device)
34
+ self.models = self._load_models()
35
+ self.model_name = 'Improved HTC (DB-Swin-B)'
36
+
37
+ def _load_models(self) -> dict[str, nn.Module]:
38
+ model_dict = {
39
+ 'Faster R-CNN (DB-ResNet50)': {
40
+ 'config':
41
+ 'CBNetV2/configs/cbnet/faster_rcnn_cbv2d1_r50_fpn_1x_coco.py',
42
+ 'model':
43
+ 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/faster_rcnn_cbv2d1_r50_fpn_1x_coco.pth.zip',
44
+ },
45
+ 'Mask R-CNN (DB-Swin-T)': {
46
+ 'config':
47
+ 'CBNetV2/configs/cbnet/mask_rcnn_cbv2_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.py',
48
+ 'model':
49
+ 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/mask_rcnn_cbv2_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.pth.zip',
50
+ },
51
+ # 'Cascade Mask R-CNN (DB-Swin-S)': {
52
+ # 'config':
53
+ # 'CBNetV2/configs/cbnet/cascade_mask_rcnn_cbv2_swin_small_patch4_window7_mstrain_400-1400_adamw_3x_coco.py',
54
+ # 'model':
55
+ # 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/cascade_mask_rcnn_cbv2_swin_small_patch4_window7_mstrain_400-1400_adamw_3x_coco.pth.zip',
56
+ # },
57
+ 'Improved HTC (DB-Swin-B)': {
58
+ 'config':
59
+ 'CBNetV2/configs/cbnet/htc_cbv2_swin_base_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_20e_coco.py',
60
+ 'model':
61
+ 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/htc_cbv2_swin_base22k_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_20e_coco.pth.zip',
62
+ },
63
+ 'Improved HTC (DB-Swin-L)': {
64
+ 'config':
65
+ 'CBNetV2/configs/cbnet/htc_cbv2_swin_large_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.py',
66
+ 'model':
67
+ 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/htc_cbv2_swin_large22k_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.pth.zip',
68
+ },
69
+ 'Improved HTC (DB-Swin-L (TTA))': {
70
+ 'config':
71
+ 'CBNetV2/configs/cbnet/htc_cbv2_swin_large_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.py',
72
+ 'model':
73
+ 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/htc_cbv2_swin_large22k_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.pth.zip',
74
+ },
75
+ }
76
+
77
+ weight_dir = pathlib.Path('weights')
78
+ weight_dir.mkdir(exist_ok=True)
79
+
80
+ def _download(model_name: str, out_dir: pathlib.Path) -> None:
81
+ import zipfile
82
+
83
+ model_url = model_dict[model_name]['model']
84
+ zip_name = model_url.split('/')[-1]
85
+
86
+ out_path = out_dir / zip_name
87
+ if out_path.exists():
88
+ return
89
+ torch.hub.download_url_to_file(model_url, out_path)
90
+
91
+ with zipfile.ZipFile(out_path) as f:
92
+ f.extractall(out_dir)
93
+
94
+ def _get_model_path(model_name: str) -> str:
95
+ model_url = model_dict[model_name]['model']
96
+ model_name = model_url.split('/')[-1][:-4]
97
+ return (weight_dir / model_name).as_posix()
98
+
99
+ for model_name in model_dict:
100
+ _download(model_name, weight_dir)
101
+
102
+ models = {
103
+ key: init_detector(dic['config'],
104
+ _get_model_path(key),
105
+ device=self.device)
106
+ for key, dic in model_dict.items()
107
+ }
108
+ return models
109
+
110
+ def set_model_name(self, name: str) -> None:
111
+ self.model_name = name
112
+
113
+ def detect_and_visualize(
114
+ self, image: np.ndarray,
115
+ score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
116
+ out = self.detect(image)
117
+ vis = self.visualize_detection_results(image, out, score_threshold)
118
+ return out, vis
119
+
120
+ def detect(self, image: np.ndarray) -> list[np.ndarray]:
121
+ image = image[:, :, ::-1] # RGB -> BGR
122
+ model = self.models[self.model_name]
123
+ out = inference_detector(model, image)
124
+ return out
125
+
126
+ def visualize_detection_results(
127
+ self,
128
+ image: np.ndarray,
129
+ detection_results: list[np.ndarray],
130
+ score_threshold: float = 0.3) -> np.ndarray:
131
+ image = image[:, :, ::-1] # RGB -> BGR
132
+ model = self.models[self.model_name]
133
+ vis = model.show_result(image,
134
+ detection_results,
135
+ score_thr=score_threshold,
136
+ bbox_color=None,
137
+ text_color=(200, 200, 200),
138
+ mask_color=None)
139
+ return vis[:, :, ::-1] # BGR -> RGB