aletrn commited on
Commit
793909c
·
1 Parent(s): 0aa759d

[test] SegmentAnythingONNX test case (encode and predict_masks - check map)

Browse files
src/prediction_api/sam_onnx.py CHANGED
@@ -1,5 +1,6 @@
1
  """
2
- machine learning segment anything class.
 
3
  Modified from https://github.com/vietanhdev/samexporter/
4
 
5
  Copyright (c) 2023 Viet Anh Nguyen
 
1
  """
2
+ Define a machine learning executed by ONNX Runtime (https://onnxruntime.ai/)
3
+ for Segment Anything (https://segment-anything.com).
4
  Modified from https://github.com/vietanhdev/samexporter/
5
 
6
  Copyright (c) 2023 Viet Anh Nguyen
tests/events/SegmentAnythingONNX/mask_output.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3e9a925a21dfc07943d587e00013e6f380bf1072d1e87ea89c8e5e4f78e4cad
3
+ size 700544
tests/prediction_api/test_sam_onnx.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import unittest
3
+ from unittest.mock import patch
4
+
5
+ import numpy as np
6
+
7
+ from src import MODEL_FOLDER
8
+ from src.prediction_api import sam_onnx
9
+ from src.prediction_api.sam_onnx import SegmentAnythingONNX
10
+ from src.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME
11
+ from src.utilities.utilities import hash_calculate
12
+ from tests import TEST_EVENTS_FOLDER
13
+
14
+
15
+ instance_sam_onnx = SegmentAnythingONNX(
16
+ encoder_model_path=MODEL_FOLDER / MODEL_ENCODER_NAME,
17
+ decoder_model_path=MODEL_FOLDER / MODEL_DECODER_NAME
18
+ )
19
+ np_img = np.load(TEST_EVENTS_FOLDER / "samexporter_predict" / "oceania" / "img.npy")
20
+ prompt = [{
21
+ "type": "point",
22
+ "data": [934, 510],
23
+ "label": 0
24
+ }]
25
+
26
+
27
+ class TestSegmentAnythingONNX(unittest.TestCase):
28
+ def test_encode_predict_masks_ok(self):
29
+ embedding = instance_sam_onnx.encode(np_img)
30
+ try:
31
+ assert hash_calculate(embedding) == b"m2O3y7pNUwlLuAZhBHkRIu8cDIIej0oOmWOXevs39r4="
32
+ except AssertionError as ae1:
33
+ logging.warning(f"ae1:{ae1}.")
34
+ inference_mask = instance_sam_onnx.predict_masks(embedding, prompt)
35
+ try:
36
+ assert hash_calculate(inference_mask) == b'YSKKNCs3AMpbeDUVwqIwNQqJ365OG4239hxjFnW7XTM='
37
+ except AssertionError as ae2:
38
+ logging.warning(f"ae2:{ae2}.")
39
+ mask_output = np.zeros((inference_mask.shape[2], inference_mask.shape[3]), dtype=np.uint8)
40
+ for n, m in enumerate(inference_mask[0, :, :, :]):
41
+ logging.debug(f"{n}th of prediction_masks shape {inference_mask.shape}"
42
+ f" => mask shape:{mask_output.shape}, {mask_output.dtype}.")
43
+ mask_output[m > 0.0] = 255
44
+ mask_expected = np.load(TEST_EVENTS_FOLDER / "SegmentAnythingONNX" / "mask_output.npy")
45
+
46
+ # assert MAP (mean average precision) is 100%
47
+ # sum expected mask to output mask:
48
+ # - asserted "good" inference values are 2 (matched object) or 0 (matched background)
49
+ # - "bad" inference value is 1 (there are differences between expected and output mask)
50
+ sum_mask_output_vs_expected = mask_expected / 255 + mask_output / 255
51
+ unique_values__output_vs_expected = np.unique(sum_mask_output_vs_expected, return_counts=True)
52
+ tot = sum_mask_output_vs_expected.size
53
+ perc = {
54
+ k: 100 * v / tot for
55
+ k, v in
56
+ zip(unique_values__output_vs_expected[0], unique_values__output_vs_expected[1])
57
+ }
58
+ try:
59
+ assert 1 not in perc
60
+ except AssertionError:
61
+ logging.error(f"found {perc[1]} % different pixels between expected masks and output mask.")
62
+ # try to assert that the % of different pixels are minor than 5%
63
+ assert perc[1] < 5
64
+
65
+ def test_encode_predict_masks_ex1(self):
66
+ instance_sam_onnx = SegmentAnythingONNX(
67
+ encoder_model_path=MODEL_FOLDER / MODEL_ENCODER_NAME,
68
+ decoder_model_path=MODEL_FOLDER / MODEL_DECODER_NAME
69
+ )
70
+ with self.assertRaises(Exception):
71
+ try:
72
+ np_input = np.zeros((10, 10))
73
+ instance_sam_onnx.encode(np_input)
74
+ except Exception as e:
75
+ logging.error(f"e:{e}.")
76
+ msg = "[ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid rank for input: input_image "
77
+ msg += "Got: 2 Expected: 3 Please fix either the inputs or the model."
78
+ assert str(e) == msg
79
+ raise e
80
+
81
+ def test_encode_predict_masks_ex2(self):
82
+ wrong_prompt = [{
83
+ "type": "rectangle",
84
+ "data": [934, 510],
85
+ "label": 0
86
+ }]
87
+ embedding = instance_sam_onnx.encode(np_img)
88
+
89
+ with self.assertRaises(IndexError):
90
+ try:
91
+ instance_sam_onnx.predict_masks(embedding, wrong_prompt)
92
+ except IndexError as ie:
93
+ print(ie)
94
+ assert str(ie) == "list index out of range"
95
+ raise ie