Spaces:
Configuration error
Configuration error
deploy: 70e91c39c93ea64c5ddd81aee7881cf044abb76f
Browse files- README.md +0 -12
- layout_maximum_iou.py +193 -0
- requirements.txt +90 -0
README.md
CHANGED
@@ -1,12 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: Layout Maximum Iou
|
3 |
-
emoji: 🌍
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: yellow
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 4.18.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layout_maximum_iou.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from itertools import chain
|
3 |
+
from typing import Dict, List, Tuple, TypedDict
|
4 |
+
|
5 |
+
import datasets as ds
|
6 |
+
import evaluate
|
7 |
+
import numpy as np
|
8 |
+
import numpy.typing as npt
|
9 |
+
from scipy.optimize import linear_sum_assignment
|
10 |
+
|
11 |
+
|
12 |
+
class Layout(TypedDict):
|
13 |
+
bboxes: npt.NDArray[np.float64]
|
14 |
+
categories: npt.NDArray[np.int64]
|
15 |
+
|
16 |
+
|
17 |
+
_DESCRIPTION = """\
|
18 |
+
|
19 |
+
"""
|
20 |
+
|
21 |
+
_CITATION = """\
|
22 |
+
|
23 |
+
"""
|
24 |
+
|
25 |
+
|
26 |
+
def convert_xywh_to_ltrb(
|
27 |
+
batch_bbox: npt.NDArray[np.float64],
|
28 |
+
) -> Tuple[
|
29 |
+
npt.NDArray[np.float64],
|
30 |
+
npt.NDArray[np.float64],
|
31 |
+
npt.NDArray[np.float64],
|
32 |
+
npt.NDArray[np.float64],
|
33 |
+
]:
|
34 |
+
xc, yc, w, h = batch_bbox
|
35 |
+
x1 = xc - w / 2
|
36 |
+
y1 = yc - h / 2
|
37 |
+
x2 = xc + w / 2
|
38 |
+
y2 = yc + h / 2
|
39 |
+
return (x1, y1, x2, y2)
|
40 |
+
|
41 |
+
|
42 |
+
def _compute_iou(
|
43 |
+
bbox1: npt.NDArray[np.float64],
|
44 |
+
bbox2: npt.NDArray[np.float64],
|
45 |
+
generalized: bool = False,
|
46 |
+
):
|
47 |
+
# shape: bbox1 (N, 4), bbox2 (N, 4)
|
48 |
+
assert bbox1.shape[0] == bbox2.shape[0]
|
49 |
+
assert bbox1.shape[1] == bbox1.shape[1] == 4
|
50 |
+
|
51 |
+
l1, t1, r1, b1 = convert_xywh_to_ltrb(bbox1.T)
|
52 |
+
l2, t2, r2, b2 = convert_xywh_to_ltrb(bbox2.T)
|
53 |
+
a1, a2 = (r1 - l1) * (b1 - t1), (r2 - l2) * (b2 - t2)
|
54 |
+
|
55 |
+
# intersection
|
56 |
+
l_max = np.maximum(l1, l2)
|
57 |
+
r_min = np.minimum(r1, r2)
|
58 |
+
t_max = np.maximum(t1, t2)
|
59 |
+
b_min = np.minimum(b1, b2)
|
60 |
+
cond = (l_max < r_min) & (t_max < b_min)
|
61 |
+
ai = np.where(cond, (r_min - l_max) * (b_min - t_max), np.zeros_like(a1[0]))
|
62 |
+
|
63 |
+
au = a1 + a2 - ai
|
64 |
+
iou = ai / au
|
65 |
+
|
66 |
+
if not generalized:
|
67 |
+
return iou
|
68 |
+
|
69 |
+
# outer region
|
70 |
+
l_min = np.minimum(l1, l2)
|
71 |
+
r_max = np.maximum(r1, r2)
|
72 |
+
t_min = np.minimum(t1, t2)
|
73 |
+
b_max = np.maximum(b1, b2)
|
74 |
+
ac = (r_max - l_min) * (b_max - t_min)
|
75 |
+
|
76 |
+
giou = iou - (ac - au) / ac
|
77 |
+
|
78 |
+
return giou
|
79 |
+
|
80 |
+
|
81 |
+
def _compute_maximum_iou_for_layout(layout1: Layout, layout2: Layout):
|
82 |
+
score = 0.0
|
83 |
+
bi, ci = layout1["bboxes"], layout1["categories"]
|
84 |
+
bj, cj = layout2["bboxes"], layout2["categories"]
|
85 |
+
N = len(bi)
|
86 |
+
|
87 |
+
for c in list(set(ci.tolist())):
|
88 |
+
_bi = bi[np.where(ci == c)]
|
89 |
+
_bj = bj[np.where(cj == c)]
|
90 |
+
n = len(_bi)
|
91 |
+
ii, jj = np.meshgrid(range(n), range(n))
|
92 |
+
ii, jj = ii.flatten(), jj.flatten()
|
93 |
+
iou = _compute_iou(_bi[ii], _bj[jj]).reshape(n, n)
|
94 |
+
# Note: maximize is supported only when scipy >= 1.4
|
95 |
+
ii, jj = linear_sum_assignment(iou, maximize=True)
|
96 |
+
score += iou[ii, jj].sum().item()
|
97 |
+
return score / N
|
98 |
+
|
99 |
+
|
100 |
+
def _compute_maximum_iou(
|
101 |
+
layouts_1_and_2: Tuple[List[Layout], List[Layout]]
|
102 |
+
) -> npt.NDArray[np.float64]:
|
103 |
+
assert len(layouts_1_and_2) == 2
|
104 |
+
layouts1, layouts2 = layouts_1_and_2
|
105 |
+
|
106 |
+
N, M = len(layouts1), len(layouts2)
|
107 |
+
ii, jj = np.meshgrid(range(N), range(M))
|
108 |
+
ii, jj = ii.flatten(), jj.flatten()
|
109 |
+
scores = np.asarray(
|
110 |
+
[
|
111 |
+
_compute_maximum_iou_for_layout(layouts1[i], layouts2[j])
|
112 |
+
for i, j in zip(ii, jj)
|
113 |
+
]
|
114 |
+
)
|
115 |
+
scores = scores.reshape(N, M)
|
116 |
+
ii, jj = linear_sum_assignment(scores, maximize=True)
|
117 |
+
return scores[ii, jj]
|
118 |
+
|
119 |
+
|
120 |
+
def _get_cond_to_layouts(layouts: List[Layout]) -> Dict[str, List[Layout]]:
|
121 |
+
out = defaultdict(list)
|
122 |
+
for layout in layouts:
|
123 |
+
bboxes_list = layout["bboxes"]
|
124 |
+
categories_list = layout["categories"]
|
125 |
+
assert len(bboxes_list) == len(categories_list)
|
126 |
+
|
127 |
+
for bboxes, categories in zip(bboxes_list, categories_list):
|
128 |
+
bboxes = np.array(bboxes)
|
129 |
+
cond_key = str(sorted(categories))
|
130 |
+
categories = np.array(categories)
|
131 |
+
layout_dict: Layout = {"bboxes": bboxes, "categories": categories}
|
132 |
+
out[cond_key].append(layout_dict)
|
133 |
+
return out
|
134 |
+
|
135 |
+
|
136 |
+
def compute_maximum_iou(args):
|
137 |
+
return [_compute_maximum_iou(a) for a in args]
|
138 |
+
|
139 |
+
|
140 |
+
class LayoutMaximumIoU(evaluate.Metric):
|
141 |
+
|
142 |
+
def _info(self) -> evaluate.EvaluationModuleInfo:
|
143 |
+
return evaluate.EvaluationModuleInfo(
|
144 |
+
description=_DESCRIPTION,
|
145 |
+
citation=_CITATION,
|
146 |
+
features=ds.Features(
|
147 |
+
{
|
148 |
+
"layouts1": ds.Sequence(
|
149 |
+
{
|
150 |
+
"bboxes": ds.Sequence(ds.Sequence((ds.Value("float64")))),
|
151 |
+
"categories": ds.Sequence(ds.Value("int64")),
|
152 |
+
}
|
153 |
+
),
|
154 |
+
"layouts2": ds.Sequence(
|
155 |
+
{
|
156 |
+
"bboxes": ds.Sequence(ds.Sequence((ds.Value("float64")))),
|
157 |
+
"categories": ds.Sequence(ds.Value("int64")),
|
158 |
+
}
|
159 |
+
),
|
160 |
+
}
|
161 |
+
),
|
162 |
+
codebase_urls=[
|
163 |
+
"https://github.com/CyberAgentAILab/layout-dm/blob/main/src/trainer/trainer/helpers/metric.py#L206-L247",
|
164 |
+
"https://github.com/CyberAgentAILab/layout-dm/blob/main/src/trainer/trainer/helpers/metric.py#L250-L297",
|
165 |
+
"https://github.com/CyberAgentAILab/layout-dm/blob/main/src/trainer/trainer/helpers/metric.py#L300-L314",
|
166 |
+
"https://github.com/CyberAgentAILab/layout-dm/blob/main/src/trainer/trainer/helpers/metric.py#L317-L329",
|
167 |
+
"https://github.com/CyberAgentAILab/layout-dm/blob/main/src/trainer/trainer/helpers/metric.py#L332-L340",
|
168 |
+
"https://github.com/CyberAgentAILab/layout-dm/blob/main/src/trainer/trainer/helpers/metric.py#L343-L371",
|
169 |
+
],
|
170 |
+
)
|
171 |
+
|
172 |
+
def _compute(
|
173 |
+
self,
|
174 |
+
*,
|
175 |
+
layouts1: List[Layout],
|
176 |
+
layouts2: List[Layout],
|
177 |
+
) -> float:
|
178 |
+
c2bl_1 = _get_cond_to_layouts(layouts1)
|
179 |
+
keys_1 = set(c2bl_1.keys())
|
180 |
+
c2bl_2 = _get_cond_to_layouts(layouts2)
|
181 |
+
keys_2 = set(c2bl_2.keys())
|
182 |
+
keys = list(keys_1.intersection(keys_2))
|
183 |
+
args = [(c2bl_1[key], c2bl_2[key]) for key in keys]
|
184 |
+
|
185 |
+
# to check actual number of layouts for evaluation
|
186 |
+
# ans = 0
|
187 |
+
# for x in args:
|
188 |
+
# ans += len(x[0])
|
189 |
+
|
190 |
+
scores = compute_maximum_iou(args)
|
191 |
+
scores = np.asarray(list(chain.from_iterable(scores)))
|
192 |
+
|
193 |
+
return scores.mean().item() if len(scores) != 0 else 0.0
|
requirements.txt
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1 ; python_version >= "3.9" and python_version < "4.0"
|
2 |
+
aiohttp==3.9.3 ; python_version >= "3.9" and python_version < "4.0"
|
3 |
+
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "4.0"
|
4 |
+
altair==5.2.0 ; python_version >= "3.9" and python_version < "4.0"
|
5 |
+
annotated-types==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
6 |
+
anyio==4.2.0 ; python_version >= "3.9" and python_version < "4.0"
|
7 |
+
arrow==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
|
8 |
+
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11"
|
9 |
+
attrs==23.2.0 ; python_version >= "3.9" and python_version < "4.0"
|
10 |
+
binaryornot==0.4.4 ; python_version >= "3.9" and python_version < "4.0"
|
11 |
+
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "4.0"
|
12 |
+
chardet==5.2.0 ; python_version >= "3.9" and python_version < "4.0"
|
13 |
+
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "4.0"
|
14 |
+
click==8.1.7 ; python_version >= "3.9" and python_version < "4.0"
|
15 |
+
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0"
|
16 |
+
contourpy==1.2.0 ; python_version >= "3.9" and python_version < "4.0"
|
17 |
+
cookiecutter==2.5.0 ; python_version >= "3.9" and python_version < "4.0"
|
18 |
+
cycler==0.12.1 ; python_version >= "3.9" and python_version < "4.0"
|
19 |
+
datasets==2.17.0 ; python_version >= "3.9" and python_version < "4.0"
|
20 |
+
dill==0.3.8 ; python_version >= "3.9" and python_version < "4.0"
|
21 |
+
evaluate[template]==0.4.1 ; python_version >= "3.9" and python_version < "4.0"
|
22 |
+
exceptiongroup==1.2.0 ; python_version >= "3.9" and python_version < "3.11"
|
23 |
+
fastapi==0.109.2 ; python_version >= "3.9" and python_version < "4.0"
|
24 |
+
ffmpy==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
|
25 |
+
filelock==3.13.1 ; python_version >= "3.9" and python_version < "4.0"
|
26 |
+
fonttools==4.48.1 ; python_version >= "3.9" and python_version < "4.0"
|
27 |
+
frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "4.0"
|
28 |
+
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "4.0"
|
29 |
+
fsspec[http]==2023.10.0 ; python_version >= "3.9" and python_version < "4.0"
|
30 |
+
gradio-client==0.10.0 ; python_version >= "3.9" and python_version < "4.0"
|
31 |
+
gradio==4.18.0 ; python_version >= "3.9" and python_version < "4.0"
|
32 |
+
h11==0.14.0 ; python_version >= "3.9" and python_version < "4.0"
|
33 |
+
httpcore==1.0.2 ; python_version >= "3.9" and python_version < "4.0"
|
34 |
+
httpx==0.26.0 ; python_version >= "3.9" and python_version < "4.0"
|
35 |
+
huggingface-hub==0.20.3 ; python_version >= "3.9" and python_version < "4.0"
|
36 |
+
idna==3.6 ; python_version >= "3.9" and python_version < "4.0"
|
37 |
+
importlib-resources==6.1.1 ; python_version >= "3.9" and python_version < "4.0"
|
38 |
+
jinja2==3.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
39 |
+
jsonschema-specifications==2023.12.1 ; python_version >= "3.9" and python_version < "4.0"
|
40 |
+
jsonschema==4.21.1 ; python_version >= "3.9" and python_version < "4.0"
|
41 |
+
kiwisolver==1.4.5 ; python_version >= "3.9" and python_version < "4.0"
|
42 |
+
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "4.0"
|
43 |
+
markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "4.0"
|
44 |
+
matplotlib==3.8.2 ; python_version >= "3.9" and python_version < "4.0"
|
45 |
+
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "4.0"
|
46 |
+
multidict==6.0.5 ; python_version >= "3.9" and python_version < "4.0"
|
47 |
+
multiprocess==0.70.16 ; python_version >= "3.9" and python_version < "4.0"
|
48 |
+
numpy==1.26.4 ; python_version >= "3.9" and python_version < "4.0"
|
49 |
+
orjson==3.9.13 ; python_version >= "3.9" and python_version < "4.0"
|
50 |
+
packaging==23.2 ; python_version >= "3.9" and python_version < "4.0"
|
51 |
+
pandas==2.2.0 ; python_version >= "3.9" and python_version < "4.0"
|
52 |
+
pillow==10.2.0 ; python_version >= "3.9" and python_version < "4.0"
|
53 |
+
pyarrow-hotfix==0.6 ; python_version >= "3.9" and python_version < "4.0"
|
54 |
+
pyarrow==15.0.0 ; python_version >= "3.9" and python_version < "4.0"
|
55 |
+
pydantic-core==2.16.2 ; python_version >= "3.9" and python_version < "4.0"
|
56 |
+
pydantic==2.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
57 |
+
pydub==0.25.1 ; python_version >= "3.9" and python_version < "4.0"
|
58 |
+
pygments==2.17.2 ; python_version >= "3.9" and python_version < "4.0"
|
59 |
+
pyparsing==3.1.1 ; python_version >= "3.9" and python_version < "4.0"
|
60 |
+
python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "4.0"
|
61 |
+
python-multipart==0.0.9 ; python_version >= "3.9" and python_version < "4.0"
|
62 |
+
python-slugify==8.0.4 ; python_version >= "3.9" and python_version < "4.0"
|
63 |
+
pytz==2024.1 ; python_version >= "3.9" and python_version < "4.0"
|
64 |
+
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "4.0"
|
65 |
+
referencing==0.33.0 ; python_version >= "3.9" and python_version < "4.0"
|
66 |
+
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
|
67 |
+
responses==0.18.0 ; python_version >= "3.9" and python_version < "4.0"
|
68 |
+
rich==13.7.0 ; python_version >= "3.9" and python_version < "4.0"
|
69 |
+
rpds-py==0.17.1 ; python_version >= "3.9" and python_version < "4.0"
|
70 |
+
ruff==0.2.1 ; python_version >= "3.9" and python_version < "4.0"
|
71 |
+
scipy==1.12.0 ; python_version >= "3.9" and python_version < "4.0"
|
72 |
+
semantic-version==2.10.0 ; python_version >= "3.9" and python_version < "4.0"
|
73 |
+
shellingham==1.5.4 ; python_version >= "3.9" and python_version < "4.0"
|
74 |
+
six==1.16.0 ; python_version >= "3.9" and python_version < "4.0"
|
75 |
+
sniffio==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
|
76 |
+
starlette==0.36.3 ; python_version >= "3.9" and python_version < "4.0"
|
77 |
+
text-unidecode==1.3 ; python_version >= "3.9" and python_version < "4.0"
|
78 |
+
tomlkit==0.12.0 ; python_version >= "3.9" and python_version < "4.0"
|
79 |
+
toolz==0.12.1 ; python_version >= "3.9" and python_version < "4.0"
|
80 |
+
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "4.0"
|
81 |
+
typer[all]==0.9.0 ; python_version >= "3.9" and python_version < "4.0"
|
82 |
+
types-python-dateutil==2.8.19.20240106 ; python_version >= "3.9" and python_version < "4.0"
|
83 |
+
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "4.0"
|
84 |
+
tzdata==2024.1 ; python_version >= "3.9" and python_version < "4.0"
|
85 |
+
urllib3==2.2.0 ; python_version >= "3.9" and python_version < "4.0"
|
86 |
+
uvicorn==0.27.1 ; python_version >= "3.9" and python_version < "4.0"
|
87 |
+
websockets==11.0.3 ; python_version >= "3.9" and python_version < "4.0"
|
88 |
+
xxhash==3.4.1 ; python_version >= "3.9" and python_version < "4.0"
|
89 |
+
yarl==1.9.4 ; python_version >= "3.9" and python_version < "4.0"
|
90 |
+
zipp==3.17.0 ; python_version >= "3.9" and python_version < "3.10"
|