shunk031 commited on
Commit
d65e682
β€’
1 Parent(s): 41c4309

deploy: 9b091bed6e92624a368d339225a2373c3207ec09

Browse files
Files changed (3) hide show
  1. README.md +5 -5
  2. layout_overlap.py +182 -0
  3. requirements.txt +89 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Layout Overlap
3
- emoji: πŸƒ
4
- colorFrom: gray
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 4.18.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: Layout Alignment
3
+ emoji: πŸ“Š
4
+ colorFrom: pink
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.17.0
8
  app_file: app.py
9
  pinned: false
10
  ---
layout_overlap.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, TypedDict, Union
2
+
3
+ import datasets as ds
4
+ import evaluate
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+
8
+ _DESCRIPTION = """\
9
+ Computes some alignment metrics that are different to each other in previous works.
10
+ """
11
+
12
+ _CITATION = """\
13
+ @inproceedings{lee2020neural,
14
+ title={Neural design network: Graphic layout generation with constraints},
15
+ author={Lee, Hsin-Ying and Jiang, Lu and Essa, Irfan and Le, Phuong B and Gong, Haifeng and Yang, Ming-Hsuan and Yang, Weilong},
16
+ booktitle={Computer Vision--ECCV 2020: 16th European Conference, Glasgow, UK, August 23--28, 2020, Proceedings, Part III 16},
17
+ pages={491--506},
18
+ year={2020},
19
+ organization={Springer}
20
+ }
21
+
22
+ @article{li2020attribute,
23
+ title={Attribute-conditioned layout gan for automatic graphic design},
24
+ author={Li, Jianan and Yang, Jimei and Zhang, Jianming and Liu, Chang and Wang, Christina and Xu, Tingfa},
25
+ journal={IEEE Transactions on Visualization and Computer Graphics},
26
+ volume={27},
27
+ number={10},
28
+ pages={4039--4048},
29
+ year={2020},
30
+ publisher={IEEE}
31
+ }
32
+
33
+ @inproceedings{kikuchi2021constrained,
34
+ title={Constrained graphic layout generation via latent optimization},
35
+ author={Kikuchi, Kotaro and Simo-Serra, Edgar and Otani, Mayu and Yamaguchi, Kota},
36
+ booktitle={Proceedings of the 29th ACM International Conference on Multimedia},
37
+ pages={88--96},
38
+ year={2021}
39
+ }
40
+ """
41
+
42
+
43
+ def convert_xywh_to_ltrb(
44
+ batch_bbox: npt.NDArray[np.float64],
45
+ ) -> Tuple[
46
+ npt.NDArray[np.float64],
47
+ npt.NDArray[np.float64],
48
+ npt.NDArray[np.float64],
49
+ npt.NDArray[np.float64],
50
+ ]:
51
+ xc, yc, w, h = batch_bbox
52
+ x1 = xc - w / 2
53
+ y1 = yc - h / 2
54
+ x2 = xc + w / 2
55
+ y2 = yc + h / 2
56
+ return (x1, y1, x2, y2)
57
+
58
+
59
+ class A(TypedDict):
60
+ a1: npt.NDArray[np.float64]
61
+ ai: npt.NDArray[np.float64]
62
+
63
+
64
+ class LayoutOverlap(evaluate.Metric):
65
+ def _info(self) -> evaluate.EvaluationModuleInfo:
66
+ return evaluate.MetricInfo(
67
+ description=_DESCRIPTION,
68
+ citation=_CITATION,
69
+ features=ds.Features(
70
+ {
71
+ "batch_bbox": ds.Sequence(ds.Sequence(ds.Value("float64"))),
72
+ "batch_mask": ds.Sequence(ds.Value("bool")),
73
+ }
74
+ ),
75
+ codebase_urls=[
76
+ "https://github.com/ktrk115/const_layout/blob/master/metric.py#L167-L188",
77
+ "https://github.com/CyberAgentAILab/layout-dm/blob/main/src/trainer/trainer/helpers/metric.py#L98-L147",
78
+ ],
79
+ )
80
+
81
+ def __calculate_a1_ai(self, batch_bbox: npt.NDArray[np.float64]) -> A:
82
+
83
+ l1, t1, r1, b1 = convert_xywh_to_ltrb(batch_bbox[:, :, :, None])
84
+ l2, t2, r2, b2 = convert_xywh_to_ltrb(batch_bbox[:, :, None, :])
85
+ a1 = (r1 - l1) * (b1 - t1)
86
+
87
+ # shape: (B, S, S)
88
+ l_max = np.maximum(l1, l2)
89
+ r_min = np.minimum(r1, r2)
90
+ t_max = np.maximum(t1, t2)
91
+ b_min = np.minimum(b1, b2)
92
+ cond = (l_max < r_min) & (t_max < b_min)
93
+ ai = np.where(cond, (r_min - l_max) * (b_min - t_max), 0.0)
94
+
95
+ return {"a1": a1, "ai": ai}
96
+
97
+ def _compute_ac_layout_gan(
98
+ self,
99
+ S: int,
100
+ ai: npt.NDArray[np.float64],
101
+ a1: npt.NDArray[np.float64],
102
+ batch_mask: npt.NDArray[np.bool_],
103
+ ) -> npt.NDArray[np.float64]:
104
+
105
+ # shape: (B, S) -> (B, S, S)
106
+ batch_mask = ~batch_mask[:, None, :] | ~batch_mask[:, :, None]
107
+ indices = np.arange(S)
108
+ batch_mask[:, indices, indices] = True
109
+ ai[batch_mask] = 0.0
110
+
111
+ # shape: (B, S, S)
112
+ ar = np.nan_to_num(ai / a1)
113
+ score = ar.sum(axis=(1, 2))
114
+
115
+ return score
116
+
117
+ def _compute_layout_gan_pp(
118
+ self,
119
+ score_ac_layout_gan: npt.NDArray[np.float64],
120
+ batch_mask: npt.NDArray[np.bool_],
121
+ ):
122
+ # shape: (B, S) -> (B,)
123
+ batch_mask = batch_mask.sum(axis=1)
124
+
125
+ # shape: (B,)
126
+ score_normalized = score_ac_layout_gan / batch_mask
127
+ score_normalized[np.isnan(score_normalized)] = 0.0
128
+ return score_normalized
129
+
130
+ def _compute_layout_gan(
131
+ self, S: int, B: int, ai: npt.NDArray[np.float64]
132
+ ) -> npt.NDArray[np.float64]:
133
+ indices = np.arange(S)
134
+ ii, jj = np.meshgrid(indices, indices, indexing="ij")
135
+
136
+ # shape: ii (S, S) -> (1, S, S), jj (S, S) -> (1, S, S)
137
+ # shape: (1, S, S) -> (B, S, S)
138
+ ai[np.repeat((ii[None, :] >= jj[None, :]), axis=0, repeats=B)] = 0.0
139
+
140
+ # shape: (B, S, S) -> (B,)
141
+ score = ai.sum(axis=(1, 2))
142
+
143
+ return score
144
+
145
+ def _compute(
146
+ self,
147
+ *,
148
+ batch_bbox: Union[npt.NDArray[np.float64], List[List[int]]],
149
+ batch_mask: Union[npt.NDArray[np.bool_], List[List[bool]]],
150
+ ) -> Dict[str, npt.NDArray[np.float64]]:
151
+
152
+ # shape: (B, model_max_length, C)
153
+ batch_bbox = np.array(batch_bbox)
154
+ # shape: (B, model_max_length)
155
+ batch_mask = np.array(batch_mask)
156
+
157
+ # S: model_max_length
158
+ B, S, C = batch_bbox.shape
159
+
160
+ # shape: batch_bbox (B, S, C), batch_mask (B, S) -> (B, S, 1) -> (B, S, C)
161
+ batch_bbox[np.repeat(~batch_mask[:, :, None], axis=2, repeats=C)] = 0.0
162
+ # shape: (C, B, S)
163
+ batch_bbox = batch_bbox.transpose(2, 0, 1)
164
+
165
+ A = self.__calculate_a1_ai(batch_bbox)
166
+
167
+ # shape: (B,)
168
+ score_ac_layout_gan = self._compute_ac_layout_gan(
169
+ S=S, batch_mask=batch_mask, **A
170
+ )
171
+ # shape: (B,)
172
+ score_layout_gan_pp = self._compute_layout_gan_pp(
173
+ score_ac_layout_gan=score_ac_layout_gan, batch_mask=batch_mask
174
+ )
175
+ # shape: (B,)
176
+ score_layout_gan = self._compute_layout_gan(B=B, S=S, ai=A["ai"])
177
+
178
+ return {
179
+ "overlap-ACLayoutGAN": score_ac_layout_gan,
180
+ "overlap-LayoutGAN++": score_layout_gan_pp,
181
+ "overlap-LayoutGAN": score_layout_gan,
182
+ }
requirements.txt ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1 ; python_version >= "3.9" and python_version < "4.0"
2
+ aiohttp==3.9.3 ; python_version >= "3.9" and python_version < "4.0"
3
+ aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "4.0"
4
+ altair==5.2.0 ; python_version >= "3.9" and python_version < "4.0"
5
+ annotated-types==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
6
+ anyio==4.2.0 ; python_version >= "3.9" and python_version < "4.0"
7
+ arrow==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
8
+ async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11"
9
+ attrs==23.2.0 ; python_version >= "3.9" and python_version < "4.0"
10
+ binaryornot==0.4.4 ; python_version >= "3.9" and python_version < "4.0"
11
+ certifi==2024.2.2 ; python_version >= "3.9" and python_version < "4.0"
12
+ chardet==5.2.0 ; python_version >= "3.9" and python_version < "4.0"
13
+ charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "4.0"
14
+ click==8.1.7 ; python_version >= "3.9" and python_version < "4.0"
15
+ colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0"
16
+ contourpy==1.2.0 ; python_version >= "3.9" and python_version < "4.0"
17
+ cookiecutter==2.5.0 ; python_version >= "3.9" and python_version < "4.0"
18
+ cycler==0.12.1 ; python_version >= "3.9" and python_version < "4.0"
19
+ datasets==2.17.0 ; python_version >= "3.9" and python_version < "4.0"
20
+ dill==0.3.8 ; python_version >= "3.9" and python_version < "4.0"
21
+ evaluate[template]==0.4.1 ; python_version >= "3.9" and python_version < "4.0"
22
+ exceptiongroup==1.2.0 ; python_version >= "3.9" and python_version < "3.11"
23
+ fastapi==0.109.2 ; python_version >= "3.9" and python_version < "4.0"
24
+ ffmpy==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
25
+ filelock==3.13.1 ; python_version >= "3.9" and python_version < "4.0"
26
+ fonttools==4.48.1 ; python_version >= "3.9" and python_version < "4.0"
27
+ frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "4.0"
28
+ fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "4.0"
29
+ fsspec[http]==2023.10.0 ; python_version >= "3.9" and python_version < "4.0"
30
+ gradio-client==0.9.0 ; python_version >= "3.9" and python_version < "4.0"
31
+ gradio==4.17.0 ; python_version >= "3.9" and python_version < "4.0"
32
+ h11==0.14.0 ; python_version >= "3.9" and python_version < "4.0"
33
+ httpcore==1.0.2 ; python_version >= "3.9" and python_version < "4.0"
34
+ httpx==0.26.0 ; python_version >= "3.9" and python_version < "4.0"
35
+ huggingface-hub==0.20.3 ; python_version >= "3.9" and python_version < "4.0"
36
+ idna==3.6 ; python_version >= "3.9" and python_version < "4.0"
37
+ importlib-resources==6.1.1 ; python_version >= "3.9" and python_version < "4.0"
38
+ jinja2==3.1.3 ; python_version >= "3.9" and python_version < "4.0"
39
+ jsonschema-specifications==2023.12.1 ; python_version >= "3.9" and python_version < "4.0"
40
+ jsonschema==4.21.1 ; python_version >= "3.9" and python_version < "4.0"
41
+ kiwisolver==1.4.5 ; python_version >= "3.9" and python_version < "4.0"
42
+ markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "4.0"
43
+ markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "4.0"
44
+ matplotlib==3.8.2 ; python_version >= "3.9" and python_version < "4.0"
45
+ mdurl==0.1.2 ; python_version >= "3.9" and python_version < "4.0"
46
+ multidict==6.0.5 ; python_version >= "3.9" and python_version < "4.0"
47
+ multiprocess==0.70.16 ; python_version >= "3.9" and python_version < "4.0"
48
+ numpy==1.26.4 ; python_version >= "3.9" and python_version < "4.0"
49
+ orjson==3.9.13 ; python_version >= "3.9" and python_version < "4.0"
50
+ packaging==23.2 ; python_version >= "3.9" and python_version < "4.0"
51
+ pandas==2.2.0 ; python_version >= "3.9" and python_version < "4.0"
52
+ pillow==10.2.0 ; python_version >= "3.9" and python_version < "4.0"
53
+ pyarrow-hotfix==0.6 ; python_version >= "3.9" and python_version < "4.0"
54
+ pyarrow==15.0.0 ; python_version >= "3.9" and python_version < "4.0"
55
+ pydantic-core==2.16.2 ; python_version >= "3.9" and python_version < "4.0"
56
+ pydantic==2.6.1 ; python_version >= "3.9" and python_version < "4.0"
57
+ pydub==0.25.1 ; python_version >= "3.9" and python_version < "4.0"
58
+ pygments==2.17.2 ; python_version >= "3.9" and python_version < "4.0"
59
+ pyparsing==3.1.1 ; python_version >= "3.9" and python_version < "4.0"
60
+ python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "4.0"
61
+ python-multipart==0.0.7 ; python_version >= "3.9" and python_version < "4.0"
62
+ python-slugify==8.0.4 ; python_version >= "3.9" and python_version < "4.0"
63
+ pytz==2024.1 ; python_version >= "3.9" and python_version < "4.0"
64
+ pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "4.0"
65
+ referencing==0.33.0 ; python_version >= "3.9" and python_version < "4.0"
66
+ requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
67
+ responses==0.18.0 ; python_version >= "3.9" and python_version < "4.0"
68
+ rich==13.7.0 ; python_version >= "3.9" and python_version < "4.0"
69
+ rpds-py==0.17.1 ; python_version >= "3.9" and python_version < "4.0"
70
+ ruff==0.2.1 ; python_version >= "3.9" and python_version < "4.0"
71
+ semantic-version==2.10.0 ; python_version >= "3.9" and python_version < "4.0"
72
+ shellingham==1.5.4 ; python_version >= "3.9" and python_version < "4.0"
73
+ six==1.16.0 ; python_version >= "3.9" and python_version < "4.0"
74
+ sniffio==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
75
+ starlette==0.36.3 ; python_version >= "3.9" and python_version < "4.0"
76
+ text-unidecode==1.3 ; python_version >= "3.9" and python_version < "4.0"
77
+ tomlkit==0.12.0 ; python_version >= "3.9" and python_version < "4.0"
78
+ toolz==0.12.1 ; python_version >= "3.9" and python_version < "4.0"
79
+ tqdm==4.66.1 ; python_version >= "3.9" and python_version < "4.0"
80
+ typer[all]==0.9.0 ; python_version >= "3.9" and python_version < "4.0"
81
+ types-python-dateutil==2.8.19.20240106 ; python_version >= "3.9" and python_version < "4.0"
82
+ typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "4.0"
83
+ tzdata==2023.4 ; python_version >= "3.9" and python_version < "4.0"
84
+ urllib3==2.2.0 ; python_version >= "3.9" and python_version < "4.0"
85
+ uvicorn==0.27.0.post1 ; python_version >= "3.9" and python_version < "4.0"
86
+ websockets==11.0.3 ; python_version >= "3.9" and python_version < "4.0"
87
+ xxhash==3.4.1 ; python_version >= "3.9" and python_version < "4.0"
88
+ yarl==1.9.4 ; python_version >= "3.9" and python_version < "4.0"
89
+ zipp==3.17.0 ; python_version >= "3.9" and python_version < "3.10"