import os import random import json # read the json valid_list_json_path = './megadepth_valid_list.json' assert os.path.isfile(valid_list_json_path), 'Change to the valid list json' with open(valid_list_json_path, 'r') as f: all_list = json.load(f) # build scene - image dictionary scene_img_dict = {} for item in all_list: if not item[:4] in scene_img_dict: scene_img_dict[item[:4]] = [] scene_img_dict[item[:4]].append(item) train_split = [] val_split = [] test_split = [] for k in sorted(scene_img_dict.keys()): if int(k) == 204: val_split += scene_img_dict[k] elif int(k) <= 240 and int(k) != 204: train_split += scene_img_dict[k] else: test_split += scene_img_dict[k] # save split to json with open('megadepth_train.json', 'w') as outfile: json.dump(sorted(train_split), outfile, indent=4) with open('megadepth_val.json', 'w') as outfile: json.dump(sorted(val_split), outfile, indent=4) with open('megadepth_test.json', 'w') as outfile: json.dump(sorted(test_split), outfile, indent=4)