mckabue commited on
Commit
7378eef
·
verified ·
1 Parent(s): 1c0c18d

RE_UPLOAD-REBUILD-RESTART

Browse files
model/layout-model-training/tools/convert_prima_to_coco.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re, json
2
+ import imagesize
3
+ from glob import glob
4
+ from bs4 import BeautifulSoup
5
+ import numpy as np
6
+ from PIL import Image
7
+ import argparse
8
+ from tqdm import tqdm
9
+ import sys
10
+ sys.path.append('..')
11
+ from utils import cocosplit
12
+
13
+ class NpEncoder(json.JSONEncoder):
14
+ def default(self, obj):
15
+ if isinstance(obj, np.integer):
16
+ return int(obj)
17
+ elif isinstance(obj, np.floating):
18
+ return float(obj)
19
+ elif isinstance(obj, np.ndarray):
20
+ return obj.tolist()
21
+ else:
22
+ return super(NpEncoder, self).default(obj)
23
+
24
+ def cvt_coords_to_array(obj):
25
+
26
+ return np.array(
27
+ [(float(pt['x']), float(pt['y']))
28
+ for pt in obj.find_all("Point")]
29
+ )
30
+
31
+ def cal_ployarea(points):
32
+ x = points[:,0]
33
+ y = points[:,1]
34
+ return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1)))
35
+
36
+ def _create_category(schema=0):
37
+
38
+ if schema==0:
39
+
40
+ categories = \
41
+ [{"supercategory": "layout", "id": 0, "name": "Background"},
42
+ {"supercategory": "layout", "id": 1, "name": "TextRegion"},
43
+ {"supercategory": "layout", "id": 2, "name": "ImageRegion"},
44
+ {"supercategory": "layout", "id": 3, "name": "TableRegion"},
45
+ {"supercategory": "layout", "id": 4, "name": "MathsRegion"},
46
+ {"supercategory": "layout", "id": 5, "name": "SeparatorRegion"},
47
+ {"supercategory": "layout", "id": 6, "name": "OtherRegion"}]
48
+
49
+ find_categories = lambda name: \
50
+ [val["id"] for val in categories if val['name'] == name][0]
51
+
52
+ conversion = \
53
+ {
54
+ 'TextRegion': find_categories("TextRegion"),
55
+ 'TableRegion': find_categories("TableRegion"),
56
+ 'MathsRegion': find_categories("MathsRegion"),
57
+ 'ChartRegion': find_categories("ImageRegion"),
58
+ 'GraphicRegion': find_categories("ImageRegion"),
59
+ 'ImageRegion': find_categories("ImageRegion"),
60
+ 'LineDrawingRegion':find_categories("OtherRegion"),
61
+ 'SeparatorRegion': find_categories("SeparatorRegion"),
62
+ 'NoiseRegion': find_categories("OtherRegion"),
63
+ 'FrameRegion': find_categories("OtherRegion"),
64
+ }
65
+
66
+ return categories, conversion
67
+
68
+ _categories, _categories_conversion = _create_category(schema=0)
69
+
70
+ _info = {
71
+ "description": "PRIMA Layout Analysis Dataset",
72
+ "url": "https://www.primaresearch.org/datasets/Layout_Analysis",
73
+ "version": "1.0",
74
+ "year": 2010,
75
+ "contributor": "PRIMA Research",
76
+ "date_created": "2020/09/01",
77
+ }
78
+
79
+ def _load_soup(filename):
80
+ with open(filename, "r") as fp:
81
+ soup = BeautifulSoup(fp.read(),'xml')
82
+
83
+ return soup
84
+
85
+ def _image_template(image_id, image_path):
86
+
87
+ width, height = imagesize.get(image_path)
88
+
89
+ return {
90
+ "file_name": os.path.basename(image_path),
91
+ "height": height,
92
+ "width": width,
93
+ "id": int(image_id)
94
+ }
95
+
96
+ def _anno_template(anno_id, image_id, pts, obj_tag):
97
+
98
+ x_1, x_2 = pts[:,0].min(), pts[:,0].max()
99
+ y_1, y_2 = pts[:,1].min(), pts[:,1].max()
100
+ height = y_2 - y_1
101
+ width = x_2 - x_1
102
+
103
+ return {
104
+ "segmentation": [pts.flatten().tolist()],
105
+ "area": cal_ployarea(pts),
106
+ "iscrowd": 0,
107
+ "image_id": image_id,
108
+ "bbox": [x_1, y_1, width, height],
109
+ "category_id": _categories_conversion[obj_tag],
110
+ "id": anno_id
111
+ }
112
+
113
+ class PRIMADataset():
114
+
115
+ def __init__(self, base_path, anno_path='XML',
116
+ image_path='Images'):
117
+
118
+ self.base_path = base_path
119
+ self.anno_path = os.path.join(base_path, anno_path)
120
+ self.image_path = os.path.join(base_path, image_path)
121
+
122
+ self._ids = self.find_all_image_ids()
123
+
124
+ def __len__(self):
125
+ return len(self.ids)
126
+
127
+ def __getitem__(self, idx):
128
+ return self.load_image_and_annotaiton(idx)
129
+
130
+ def find_all_annotation_files(self):
131
+ return glob(os.path.join(self.anno_path, '*.xml'))
132
+
133
+ def find_all_image_ids(self):
134
+ replacer = lambda s: os.path.basename(s).replace('pc-', '').replace('.xml', '')
135
+ return [replacer(s) for s in self.find_all_annotation_files()]
136
+
137
+ def load_image_and_annotaiton(self, idx):
138
+
139
+ image_id = self._ids[idx]
140
+
141
+ image_path = os.path.join(self.image_path, f'{image_id}.tif')
142
+ image = Image.open(image_path)
143
+
144
+ anno = self.load_annotation(idx)
145
+
146
+ return image, anno
147
+
148
+ def load_annotation(self, idx):
149
+ image_id = self._ids[idx]
150
+
151
+ anno_path = os.path.join(self.anno_path, f'pc-{image_id}.xml')
152
+ # A dirtly hack to load the files w/wo pc- simualtaneously
153
+ if not os.path.exists(anno_path):
154
+ anno_path = os.path.join(self.anno_path, f'{image_id}.xml')
155
+ assert os.path.exists(anno_path), "Invalid path"
156
+ anno = _load_soup(anno_path)
157
+
158
+ return anno
159
+
160
+ def convert_to_COCO(self, save_path):
161
+
162
+ all_image_infos = []
163
+ all_anno_infos = []
164
+ anno_id = 0
165
+
166
+ for idx, image_id in enumerate(tqdm(self._ids)):
167
+
168
+ # We use the idx as the image id
169
+
170
+ image_path = os.path.join(self.image_path, f'{image_id}.tif')
171
+ image_info = _image_template(idx, image_path)
172
+ all_image_infos.append(image_info)
173
+
174
+ anno = self.load_annotation(idx)
175
+
176
+ for item in anno.find_all(re.compile(".*Region")):
177
+
178
+ pts = cvt_coords_to_array(item.Coords)
179
+ if 0 not in pts.shape:
180
+ # Sometimes there will be polygons with less
181
+ # than 4 edges, and they could not be appropriately
182
+ # handled by the COCO format. So we just drop them.
183
+ if pts.shape[0] >= 4:
184
+ anno_info = _anno_template(anno_id, idx, pts, item.name)
185
+ all_anno_infos.append(anno_info)
186
+ anno_id += 1
187
+
188
+
189
+ final_annotation = {
190
+ "info": _info,
191
+ "licenses": [],
192
+ "images": all_image_infos,
193
+ "annotations": all_anno_infos,
194
+ "categories": _categories}
195
+
196
+ with open(save_path, 'w') as fp:
197
+ json.dump(final_annotation, fp, cls=NpEncoder)
198
+
199
+ return final_annotation
200
+
201
+
202
+ parser = argparse.ArgumentParser()
203
+
204
+ parser.add_argument('--prima_datapath', type=str, default='./data/prima', help='the path to the prima data folders')
205
+ parser.add_argument('--anno_savepath', type=str, default='./annotations.json', help='the path to save the new annotations')
206
+
207
+
208
+ if __name__ == "__main__":
209
+ args = parser.parse_args()
210
+
211
+ print("Start running the conversion script")
212
+
213
+ print(f"Loading the information from the path {args.prima_datapath}")
214
+ dataset = PRIMADataset(args.prima_datapath)
215
+
216
+ print(f"Saving the annotation to {args.anno_savepath}")
217
+ res = dataset.convert_to_COCO(args.anno_savepath)
218
+
219
+ cocosplit.main(
220
+ args.anno_savepath,
221
+ split_ratio=0.8,
222
+ having_annotations=True,
223
+ train_save_path=args.anno_savepath.replace('.json', '-train.json'),
224
+ test_save_path=args.anno_savepath.replace('.json', '-val.json'),
225
+ random_state=24)