# Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, visit # https://github.com/NVlabs/prismer/blob/main/LICENSE import torch import os import json import copy import PIL.Image as Image try: import ruamel_yaml as yaml except ModuleNotFoundError: import ruamel.yaml as yaml from experts.model_bank import load_expert_model from experts.obj_detection.generate_dataset import Dataset, collate_fn from accelerate import Accelerator from tqdm import tqdm model, transform = load_expert_model(task='obj_detection') accelerator = Accelerator(mixed_precision='fp16') config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader) data_path = config['data_path'] save_path = config['save_path'] depth_path = os.path.join(save_path, 'depth', data_path.split('/')[-1]) batch_size = 32 dataset = Dataset(data_path, depth_path, transform) data_loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate_fn, ) model, data_loader = accelerator.prepare(model, data_loader) def get_mask_labels(depth, instance_boxes, instance_id): obj_masks = [] obj_ids = [] for i in range(len(instance_boxes)): is_duplicate = False mask = torch.zeros_like(depth) x1, y1, x2, y2 = instance_boxes[i][0].item(), instance_boxes[i][1].item(), \ instance_boxes[i][2].item(), instance_boxes[i][3].item() mask[int(y1):int(y2), int(x1):int(x2)] = 1 for j in range(len(obj_masks)): if ((mask + obj_masks[j]) == 2).sum() / ((mask + obj_masks[j]) > 0).sum() > 0.95: is_duplicate = True break if not is_duplicate: obj_masks.append(mask) obj_ids.append(instance_id[i]) obj_masked_modified = copy.deepcopy(obj_masks[:]) for i in range(len(obj_masks) - 1): mask1 = obj_masks[i] mask1_ = obj_masked_modified[i] for j in range(i + 1, len(obj_masks)): mask2 = obj_masks[j] mask2_ = obj_masked_modified[j] # case 1: if they don't intersect we don't touch them if ((mask1 + mask2) == 2).sum() == 0: continue # case 2: the entire object 1 is inside of object 2, we say object 1 is in front of object 2: elif (((mask1 + mask2) == 2).float() - mask1).sum() == 0: mask2_ -= mask1_ # case 3: the entire object 2 is inside of object 1, we say object 2 is in front of object 1: elif (((mask1 + mask2) == 2).float() - mask2).sum() == 0: mask1_ -= mask2_ # case 4: use depth to check object order: else: # object 1 is closer if (depth * mask1).sum() / mask1.sum() > (depth * mask2).sum() / mask2.sum(): mask2_ -= ((mask1 + mask2) == 2).float() # object 2 is closer if (depth * mask1).sum() / mask1.sum() < (depth * mask2).sum() / mask2.sum(): mask1_ -= ((mask1 + mask2) == 2).float() final_mask = torch.ones_like(depth) * 255 instance_labels = {} for i in range(len(obj_masked_modified)): final_mask = final_mask.masked_fill(obj_masked_modified[i] > 0, i) instance_labels[i] = obj_ids[i].item() return final_mask, instance_labels with torch.no_grad(): for i, test_data in enumerate(tqdm(data_loader)): test_pred = model(test_data) for k in range(len(test_pred)): instance_boxes = test_pred[k]['instances'].get_fields()['pred_boxes'].tensor instance_id = test_pred[k]['instances'].get_fields()['pred_classes'] depth = test_data[k]['depth'] final_mask, instance_labels = get_mask_labels(depth, instance_boxes, instance_id) img_path_split = test_data[k]['image_path'].split('/') im_save_path = os.path.join(save_path, 'obj_detection', img_path_split[-3], img_path_split[-2]) ps = test_data[k]['image_path'].split('.')[-1] os.makedirs(im_save_path, exist_ok=True) height, width = test_data[k]['true_height'], test_data[k]['true_width'] final_mask = Image.fromarray(final_mask.cpu().numpy()).convert('L') final_mask = final_mask.resize((height, width), resample=Image.Resampling.NEAREST) final_mask.save(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.png'))) with open(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.json')), 'w') as fp: json.dump(instance_labels, fp)