hysts HF staff commited on
Commit
293637a
1 Parent(s): bef58c2
Files changed (6) hide show
  1. .pre-commit-config.yaml +59 -35
  2. .vscode/settings.json +30 -0
  3. app.py +34 -37
  4. model.py +44 -59
  5. palette.py +10 -7
  6. style.css +6 -2
.pre-commit-config.yaml CHANGED
@@ -1,37 +1,61 @@
1
  exclude: ^patch
2
  repos:
3
- - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.2.0
5
- hooks:
6
- - id: check-executables-have-shebangs
7
- - id: check-json
8
- - id: check-merge-conflict
9
- - id: check-shebang-scripts-are-executable
10
- - id: check-toml
11
- - id: check-yaml
12
- - id: double-quote-string-fixer
13
- - id: end-of-file-fixer
14
- - id: mixed-line-ending
15
- args: ['--fix=lf']
16
- - id: requirements-txt-fixer
17
- - id: trailing-whitespace
18
- - repo: https://github.com/myint/docformatter
19
- rev: v1.4
20
- hooks:
21
- - id: docformatter
22
- args: ['--in-place']
23
- - repo: https://github.com/pycqa/isort
24
- rev: 5.12.0
25
- hooks:
26
- - id: isort
27
- - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.991
29
- hooks:
30
- - id: mypy
31
- args: ['--ignore-missing-imports']
32
- additional_dependencies: ['types-python-slugify']
33
- - repo: https://github.com/google/yapf
34
- rev: v0.32.0
35
- hooks:
36
- - id: yapf
37
- args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  exclude: ^patch
2
  repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.6.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ["--fix=lf"]
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.7.5
19
+ hooks:
20
+ - id: docformatter
21
+ args: ["--in-place"]
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.13.2
24
+ hooks:
25
+ - id: isort
26
+ args: ["--profile", "black"]
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v1.10.0
29
+ hooks:
30
+ - id: mypy
31
+ args: ["--ignore-missing-imports"]
32
+ additional_dependencies:
33
+ [
34
+ "types-python-slugify",
35
+ "types-requests",
36
+ "types-PyYAML",
37
+ "types-pytz",
38
+ ]
39
+ - repo: https://github.com/psf/black
40
+ rev: 24.4.2
41
+ hooks:
42
+ - id: black
43
+ language_version: python3.10
44
+ args: ["--line-length", "119"]
45
+ - repo: https://github.com/kynan/nbstripout
46
+ rev: 0.7.1
47
+ hooks:
48
+ - id: nbstripout
49
+ args:
50
+ [
51
+ "--extra-keys",
52
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
53
+ ]
54
+ - repo: https://github.com/nbQA-dev/nbQA
55
+ rev: 1.8.5
56
+ hooks:
57
+ - id: nbqa-black
58
+ - id: nbqa-pyupgrade
59
+ args: ["--py37-plus"]
60
+ - id: nbqa-isort
61
+ args: ["--float-to-top"]
.vscode/settings.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.formatOnSave": true,
3
+ "files.insertFinalNewline": false,
4
+ "[python]": {
5
+ "editor.defaultFormatter": "ms-python.black-formatter",
6
+ "editor.formatOnType": true,
7
+ "editor.codeActionsOnSave": {
8
+ "source.organizeImports": "explicit"
9
+ }
10
+ },
11
+ "[jupyter]": {
12
+ "files.insertFinalNewline": false
13
+ },
14
+ "black-formatter.args": [
15
+ "--line-length=119"
16
+ ],
17
+ "isort.args": ["--profile", "black"],
18
+ "flake8.args": [
19
+ "--max-line-length=119"
20
+ ],
21
+ "ruff.lint.args": [
22
+ "--line-length=119"
23
+ ],
24
+ "notebook.output.scrolling": true,
25
+ "notebook.formatOnCellExecution": true,
26
+ "notebook.formatOnSave.enabled": true,
27
+ "notebook.codeActionsOnSave": {
28
+ "source.organizeImports": "explicit"
29
+ }
30
+ }
app.py CHANGED
@@ -8,60 +8,57 @@ import gradio as gr
8
 
