Charles Kabui
Add 'model/layout-parser/' from commit 'b9fad076596272e427612d5e848da1ba8ea06b97'
399308e
# Modified based on https://github.com/akarazniewicz/cocosplit/blob/master/cocosplit.py | |
import json | |
import argparse | |
import funcy | |
from sklearn.model_selection import train_test_split | |
parser = argparse.ArgumentParser( | |
description="Splits COCO annotations file into training and test sets." | |
) | |
parser.add_argument( | |
"--annotation-path", | |
metavar="coco_annotations", | |
type=str, | |
help="Path to COCO annotations file.", | |
) | |
parser.add_argument( | |
"--train", type=str, help="Where to store COCO training annotations" | |
) | |
parser.add_argument("--test", type=str, help="Where to store COCO test annotations") | |
parser.add_argument( | |
"--split-ratio", | |
dest="split_ratio", | |
type=float, | |
required=True, | |
help="A percentage of a split; a number in (0, 1)", | |
) | |
parser.add_argument( | |
"--having-annotations", | |
dest="having_annotations", | |
action="store_true", | |
help="Ignore all images without annotations. Keep only these with at least one annotation", | |
) | |
def save_coco(file, tagged_data): | |
with open(file, "wt", encoding="UTF-8") as coco: | |
json.dump(tagged_data, coco, indent=2, sort_keys=True) | |
def filter_annotations(annotations, images): | |
image_ids = funcy.lmap(lambda i: int(i["id"]), images) | |
return funcy.lfilter(lambda a: int(a["image_id"]) in image_ids, annotations) | |
def main( | |
annotation_path, | |
split_ratio, | |
having_annotations, | |
train_save_path, | |
test_save_path, | |
random_state=None, | |
): | |
with open(annotation_path, "rt", encoding="UTF-8") as annotations: | |
coco = json.load(annotations) | |
images = coco["images"] | |
annotations = coco["annotations"] | |
ids_with_annotations = funcy.lmap(lambda a: int(a["image_id"]), annotations) | |
# Images with annotations | |
img_ann = funcy.lremove(lambda i: i["id"] not in ids_with_annotations, images) | |
tr_ann, ts_ann = train_test_split( | |
img_ann, train_size=split_ratio, random_state=random_state | |
) | |
img_wo_ann = funcy.lremove(lambda i: i["id"] in ids_with_annotations, images) | |
if len(img_wo_ann) > 0: | |
tr_wo_ann, ts_wo_ann = train_test_split( | |
img_wo_ann, train_size=split_ratio, random_state=random_state | |
) | |
else: | |
tr_wo_ann, ts_wo_ann = [], [] # Images without annotations | |
if having_annotations: | |
tr, ts = tr_ann, ts_ann | |
else: | |
# Merging the 2 image lists (i.e. with and without annotation) | |
tr_ann.extend(tr_wo_ann) | |
ts_ann.extend(ts_wo_ann) | |
tr, ts = tr_ann, ts_ann | |
# Train Data | |
coco.update({"images": tr, "annotations": filter_annotations(annotations, tr)}) | |
save_coco(train_save_path, coco) | |
# Test Data | |
coco.update({"images": ts, "annotations": filter_annotations(annotations, ts)}) | |
save_coco(test_save_path, coco) | |
print( | |
"Saved {} entries in {} and {} in {}".format( | |
len(tr), train_save_path, len(ts), test_save_path | |
) | |
) | |
if __name__ == "__main__": | |
args = parser.parse_args() | |
main( | |
args.annotation_path, | |
args.split_ratio, | |
args.having_annotations, | |
args.train, | |
args.test, | |
random_state=24, | |
) | |