import os import orjson import json import webdataset as wds from tqdm import tqdm, trange import h5py import numpy as np from utils import MAXCOUNT, NAMING, check_sample OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/instruct/vg_relation" BOX_SCALE = 512 def load_image_filenames(image_file, image_dir): """ Loads the image filenames from visual genome from the JSON file that contains them. This matches the preprocessing in scene-graph-TF-release/data_tools/vg_to_imdb.py. :param image_file: JSON file. Elements contain the param "image_id". :param image_dir: directory where the VisualGenome images are located :return: List of filenames corresponding to the good images """ with open(image_file, 'r') as f: im_data = json.load(f) corrupted_ims = ['1592.jpg', '1722.jpg', '4616.jpg', '4617.jpg'] fns = [] for i, img in enumerate(tqdm(im_data)): basename = '{}.jpg'.format(img['image_id']) height = int(img['height']) width = int(img['width']) if basename in corrupted_ims: continue filename = os.path.join(image_dir, basename) if os.path.exists(filename): fns.append([filename, height, width]) assert len(fns) == 108073 return fns def load_graphs(graphs_file, mode='train', num_im=-1, num_val_im=0, filter_empty_rels=True, filter_non_overlap=False): """ Load the file containing the GT boxes and relations, as well as the dataset split :param graphs_file: HDF5 :param mode: (train, val, or test) :param num_im: Number of images we want :param num_val_im: Number of validation images :param filter_empty_rels: (will be filtered otherwise.) :param filter_non_overlap: If training, filter images that dont overlap. :return: image_index: numpy array corresponding to the index of images we're using boxes: List where each element is a [num_gt, 4] array of ground truth boxes (x1, y1, x2, y2) gt_classes: List where each element is a [num_gt] array of classes relationships: List where each element is a [num_r, 3] array of (box_ind_1, box_ind_2, predicate) relationships """ if mode not in ('train', 'val', 'test'): raise ValueError('{} invalid'.format(mode)) roi_h5 = h5py.File(graphs_file, 'r') data_split = roi_h5['split'][:] split = 2 if mode == 'test' else 0 split_mask = data_split == split # Filter out images without bounding boxes split_mask &= roi_h5['img_to_first_box'][:] >= 0 if filter_empty_rels: split_mask &= roi_h5['img_to_first_rel'][:] >= 0 image_index = np.where(split_mask)[0] if num_im > -1: image_index = image_index[:num_im] if num_val_im > 0: if mode == 'val': image_index = image_index[:num_val_im] elif mode == 'train': image_index = image_index[num_val_im:] split_mask = np.zeros_like(data_split).astype(bool) split_mask[image_index] = True # Get box information all_labels = roi_h5['labels'][:, 0] all_boxes = roi_h5['boxes_{}'.format(BOX_SCALE)][:] # will index later assert np.all(all_boxes[:, :2] >= 0) # sanity check assert np.all(all_boxes[:, 2:] > 0) # no empty box # convert from xc, yc, w, h to x1, y1, x2, y2 all_boxes[:, :2] = all_boxes[:, :2] - all_boxes[:, 2:] / 2 all_boxes[:, 2:] = all_boxes[:, :2] + all_boxes[:, 2:] im_to_first_box = roi_h5['img_to_first_box'][:][split_mask] im_to_last_box = roi_h5['img_to_last_box'][:][split_mask] im_to_first_rel = roi_h5['img_to_first_rel'][:][split_mask] im_to_last_rel = roi_h5['img_to_last_rel'][:][split_mask] # load relation labels _relations = roi_h5['relationships'][:] _relation_predicates = roi_h5['predicates'][:, 0] assert (im_to_first_rel.shape[0] == im_to_last_rel.shape[0]) assert (_relations.shape[0] == _relation_predicates.shape[0]) # sanity check # Get everything by image. boxes = [] gt_classes = [] relationships = [] for i in trange(len(image_index)): boxes_i = all_boxes[im_to_first_box[i]:im_to_last_box[i] + 1, :] gt_classes_i = all_labels[im_to_first_box[i]:im_to_last_box[i] + 1] if im_to_first_rel[i] >= 0: predicates = _relation_predicates[im_to_first_rel[i]:im_to_last_rel[i] + 1] obj_idx = _relations[im_to_first_rel[i]:im_to_last_rel[i] + 1] - im_to_first_box[i] assert np.all(obj_idx >= 0) assert np.all(obj_idx < boxes_i.shape[0]) rels = np.column_stack((obj_idx, predicates)) else: assert not filter_empty_rels rels = np.zeros((0, 3), dtype=np.int32) if filter_non_overlap: raise NotImplementedError assert mode == 'train' inters = bbox_overlaps(boxes_i, boxes_i) rel_overs = inters[rels[:, 0], rels[:, 1]] inc = np.where(rel_overs > 0.0)[0] if inc.size > 0: rels = rels[inc] else: split_mask[image_index[i]] = 0 continue boxes.append(boxes_i) gt_classes.append(gt_classes_i) relationships.append(rels) return split_mask, boxes, gt_classes, relationships def load_info(info_file): """ Loads the file containing the visual genome label meanings :param info_file: JSON :return: ind_to_classes: sorted list of classes ind_to_predicates: sorted list of predicates """ info = json.load(open(info_file, 'r')) info['label_to_idx']['__background__'] = 0 info['predicate_to_idx']['__background__'] = 0 class_to_ind = info['label_to_idx'] predicate_to_ind = info['predicate_to_idx'] ind_to_classes = sorted(class_to_ind, key=lambda k: class_to_ind[k]) ind_to_predicates = sorted(predicate_to_ind, key=lambda k: predicate_to_ind[k]) return ind_to_classes, ind_to_predicates if __name__ == "__main__": root = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg" filenames = load_image_filenames(os.path.join(root, "image_data.json"), os.path.join(root, "VG_100K")) split_mask, boxes, gt_classes, relationships = load_graphs( graphs_file=os.path.join(root, "VG-SGG.h5"), mode="train", ) split_filenames = [] for i, mask in enumerate(split_mask): if mask: split_filenames.append(filenames[i]) filenames = split_filenames ind_to_classes, ind_to_predicates = load_info(os.path.join(root, "VG-SGG-dicts.json")) assert len(filenames) == len(boxes) assert len(filenames) == len(gt_classes) assert len(filenames) == len(relationships) uuid = 0 os.makedirs(OUT_DIR, exist_ok=True) pbar = tqdm() with wds.ShardWriter(os.path.join(OUT_DIR, NAMING), maxcount=MAXCOUNT) as sink: for box, box_class, relationship, (filename, height, width) in zip(boxes, gt_classes, relationships, filenames): size = float(BOX_SCALE) / max(height, width) size = np.array([width, height, width, height]) * size box = (box.astype(float) / size).clip(0, 1) for relation in relationship: box1_id = relation[0] box2_id = relation[1] predicate = ind_to_predicates[relation[2]] box1 = [box[box1_id], ind_to_classes[box_class[box1_id]]] box2 = [box[box2_id], ind_to_classes[box_class[box2_id]]] data = [box1, box2, predicate] dataset = "vg_relation" image_path = filename key = f"{dataset}_{uuid}" uuid += 1 pbar.update() sample = { "__key__": key, "image_path.txt": image_path, "dataset.txt": dataset, "data.pyd": data, } check_sample(sample) sink.write(sample) # if __name__ == "__main__": # root = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg" # relationships = orjson.loads(open("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg/relationships.json").read()) # image_data = orjson.loads(open("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg/image_data.json").read()) # image_id_to_filename = {} # image_id_to_wh = {} # for image in tqdm(image_data): # image_id = image["image_id"] # subfolder, filename = image['url'].split("/")[-2:] # image_id_to_filename[image_id] = os.path.join(root, subfolder, filename) # image_id_to_wh[image_id] = (image["width"], image["height"]) # unique_predicates = [] # # with wds.ShardWriter(os.path.join(OUT_DIR, "%05d.tar"), maxcount=500) as sink: # for relation_per_image in tqdm(relationships): # image_id = relation_per_image["image_id"] # for relation in relation_per_image["relationships"]: # predicate = relation["predicate"] # unique_predicates.append(predicate) # object = { # "name": relation["object"]["name"], # "x": relation["object"]["x"], # "y": relation["object"]["y"], # "w": relation["object"]["w"], # "h": relation["object"]["h"], # } # subject = { # "name": relation["subject"]["name"], # "x": relation["subject"]["x"], # "y": relation["subject"]["y"], # "w": relation["subject"]["w"], # "h": relation["subject"]["h"], # }