9
  from model import Model
10
 
11
- DESCRIPTION = '# [CBNetV2](https://github.com/VDIGPKU/CBNetV2)'
12
 
13
  model = Model()
14
 
15
- with gr.Blocks(css='style.css') as demo:
16
  gr.Markdown(DESCRIPTION)
17
 
18
  with gr.Row():
19
  with gr.Column():
20
  with gr.Row():
21
- input_image = gr.Image(label='Input Image', type='numpy')
22
  with gr.Row():
23
- detector_name = gr.Dropdown(label='Detector',
24
- choices=list(model.models.keys()),
25
- value=model.model_name)
26
  with gr.Row():
27
- detect_button = gr.Button('Detect')
28
  detection_results = gr.Variable()
29
  with gr.Column():
30
  with gr.Row():
31
- detection_visualization = gr.Image(label='Detection Result',
32
- type='numpy')
33
  with gr.Row():
34
  visualization_score_threshold = gr.Slider(
35
- label='Visualization Score Threshold',
36
- minimum=0,
37
- maximum=1,
38
- step=0.05,
39
- value=0.3)
40
  with gr.Row():
41
- redraw_button = gr.Button('Redraw')
42
 
43
  with gr.Row():
44
- paths = sorted(pathlib.Path('images').rglob('*.jpg'))
45
- gr.Examples(examples=[[path.as_posix()] for path in paths],
46
- inputs=input_image)
47
 
48
- detector_name.change(fn=model.set_model_name,
49
- inputs=[detector_name],
50
- outputs=None)
51
- detect_button.click(fn=model.detect_and_visualize,
52
- inputs=[
53
- input_image,
54
- visualization_score_threshold,
55
- ],
56
- outputs=[
57
- detection_results,
58
- detection_visualization,
59
- ])
60
- redraw_button.click(fn=model.visualize_detection_results,
61
- inputs=[
62
- input_image,
63
- detection_results,
64
- visualization_score_threshold,
65
- ],
66
- outputs=[detection_visualization])
 
 
67
  demo.queue(max_size=10).launch()
 
8
 
9
  from model import Model
10
 
11
+ DESCRIPTION = "# [CBNetV2](https://github.com/VDIGPKU/CBNetV2)"
12
 
13
  model = Model()
14
 
15
+ with gr.Blocks(css="style.css") as demo:
16
  gr.Markdown(DESCRIPTION)
17
 
18
  with gr.Row():
19
  with gr.Column():
20
  with gr.Row():
21
+ input_image = gr.Image(label="Input Image", type="numpy")
22
  with gr.Row():
23
+ detector_name = gr.Dropdown(
24
+ label="Detector", choices=list(model.models.keys()), value=model.model_name
25
+ )
26
  with gr.Row():
27
+ detect_button = gr.Button("Detect")
28
  detection_results = gr.Variable()
29
  with gr.Column():
30
  with gr.Row():
31
+ detection_visualization = gr.Image(label="Detection Result", type="numpy")
 
32
  with gr.Row():
33
  visualization_score_threshold = gr.Slider(
34
+ label="Visualization Score Threshold", minimum=0, maximum=1, step=0.05, value=0.3
35
+ )
 
 
 
36
  with gr.Row():
37
+ redraw_button = gr.Button("Redraw")
38
 
39
  with gr.Row():
40
+ paths = sorted(pathlib.Path("images").rglob("*.jpg"))
41
+ gr.Examples(examples=[[path.as_posix()] for path in paths], inputs=input_image)
 
42
 
