hysts HF staff commited on
Commit
51488de
1 Parent(s): 69af19f
Files changed (4) hide show
  1. .pre-commit-config.yaml +59 -34
  2. .style.yapf +0 -5
  3. .vscode/settings.json +30 -0
  4. app.py +23 -29
.pre-commit-config.yaml CHANGED
@@ -1,35 +1,60 @@
1
  repos:
2
- - repo: https://github.com/pre-commit/pre-commit-hooks
3
- rev: v4.2.0
4
- hooks:
5
- - id: check-executables-have-shebangs
6
- - id: check-json
7
- - id: check-merge-conflict
8
- - id: check-shebang-scripts-are-executable
9
- - id: check-toml
10
- - id: check-yaml
11
- - id: double-quote-string-fixer
12
- - id: end-of-file-fixer
13
- - id: mixed-line-ending
14
- args: ['--fix=lf']
15
- - id: requirements-txt-fixer
16
- - id: trailing-whitespace
17
- - repo: https://github.com/myint/docformatter
18
- rev: v1.4
19
- hooks:
20
- - id: docformatter
21
- args: ['--in-place']
22
- - repo: https://github.com/pycqa/isort
23
- rev: 5.12.0
24
- hooks:
25
- - id: isort
26
- - repo: https://github.com/pre-commit/mirrors-mypy
27
- rev: v0.991
28
- hooks:
29
- - id: mypy
30
- args: ['--ignore-missing-imports']
31
- - repo: https://github.com/google/yapf
32
- rev: v0.32.0
33
- hooks:
34
- - id: yapf
35
- args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.6.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.13.2
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.9.0
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ [
33
+ "types-python-slugify",
34
+ "types-requests",
35
+ "types-PyYAML",
36
+ "types-pytz",
37
+ ]
38
+ - repo: https://github.com/psf/black
39
+ rev: 24.4.0
40
+ hooks:
41
+ - id: black
42
+ language_version: python3.10
43
+ args: ["--line-length", "119"]
44
+ - repo: https://github.com/kynan/nbstripout
45
+ rev: 0.7.1
46
+ hooks:
47
+ - id: nbstripout
48
+ args:
49
+ [
50
+ "--extra-keys",
51
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
52
+ ]
53
+ - repo: https://github.com/nbQA-dev/nbQA
54
+ rev: 1.8.5
55
+ hooks:
56
+ - id: nbqa-black
57
+ - id: nbqa-pyupgrade
58
+ args: ["--py37-plus"]
59
+ - id: nbqa-isort
60
+ args: ["--float-to-top"]
.style.yapf DELETED
@@ -1,5 +0,0 @@
1
- [style]
2
- based_on_style = pep8
3
- blank_line_before_nested_class_or_def = false
4
- spaces_before_comment = 2
5
- split_before_logical_operator = true
 
 
 
 
 
 
.vscode/settings.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.formatOnSave": true,
3
+ "files.insertFinalNewline": false,
4
+ "[python]": {
5
+ "editor.defaultFormatter": "ms-python.black-formatter",
6
+ "editor.formatOnType": true,
7
+ "editor.codeActionsOnSave": {
8
+ "source.organizeImports": "explicit"
9
+ }
10
+ },
11
+ "[jupyter]": {
12
+ "files.insertFinalNewline": false
13
+ },
14
+ "black-formatter.args": [
15
+ "--line-length=119"
16
+ ],
17
+ "isort.args": ["--profile", "black"],
18
+ "flake8.args": [
19
+ "--max-line-length=119"
20
+ ],
21
+ "ruff.lint.args": [
22
+ "--line-length=119"
23
+ ],
24
+ "notebook.output.scrolling": true,
25
+ "notebook.formatOnCellExecution": true,
26
+ "notebook.formatOnSave.enabled": true,
27
+ "notebook.codeActionsOnSave": {
28
+ "source.organizeImports": "explicit"
29
+ }
30
+ }
app.py CHANGED
@@ -10,11 +10,11 @@ from sahi.prediction import ObjectPrediction
10
  from sahi.utils.cv import visualize_object_predictions
11
  from transformers import AutoImageProcessor, DetaForObjectDetection
12
 
13
- DESCRIPTION = '# DETA (Detection Transformers with Assignment)'
14
 
15
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
16
 
17
- MODEL_ID = 'jozhang97/deta-swin-large'
18
  image_processor = AutoImageProcessor.from_pretrained(MODEL_ID)
19
  model = DetaForObjectDetection.from_pretrained(MODEL_ID)
20
  model.to(device)
@@ -23,50 +23,44 @@ model.to(device)
23
  @torch.inference_mode()
