|
|
|
|
|
import json |
|
import argparse |
|
import funcy |
|
from sklearn.model_selection import train_test_split |
|
|
|
parser = argparse.ArgumentParser(description='Splits COCO annotations file into training and test sets.') |
|
parser.add_argument('--annotation_path', metavar='coco_annotations', type=str, |
|
help='Path to COCO annotations file.') |
|
parser.add_argument('--train', type=str, help='Where to store COCO training annotations') |
|
parser.add_argument('--test', type=str, help='Where to store COCO test annotations') |
|
parser.add_argument('--s', dest='split_ratio', type=float, required=True, |
|
help="A percentage of a split; a number in (0, 1)") |
|
parser.add_argument('--having-annotations', dest='having_annotations', action='store_true', |
|
help='Ignore all images without annotations. Keep only these with at least one annotation') |
|
|
|
def save_coco(file, tagged_data): |
|
with open(file, 'wt', encoding='UTF-8') as coco: |
|
json.dump(tagged_data, coco, indent=2, sort_keys=True) |
|
|
|
def filter_annotations(annotations, images): |
|
image_ids = funcy.lmap(lambda i: int(i['id']), images) |
|
return funcy.lfilter(lambda a: int(a['image_id']) in image_ids, annotations) |
|
|
|
def main(annotation_path, |
|
split_ratio, |
|
having_annotations, |
|
train_save_path, |
|
test_save_path, |
|
random_state=None): |
|
|
|
with open(annotation_path, 'rt', encoding='UTF-8') as annotations: |
|
coco = json.load(annotations) |
|
|
|
images = coco['images'] |
|
annotations = coco['annotations'] |
|
|
|
number_of_images = len(images) |
|
|
|
images_with_annotations = funcy.lmap(lambda a: int(a['image_id']), annotations) |
|
|
|
if having_annotations: |
|
images = funcy.lremove(lambda i: i['id'] not in images_with_annotations, images) |
|
|
|
x, y = train_test_split(images, train_size=split_ratio, random_state=random_state) |
|
|
|
|
|
coco.update({'images': x, |
|
'annotations': filter_annotations(annotations, x)}) |
|
save_coco(train_save_path, coco) |
|
|
|
|
|
coco.update({'images': y, |
|
'annotations': filter_annotations(annotations, y)}) |
|
save_coco(test_save_path, coco) |
|
|
|
print("Saved {} entries in {} and {} in {}".format(len(x), train_save_path, len(y), test_save_path)) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parser.parse_args() |
|
|
|
main(args.annotation_path, |
|
args.split_ratio, |
|
args.having_annotations, |
|
args.train, |
|
args.test, |
|
random_state=24) |