43
+ detector_name.change(fn=model.set_model_name, inputs=[detector_name], outputs=None)
44
+ detect_button.click(
45
+ fn=model.detect_and_visualize,
46
+ inputs=[
47
+ input_image,
48
+ visualization_score_threshold,
49
+ ],
50
+ outputs=[
51
+ detection_results,
52
+ detection_visualization,
53
+ ],
54
+ )
55
+ redraw_button.click(
56
+ fn=model.visualize_detection_results,
57
+ inputs=[
58
+ input_image,
59
+ detection_results,
60
+ visualization_score_threshold,
61
+ ],
62
+ outputs=[detection_visualization],
63
+ )
64
  demo.queue(max_size=10).launch()
model.py CHANGED
@@ -6,26 +6,26 @@ import shlex
6
  import subprocess
7
  import sys
8
 
9
- if os.getenv('SYSTEM') == 'spaces':
10
  import mim
11
 
12
- mim.uninstall('mmcv-full', confirm_yes=True)
13
- mim.install('mmcv-full==1.5.0', is_yes=True)
14
 
15
- subprocess.run(shlex.split('pip uninstall -y opencv-python'))
16
- subprocess.run(shlex.split('pip uninstall -y opencv-python-headless'))
17
- subprocess.run(shlex.split('pip install opencv-python-headless==4.8.0.74'))
18
 
19
- with open('patch') as f:
20
- subprocess.run(shlex.split('patch -p1'), cwd='CBNetV2', stdin=f)
21
- subprocess.run('mv palette.py CBNetV2/mmdet/core/visualization/'.split())
22
 
23
  import numpy as np
24
  import torch
25
  import torch.nn as nn
26
 
27
  app_dir = pathlib.Path(__file__).parent
28
- submodule_dir = app_dir / 'CBNetV2/'
29
  sys.path.insert(0, submodule_dir.as_posix())
30
 
31
  from mmdet.apis import inference_detector, init_detector
@@ -33,24 +33,19 @@ from mmdet.apis import inference_detector, init_detector
33
 
34
  class Model:
35
  def __init__(self):
36
- self.device = torch.device(
37
- 'cuda:0' if torch.cuda.is_available() else 'cpu')
38
  self.models = self._load_models()
39
- self.model_name = 'Improved HTC (DB-Swin-B)'
40
 
41
  def _load_models(self) -> dict[str, nn.Module]:
42
  model_dict = {
43
- 'Faster R-CNN (DB-ResNet50)': {
44
- 'config':
45
- 'CBNetV2/configs/cbnet/faster_rcnn_cbv2d1_r50_fpn_1x_coco.py',
46
- 'model':
47
- 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/faster_rcnn_cbv2d1_r50_fpn_1x_coco.pth.zip',
48
  },
49
- 'Mask R-CNN (DB-Swin-T)': {
50
- 'config':
51
- 'CBNetV2/configs/cbnet/mask_rcnn_cbv2_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.py',
52
- 'model':
53
- 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/mask_rcnn_cbv2_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.pth.zip',
54
  },
55
  # 'Cascade Mask R-CNN (DB-Swin-S)': {
56
  # 'config':
@@ -58,34 +53,28 @@ class Model:
58
  # 'model':
59
  # 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/cascade_mask_rcnn_cbv2_swin_small_patch4_window7_mstrain_400-1400_adamw_3x_coco.pth.zip',
60
  # },
61
- 'Improved HTC (DB-Swin-B)': {
62
- 'config':
63
- 'CBNetV2/configs/cbnet/htc_cbv2_swin_base_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_20e_coco.py',
64
- 'model':
65
- 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/htc_cbv2_swin_base22k_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_20e_coco.pth.zip',
66
  },
67
- 'Improved HTC (DB-Swin-L)': {
68
- 'config':
69
- 'CBNetV2/configs/cbnet/htc_cbv2_swin_large_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.py',
70
- 'model':
71
- 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/htc_cbv2_swin_large22k_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.pth.zip',
72
  },
73
- 'Improved HTC (DB-Swin-L (TTA))': {
74
- 'config':
75
- 'CBNetV2/configs/cbnet/htc_cbv2_swin_large_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.py',
76
- 'model':
77
- 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/htc_cbv2_swin_large22k_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.pth.zip',
78
  },
79
  }
