shunk031 commited on
Commit
98de260
1 Parent(s): 7eaaf55

deploy: 70e91c39c93ea64c5ddd81aee7881cf044abb76f

Browse files
Files changed (3) hide show
  1. README.md +0 -12
  2. layout_maximum_iou.py +193 -0
  3. requirements.txt +90 -0
README.md CHANGED
@@ -1,12 +0,0 @@
1
- ---
2
- title: Layout Maximum Iou
3
- emoji: 🌍
4
- colorFrom: gray
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 4.18.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
layout_maximum_iou.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from itertools import chain
3
+ from typing import Dict, List, Tuple, TypedDict
4
+
5
+ import datasets as ds
6
+ import evaluate
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ from scipy.optimize import linear_sum_assignment
10
+
11
+
12
+ class Layout(TypedDict):
13
+ bboxes: npt.NDArray[np.float64]
14
+ categories: npt.NDArray[np.int64]
15
+
16
+
17
+ _DESCRIPTION = """\
18
+
19
+ """
20
+
21
+ _CITATION = """\
22
+
23
+ """
24
+
25
+
26
+ def convert_xywh_to_ltrb(
27
+ batch_bbox: npt.NDArray[np.float64],
28
+ ) -> Tuple[
29
+ npt.NDArray[np.float64],
30
+ npt.NDArray[np.float64],
31
+ npt.NDArray[np.float64],
32
+ npt.NDArray[np.float64],
33
+ ]:
34
+ xc, yc, w, h = batch_bbox
35
+ x1 = xc - w / 2
36
+ y1 = yc - h / 2
37
+ x2 = xc + w / 2
38
+ y2 = yc + h / 2
39
+ return (x1, y1, x2, y2)
40
+
41
+
42
+ def _compute_iou(
43
+ bbox1: npt.NDArray[np.float64],
44
+ bbox2: npt.NDArray[np.float64],
45
+ generalized: bool = False,
46
+ ):
47
+ # shape: bbox1 (N, 4), bbox2 (N, 4)
48
+ assert bbox1.shape[0] == bbox2.shape[0]
49
+ assert bbox1.shape[1] == bbox1.shape[1] == 4
50
+
51
+ l1, t1, r1, b1 = convert_xywh_to_ltrb(bbox1.T)
52
+ l2, t2, r2, b2 = convert_xywh_to_ltrb(bbox2.T)
53
+ a1, a2 = (r1 - l1) * (b1 - t1), (r2 - l2) * (b2 - t2)
54
+
55
+ # intersection
56
+ l_max = np.maximum(l1, l2)
57
+ r_min = np.minimum(r1, r2)
58
+ t_max = np.maximum(t1, t2)
59
+ b_min = np.minimum(b1, b2)
60
+ cond = (l_max < r_min) & (t_max < b_min)
61
+ ai = np.where(cond, (r_min - l_max) * (b_min - t_max), np.zeros_like(a1[0]))
62
+
63
+ au = a1 + a2 - ai
64
+ iou = ai / au
65
+
66
+ if not generalized:
67
+ return iou
68
+
69
+ # outer region
70
+ l_min = np.minimum(l1, l2)
71
+ r_max = np.maximum(r1, r2)
72
+ t_min = np.minimum(t1, t2)
73
+ b_max = np.maximum(b1, b2)
74
+ ac = (r_max - l_min) * (b_max - t_min)
75
+
76
+ giou = iou - (ac - au) / ac
77
+
78
+ return giou
79
+
80
+
81
+ def _compute_maximum_iou_for_layout(layout1: Layout, layout2: Layout):
82
+ score = 0.0
83
+ bi, ci = layout1["bboxes"], layout1["categories"]
84
+ bj, cj = layout2["bboxes"], layout2["categories"]
85
+ N = len(bi)
86
+
87
+ for c in list(set(ci.tolist())):
88
+ _bi = bi[np.where(ci == c)]
89
+ _bj = bj[np.where(cj == c)]
90
+ n = len(_bi)
91
+ ii, jj = np.meshgrid(range(n), range(n))
92
+ ii, jj = ii.flatten(), jj.flatten()
93
+ iou = _compute_iou(_bi[ii], _bj[jj]).reshape(n, n)
94
+ # Note: maximize is supported only when scipy >= 1.4
95
+ ii, jj = linear_sum_assignment(iou, maximize=True)
96
+ score += iou[ii, jj].sum().item()
97
+ return score / N
98
+
99
+
100
+ def _compute_maximum_iou(
101
+ layouts_1_and_2: Tuple[List[Layout], List[Layout]]
102
+ ) -> npt.NDArray[np.float64]:
103
+ assert len(layouts_1_and_2) == 2
104
+ layouts1, layouts2 = layouts_1_and_2
105
+
106
+ N, M = len(layouts1), len(layouts2)
107
+ ii, jj = np.meshgrid(range(N), range(M))
108
+ ii, jj = ii.flatten(), jj.flatten()
109
+ scores = np.asarray(
110
+ [
111
+ _compute_maximum_iou_for_layout(layouts1[i], layouts2[j])
112
+ for i, j in zip(ii, jj)
113
+ ]
114
+ )
115
+ scores = scores.reshape(N, M)
116
+ ii, jj = linear_sum_assignment(scores, maximize=True)
117
+ return scores[ii, jj]
118
+
119
+
120
+ def _get_cond_to_layouts(layouts: List[Layout]) -> Dict[str, List[Layout]]:
121
+ out = defaultdict(list)
122
+ for layout in layouts:
123
+ bboxes_list = layout["bboxes"]
124
+ categories_list = layout["categories"]
125
+ assert len(bboxes_list) == len(categories_list)
126
+
127
+ for bboxes, categories in zip(bboxes_list, categories_list):
128
+ bboxes = np.array(bboxes)
129
+ cond_key = str(sorted(categories))
130
+ categories = np.array(categories)
131
+ layout_dict: Layout = {"bboxes": bboxes, "categories": categories}
132
+ out[cond_key].append(layout_dict)
133
+ return out
134
+
135
+
136
+ def compute_maximum_iou(args):
137
+ return [_compute_maximum_iou(a) for a in args]
138
+
139
+
140
+ class LayoutMaximumIoU(evaluate.Metric):
141
+
142
+ def _info(self) -> evaluate.EvaluationModuleInfo:
143
+ return evaluate.EvaluationModuleInfo(
144
+ description=_DESCRIPTION,
145
+ citation=_CITATION,
146
+ features=ds.Features(
147
+ {
148
+ "layouts1": ds.Sequence(
149
+ {
150
+ "bboxes": ds.Sequence(ds.Sequence((ds.Value("float64")))),
151
+ "categories": ds.Sequence(ds.Value("int64")),
152
+ }
153
+ ),
154
+ "layouts2": ds.Sequence(
155
+ {
156
+ "bboxes": ds.Sequence(ds.Sequence((ds.Value("float64")))),
157
+ "categories": ds.Sequence(ds.Value("int64")),
158
+ }
159
+ ),
160
+ }
161
+ ),
162
+ codebase_urls=[
163
+ "https://github.com/CyberAgentAILab/layout-dm/blob/main/src/trainer/trainer/helpers/metric.py#L206-L247",
164
+ "https://github.com/CyberAgentAILab/layout-dm/blob/main/src/trainer/trainer/helpers/metric.py#L250-L297",
165
+ "https://github.com/CyberAgentAILab/layout-dm/blob/main/src/trainer/trainer/helpers/metric.py#L300-L314",
166
+ "https://github.com/CyberAgentAILab/layout-dm/blob/main/src/trainer/trainer/helpers/metric.py#L317-L329",
167
+ "https://github.com/CyberAgentAILab/layout-dm/blob/main/src/trainer/trainer/helpers/metric.py#L332-L340",
168
+ "https://github.com/CyberAgentAILab/layout-dm/blob/main/src/trainer/trainer/helpers/metric.py#L343-L371",
169
+ ],
170
+ )
171
+
172
+ def _compute(
173
+ self,
174
+ *,
175
+ layouts1: List[Layout],
176
+ layouts2: List[Layout],
177
+ ) -> float:
178
+ c2bl_1 = _get_cond_to_layouts(layouts1)
179
+ keys_1 = set(c2bl_1.keys())
180
+ c2bl_2 = _get_cond_to_layouts(layouts2)
181
+ keys_2 = set(c2bl_2.keys())
182
+ keys = list(keys_1.intersection(keys_2))
183
+ args = [(c2bl_1[key], c2bl_2[key]) for key in keys]
184
+
185
+ # to check actual number of layouts for evaluation
186
+ # ans = 0
187
+ # for x in args:
188
+ # ans += len(x[0])
189
+
190
+ scores = compute_maximum_iou(args)
191
+ scores = np.asarray(list(chain.from_iterable(scores)))
192
+
193
+ return scores.mean().item() if len(scores) != 0 else 0.0
requirements.txt ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1 ; python_version >= "3.9" and python_version < "4.0"
2
+ aiohttp==3.9.3 ; python_version >= "3.9" and python_version < "4.0"
3
+ aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "4.0"
4
+ altair==5.2.0 ; python_version >= "3.9" and python_version < "4.0"
5
+ annotated-types==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
6
+ anyio==4.2.0 ; python_version >= "3.9" and python_version < "4.0"
7
+ arrow==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
8
+ async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11"
9
+ attrs==23.2.0 ; python_version >= "3.9" and python_version < "4.0"
10
+ binaryornot==0.4.4 ; python_version >= "3.9" and python_version < "4.0"
11
+ certifi==2024.2.2 ; python_version >= "3.9" and python_version < "4.0"
12
+ chardet==5.2.0 ; python_version >= "3.9" and python_version < "4.0"
13
+ charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "4.0"
14
+ click==8.1.7 ; python_version >= "3.9" and python_version < "4.0"
15
+ colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0"
16
+ contourpy==1.2.0 ; python_version >= "3.9" and python_version < "4.0"
17
+ cookiecutter==2.5.0 ; python_version >= "3.9" and python_version < "4.0"
18
+ cycler==0.12.1 ; python_version >= "3.9" and python_version < "4.0"
19
+ datasets==2.17.0 ; python_version >= "3.9" and python_version < "4.0"
20
+ dill==0.3.8 ; python_version >= "3.9" and python_version < "4.0"
21
+ evaluate[template]==0.4.1 ; python_version >= "3.9" and python_version < "4.0"
22
+ exceptiongroup==1.2.0 ; python_version >= "3.9" and python_version < "3.11"
23
+ fastapi==0.109.2 ; python_version >= "3.9" and python_version < "4.0"
24
+ ffmpy==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
25
+ filelock==3.13.1 ; python_version >= "3.9" and python_version < "4.0"
26
+ fonttools==4.48.1 ; python_version >= "3.9" and python_version < "4.0"
27
+ frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "4.0"
28
+ fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "4.0"
29
+ fsspec[http]==2023.10.0 ; python_version >= "3.9" and python_version < "4.0"
30
+ gradio-client==0.10.0 ; python_version >= "3.9" and python_version < "4.0"
31
+ gradio==4.18.0 ; python_version >= "3.9" and python_version < "4.0"
32
+ h11==0.14.0 ; python_version >= "3.9" and python_version < "4.0"
33
+ httpcore==1.0.2 ; python_version >= "3.9" and python_version < "4.0"
34
+ httpx==0.26.0 ; python_version >= "3.9" and python_version < "4.0"
35
+ huggingface-hub==0.20.3 ; python_version >= "3.9" and python_version < "4.0"
36
+ idna==3.6 ; python_version >= "3.9" and python_version < "4.0"
37
+ importlib-resources==6.1.1 ; python_version >= "3.9" and python_version < "4.0"
38
+ jinja2==3.1.3 ; python_version >= "3.9" and python_version < "4.0"
39
+ jsonschema-specifications==2023.12.1 ; python_version >= "3.9" and python_version < "4.0"
40
+ jsonschema==4.21.1 ; python_version >= "3.9" and python_version < "4.0"
41
+ kiwisolver==1.4.5 ; python_version >= "3.9" and python_version < "4.0"
42
+ markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "4.0"
43
+ markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "4.0"
44
+ matplotlib==3.8.2 ; python_version >= "3.9" and python_version < "4.0"
45
+ mdurl==0.1.2 ; python_version >= "3.9" and python_version < "4.0"
46
+ multidict==6.0.5 ; python_version >= "3.9" and python_version < "4.0"
47
+ multiprocess==0.70.16 ; python_version >= "3.9" and python_version < "4.0"
48
+ numpy==1.26.4 ; python_version >= "3.9" and python_version < "4.0"
49
+ orjson==3.9.13 ; python_version >= "3.9" and python_version < "4.0"
50
+ packaging==23.2 ; python_version >= "3.9" and python_version < "4.0"
51
+ pandas==2.2.0 ; python_version >= "3.9" and python_version < "4.0"
52
+ pillow==10.2.0 ; python_version >= "3.9" and python_version < "4.0"
53
+ pyarrow-hotfix==0.6 ; python_version >= "3.9" and python_version < "4.0"
54
+ pyarrow==15.0.0 ; python_version >= "3.9" and python_version < "4.0"
55
+ pydantic-core==2.16.2 ; python_version >= "3.9" and python_version < "4.0"
56
+ pydantic==2.6.1 ; python_version >= "3.9" and python_version < "4.0"
57
+ pydub==0.25.1 ; python_version >= "3.9" and python_version < "4.0"
58
+ pygments==2.17.2 ; python_version >= "3.9" and python_version < "4.0"
59
+ pyparsing==3.1.1 ; python_version >= "3.9" and python_version < "4.0"
60
+ python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "4.0"
61
+ python-multipart==0.0.9 ; python_version >= "3.9" and python_version < "4.0"
62
+ python-slugify==8.0.4 ; python_version >= "3.9" and python_version < "4.0"
63
+ pytz==2024.1 ; python_version >= "3.9" and python_version < "4.0"
64
+ pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "4.0"
65
+ referencing==0.33.0 ; python_version >= "3.9" and python_version < "4.0"
66
+ requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
67
+ responses==0.18.0 ; python_version >= "3.9" and python_version < "4.0"
68
+ rich==13.7.0 ; python_version >= "3.9" and python_version < "4.0"
69
+ rpds-py==0.17.1 ; python_version >= "3.9" and python_version < "4.0"
70
+ ruff==0.2.1 ; python_version >= "3.9" and python_version < "4.0"
71
+ scipy==1.12.0 ; python_version >= "3.9" and python_version < "4.0"
72
+ semantic-version==2.10.0 ; python_version >= "3.9" and python_version < "4.0"
73
+ shellingham==1.5.4 ; python_version >= "3.9" and python_version < "4.0"
74
+ six==1.16.0 ; python_version >= "3.9" and python_version < "4.0"
75
+ sniffio==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
76
+ starlette==0.36.3 ; python_version >= "3.9" and python_version < "4.0"
77
+ text-unidecode==1.3 ; python_version >= "3.9" and python_version < "4.0"
78
+ tomlkit==0.12.0 ; python_version >= "3.9" and python_version < "4.0"
79
+ toolz==0.12.1 ; python_version >= "3.9" and python_version < "4.0"
80
+ tqdm==4.66.2 ; python_version >= "3.9" and python_version < "4.0"
81
+ typer[all]==0.9.0 ; python_version >= "3.9" and python_version < "4.0"
82
+ types-python-dateutil==2.8.19.20240106 ; python_version >= "3.9" and python_version < "4.0"
83
+ typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "4.0"
84
+ tzdata==2024.1 ; python_version >= "3.9" and python_version < "4.0"
85
+ urllib3==2.2.0 ; python_version >= "3.9" and python_version < "4.0"
86
+ uvicorn==0.27.1 ; python_version >= "3.9" and python_version < "4.0"
87
+ websockets==11.0.3 ; python_version >= "3.9" and python_version < "4.0"
88
+ xxhash==3.4.1 ; python_version >= "3.9" and python_version < "4.0"
89
+ yarl==1.9.4 ; python_version >= "3.9" and python_version < "4.0"
90
+ zipp==3.17.0 ; python_version >= "3.9" and python_version < "3.10"