24
  def run(image_path: str, threshold: float) -> np.ndarray:
25
  image = PIL.Image.open(image_path)
26
- inputs = image_processor(images=image, return_tensors='pt').to(device)
27
  outputs = model(**inputs)
28
  target_sizes = torch.tensor([image.size[::-1]])
29
- results = image_processor.post_process_object_detection(
30
- outputs, threshold=threshold, target_sizes=target_sizes)[0]
31
 
32
- boxes = results['boxes'].cpu().numpy()
33
- scores = results['scores'].cpu().numpy()
34
- cat_ids = results['labels'].cpu().numpy().tolist()
35
 
36
  preds = []
37
  for box, score, cat_id in zip(boxes, scores, cat_ids):
38
  box = np.round(box).astype(int)
39
  cat_label = model.config.id2label[cat_id]
40
- pred = ObjectPrediction(bbox=box,
41
- category_id=cat_id,
42
- category_name=cat_label,
43
- score=score)
44
  preds.append(pred)
45
 
46
- res = visualize_object_predictions(np.asarray(image), preds)['image']
47
  return res
48
 
49
 
50
- with gr.Blocks(css='style.css') as demo:
51
  gr.Markdown(DESCRIPTION)
52
  with gr.Row():
53
  with gr.Column():
54
- image = gr.Image(label='Input image', type='filepath')
55
- threshold = gr.Slider(label='Score threshold',
56
- minimum=0,
57
- maximum=1,
58
- value=0.1,
59
- step=0.01)
60
- run_button = gr.Button('Run')
61
- result = gr.Image(label='Result', type='numpy')
62
 
63
  with gr.Row():
64
- paths = sorted(pathlib.Path('images').glob('*.jpg'))
65
- gr.Examples(examples=[[path.as_posix(), 0.1] for path in paths],
66
- inputs=[image, threshold],
67
- outputs=result,
68
- fn=run,
69
- cache_examples=True)
 
 
70
 
71
  run_button.click(fn=run, inputs=[image, threshold], outputs=result)
72
 
 
10
  from sahi.utils.cv import visualize_object_predictions
11
  from transformers import AutoImageProcessor, DetaForObjectDetection
12
 
13
+ DESCRIPTION = "# DETA (Detection Transformers with Assignment)"
14
 
15
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
 
17
+ MODEL_ID = "jozhang97/deta-swin-large"
18
  image_processor = AutoImageProcessor.from_pretrained(MODEL_ID)
19
  model = DetaForObjectDetection.from_pretrained(MODEL_ID)
20
  model.to(device)
 
23
  @torch.inference_mode()
24
  def run(image_path: str, threshold: float) -> np.ndarray:
25
  image = PIL.Image.open(image_path)
26
+ inputs = image_processor(images=image, return_tensors="pt").to(device)
27
  outputs = model(**inputs)
28
  target_sizes = torch.tensor([image.size[::-1]])
29
+ results = image_processor.post_process_object_detection(outputs, threshold=threshold, target_sizes=target_sizes)[0]
 
30
 
31
+ boxes = results["boxes"].cpu().numpy()
32
+ scores = results["scores"].cpu().numpy()
33
+ cat_ids = results["labels"].cpu().numpy().tolist()
34
 
35
  preds = []
36
  for box, score, cat_id in zip(boxes, scores, cat_ids):
37
  box = np.round(box).astype(int)
38
  cat_label = model.config.id2label[cat_id]
39
+ pred = ObjectPrediction(bbox=box, category_id=cat_id, category_name=cat_label, score=score)
 
 
 
40
  preds.append(pred)
41
 
42
+ res = visualize_object_predictions(np.asarray(image), preds)["image"]
43
  return res
44
 
45
 
46
+ with gr.Blocks(css="style.css") as demo:
47
  gr.Markdown(DESCRIPTION)
48
  with gr.Row():
49
  with gr.Column():
50
+ image = gr.Image(label="Input image", type="filepath")
51
+ threshold = gr.Slider(label="Score threshold", minimum=0, maximum=1, value=0.1, step=0.01)
52
+ run_button = gr.Button("Run")
53
+ result = gr.Image(label="Result", type="numpy")
 
 
 
 
54
 
55
  with gr.Row():
56
+ paths = sorted(pathlib.Path("images").glob("*.jpg"))
57
+ gr.Examples(
58
+ examples=[[path.as_posix(), 0.1] for path in paths],
59
+ inputs=[image, threshold],
60
+ outputs=result,
61
+ fn=run,
62
+ cache_examples=True,
63
+ )
64
 
65
  run_button.click(fn=run, inputs=[image, threshold], outputs=result)
66