80
 
81
- weight_dir = pathlib.Path('weights')
82
  weight_dir.mkdir(exist_ok=True)
83
 
84
  def _download(model_name: str, out_dir: pathlib.Path) -> None:
85
  import zipfile
86
 
87
- model_url = model_dict[model_name]['model']
88
- zip_name = model_url.split('/')[-1]
89
 
90
  out_path = out_dir / zip_name
91
  if out_path.exists():
@@ -96,17 +85,15 @@ class Model:
96
  f.extractall(out_dir)
97
 
98
  def _get_model_path(model_name: str) -> str:
99
- model_url = model_dict[model_name]['model']
100
- model_name = model_url.split('/')[-1][:-4]
101
  return (weight_dir / model_name).as_posix()
102
 
103
  for model_name in model_dict:
104
  _download(model_name, weight_dir)
105
 
106
  models = {
107
- key: init_detector(dic['config'],
108
- _get_model_path(key),
109
- device=self.device)
110
  for key, dic in model_dict.items()
111
  }
112
  return models
@@ -114,9 +101,7 @@ class Model:
114
  def set_model_name(self, name: str) -> None:
115
  self.model_name = name
116
 
117
- def detect_and_visualize(
118
- self, image: np.ndarray,
119
- score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
120
  out = self.detect(image)
121
  vis = self.visualize_detection_results(image, out, score_threshold)
122
  return out, vis
@@ -128,16 +113,16 @@ class Model:
128
  return out
129
 
130
  def visualize_detection_results(
131
- self,
132
- image: np.ndarray,
133
- detection_results: list[np.ndarray],
134
- score_threshold: float = 0.3) -> np.ndarray:
135
  image = image[:, :, ::-1] # RGB -> BGR
136
  model = self.models[self.model_name]
137
- vis = model.show_result(image,
138
- detection_results,
139
- score_thr=score_threshold,
140
- bbox_color=None,
141
- text_color=(200, 200, 200),
142
- mask_color=None)
 
 
143
  return vis[:, :, ::-1] # BGR -> RGB
 
6
  import subprocess
7
  import sys
8
 
9
+ if os.getenv("SYSTEM") == "spaces":
10
  import mim
11
 
12
+ mim.uninstall("mmcv-full", confirm_yes=True)
13
+ mim.install("mmcv-full==1.5.0", is_yes=True)
14
 
15
+ subprocess.run(shlex.split("pip uninstall -y opencv-python"))
16
+ subprocess.run(shlex.split("pip uninstall -y opencv-python-headless"))
17
+ subprocess.run(shlex.split("pip install opencv-python-headless==4.8.0.74"))
18
 
19
+ with open("patch") as f:
20
+ subprocess.run(shlex.split("patch -p1"), cwd="CBNetV2", stdin=f)
21
+ subprocess.run("mv palette.py CBNetV2/mmdet/core/visualization/".split())
22
 
23
  import numpy as np
24
  import torch
25
  import torch.nn as nn
26
 
27
  app_dir = pathlib.Path(__file__).parent
28
+ submodule_dir = app_dir / "CBNetV2/"
29
  sys.path.insert(0, submodule_dir.as_posix())
30
 
31
  from mmdet.apis import inference_detector, init_detector
 
33
 
34
  class Model:
35
  def __init__(self):
36
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
37
  self.models = self._load_models()
38
+ self.model_name = "Improved HTC (DB-Swin-B)"
39
 
40
  def _load_models(self) -> dict[str, nn.Module]:
41
  model_dict = {
42
+ "Faster R-CNN (DB-ResNet50)": {
43
+ "config": "CBNetV2/configs/cbnet/faster_rcnn_cbv2d1_r50_fpn_1x_coco.py",
44
+ "model": "https://github.com/CBNetwork/storage/releases/download/v1.0.0/faster_rcnn_cbv2d1_r50_fpn_1x_coco.pth.zip",
 
 
45
  },
46
+ "Mask R-CNN (DB-Swin-T)": {
47
+ "config": "CBNetV2/configs/cbnet/mask_rcnn_cbv2_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.py",
48
+ "model": "https://github.com/CBNetwork/storage/releases/download/v1.0.0/mask_rcnn_cbv2_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.pth.zip",
 
 
49
  },
50
  # 'Cascade Mask R-CNN (DB-Swin-S)': {
51
  # 'config':
 
53
  # 'model':
54
  # 'https://github.com/CBNetwork/storage/releases/download/v1.0.0/cascade_mask_rcnn_cbv2_swin_small_patch4_window7_mstrain_400-1400_adamw_3x_coco.pth.zip',
55
  # },
56
+ "Improved HTC (DB-Swin-B)": {
57
+ "config": "CBNetV2/configs/cbnet/htc_cbv2_swin_base_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_20e_coco.py",
58
+ "model": "https://github.com/CBNetwork/storage/releases/download/v1.0.0/htc_cbv2_swin_base22k_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_20e_coco.pth.zip",
 
 
59
  },
60
+ "Improved HTC (DB-Swin-L)": {
61
+ "config": "CBNetV2/configs/cbnet/htc_cbv2_swin_large_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.py",
62
+ "model": "https://github.com/CBNetwork/storage/releases/download/v1.0.0/htc_cbv2_swin_large22k_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.pth.zip",
 
 
63
  },
64
+ "Improved HTC (DB-Swin-L (TTA))": {
65
+ "config": "CBNetV2/configs/cbnet/htc_cbv2_swin_large_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.py",
66
+ "model": "https://github.com/CBNetwork/storage/releases/download/v1.0.0/htc_cbv2_swin_large22k_patch4_window7_mstrain_400-1400_giou_4conv1f_adamw_1x_coco.pth.zip",
 
 
67
  },
68
  }
69
 
70
+ weight_dir = pathlib.Path("weights")
71
  weight_dir.mkdir(exist_ok=True)
72
 
73
  def _download(model_name: str, out_dir: pathlib.Path) -> None:
74
  import zipfile
75
 
76
+ model_url = model_dict[model_name]["model"]
77
+ zip_name = model_url.split("/")[-1]
78
 
79
  out_path = out_dir / zip_name
80
  if out_path.exists():
 
85
  f.extractall(out_dir)
86
 
87
  def _get_model_path(model_name: str) -> str:
88
+ model_url = model_dict[model_name]["model"]
89
+ model_name = model_url.split("/")[-1][:-4]
90
  return (weight_dir / model_name).as_posix()
91
 
92
  for model_name in model_dict:
93
  _download(model_name, weight_dir)
94
 
95
  models = {
96
+ key: init_detector(dic["config"], _get_model_path(key), device=self.device)
 
 
97
  for key, dic in model_dict.items()
98
  }
99
  return models
 
101
  def set_model_name(self, name: str) -> None:
102
  self.model_name = name
103
 
104
+ def detect_and_visualize(self, image: np.ndarray, score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
 
 
105
  out = self.detect(image)
106
  vis = self.visualize_detection_results(image, out, score_threshold)
107
  return out, vis
 
113
  return out
114
 
115
  def visualize_detection_results(
116
+ self, image: np.ndarray, detection_results: list[np.ndarray], score_threshold: float = 0.3
117
+ ) -> np.ndarray:
 
 
118
  image = image[:, :, ::-1] # RGB -> BGR
119
  model = self.models[self.model_name]
120
+ vis = model.show_result(
121
+ image,
122
+ detection_results,
123
+ score_thr=score_threshold,
124
+ bbox_color=None,
125
+ text_color=(200, 200, 200),
126
+ mask_color=None,
127
+ )
128
  return vis[:, :, ::-1] # BGR -> RGB
palette.py CHANGED
@@ -208,6 +208,7 @@ Copyright 2018-2023 OpenMMLab. All rights reserved.
208
  limitations under the License.
209
  ```
210
  """
 
211
  # Copyright (c) OpenMMLab. All rights reserved.
212
  import mmcv
213
  import numpy as np
@@ -245,29 +246,31 @@ def get_palette(palette, num_classes):
245
  dataset_palette = palette
246
  elif isinstance(palette, tuple):
247
  dataset_palette = [palette] * num_classes
248
- elif palette == 'random' or palette is None:
249
  state = np.random.get_state()
250
  # random color
251
  np.random.seed(42)
252
  palette = np.random.randint(0, 256, size=(num_classes, 3))
253
  np.random.set_state(state)
254
  dataset_palette = [tuple(c) for c in palette]
255
- elif palette == 'coco':
256
  from mmdet.datasets import CocoDataset, CocoPanopticDataset
 
257
  dataset_palette = CocoDataset.PALETTE
258
  if len(dataset_palette) < num_classes:
259
  dataset_palette = CocoPanopticDataset.PALETTE
260
- elif palette == 'citys':
261
  from mmdet.datasets import CityscapesDataset
 
262
  dataset_palette = CityscapesDataset.PALETTE
263
- elif palette == 'voc':
264
  from mmdet.datasets import VOCDataset
 
265
  dataset_palette = VOCDataset.PALETTE
266
  elif mmcv.is_str(palette):
267
  dataset_palette = [mmcv.color_val(palette)[::-1]] * num_classes
268
  else:
269
- raise TypeError(f'Invalid type for palette: {type(palette)}')
270
 
271
- assert len(dataset_palette) >= num_classes, \
272
- 'The length of palette should not be less than `num_classes`.'
273
  return dataset_palette
 
208
  limitations under the License.
209
  ```
210
  """
211
+
212
  # Copyright (c) OpenMMLab. All rights reserved.
213
  import mmcv
214
  import numpy as np
 
246
  dataset_palette = palette
247
  elif isinstance(palette, tuple):
248
  dataset_palette = [palette] * num_classes
249
+ elif palette == "random" or palette is None:
250
  state = np.random.get_state()
251
  # random color
252
  np.random.seed(42)
253
  palette = np.random.randint(0, 256, size=(num_classes, 3))
254
  np.random.set_state(state)
255
  dataset_palette = [tuple(c) for c in palette]
256
+ elif palette == "coco":
257
  from mmdet.datasets import CocoDataset, CocoPanopticDataset
258
+
259
  dataset_palette = CocoDataset.PALETTE
260
  if len(dataset_palette) < num_classes:
261
  dataset_palette = CocoPanopticDataset.PALETTE
262
+ elif palette == "citys":
263
  from mmdet.datasets import CityscapesDataset
264
+
265
  dataset_palette = CityscapesDataset.PALETTE
266
+ elif palette == "voc":
267
  from mmdet.datasets import VOCDataset
268
+
269
  dataset_palette = VOCDataset.PALETTE
270
  elif mmcv.is_str(palette):
271
  dataset_palette = [mmcv.color_val(palette)[::-1]] * num_classes
272
  else:
273
+ raise TypeError(f"Invalid type for palette: {type(palette)}")
274
 
275
+ assert len(dataset_palette) >= num_classes, "The length of palette should not be less than `num_classes`."
 
276
  return dataset_palette
style.css CHANGED
@@ -1,7 +1,11 @@
1
  h1 {
2
  text-align: center;
3
- }
4
- img#visitor-badge {
5
  display: block;
 
 
 
6
  margin: auto;
 
 
 
7
  }
 
1
  h1 {
2
  text-align: center;
 
 
3
  display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
  margin: auto;
8
+ color: #fff;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
  }