Spaces:
Sleeping
Sleeping
# Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, visit | |
# https://github.com/NVlabs/prismer/blob/main/LICENSE | |
import torch | |
import os | |
import json | |
import copy | |
import PIL.Image as Image | |
try: | |
import ruamel_yaml as yaml | |
except ModuleNotFoundError: | |
import ruamel.yaml as yaml | |
from experts.model_bank import load_expert_model | |
from experts.obj_detection.generate_dataset import Dataset, collate_fn | |
from accelerate import Accelerator | |
from tqdm import tqdm | |
model, transform = load_expert_model(task='obj_detection') | |
accelerator = Accelerator(mixed_precision='fp16') | |
config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader) | |
data_path = config['data_path'] | |
save_path = config['save_path'] | |
depth_path = os.path.join(save_path, 'depth', data_path.split('/')[-1]) | |
batch_size = 32 | |
dataset = Dataset(data_path, depth_path, transform) | |
data_loader = torch.utils.data.DataLoader( | |
dataset=dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=4, | |
pin_memory=True, | |
collate_fn=collate_fn, | |
) | |
model, data_loader = accelerator.prepare(model, data_loader) | |
def get_mask_labels(depth, instance_boxes, instance_id): | |
obj_masks = [] | |
obj_ids = [] | |
for i in range(len(instance_boxes)): | |
is_duplicate = False | |
mask = torch.zeros_like(depth) | |
x1, y1, x2, y2 = instance_boxes[i][0].item(), instance_boxes[i][1].item(), \ | |
instance_boxes[i][2].item(), instance_boxes[i][3].item() | |
mask[int(y1):int(y2), int(x1):int(x2)] = 1 | |
for j in range(len(obj_masks)): | |
if ((mask + obj_masks[j]) == 2).sum() / ((mask + obj_masks[j]) > 0).sum() > 0.95: | |
is_duplicate = True | |
break | |
if not is_duplicate: | |
obj_masks.append(mask) | |
obj_ids.append(instance_id[i]) | |
obj_masked_modified = copy.deepcopy(obj_masks[:]) | |
for i in range(len(obj_masks) - 1): | |
mask1 = obj_masks[i] | |
mask1_ = obj_masked_modified[i] | |
for j in range(i + 1, len(obj_masks)): | |
mask2 = obj_masks[j] | |
mask2_ = obj_masked_modified[j] | |
# case 1: if they don't intersect we don't touch them | |
if ((mask1 + mask2) == 2).sum() == 0: | |
continue | |
# case 2: the entire object 1 is inside of object 2, we say object 1 is in front of object 2: | |
elif (((mask1 + mask2) == 2).float() - mask1).sum() == 0: | |
mask2_ -= mask1_ | |
# case 3: the entire object 2 is inside of object 1, we say object 2 is in front of object 1: | |
elif (((mask1 + mask2) == 2).float() - mask2).sum() == 0: | |
mask1_ -= mask2_ | |
# case 4: use depth to check object order: | |
else: | |
# object 1 is closer | |
if (depth * mask1).sum() / mask1.sum() > (depth * mask2).sum() / mask2.sum(): | |
mask2_ -= ((mask1 + mask2) == 2).float() | |
# object 2 is closer | |
if (depth * mask1).sum() / mask1.sum() < (depth * mask2).sum() / mask2.sum(): | |
mask1_ -= ((mask1 + mask2) == 2).float() | |
final_mask = torch.ones_like(depth) * 255 | |
instance_labels = {} | |
for i in range(len(obj_masked_modified)): | |
final_mask = final_mask.masked_fill(obj_masked_modified[i] > 0, i) | |
instance_labels[i] = obj_ids[i].item() | |
return final_mask, instance_labels | |
with torch.no_grad(): | |
for i, test_data in enumerate(tqdm(data_loader)): | |
test_pred = model(test_data) | |
for k in range(len(test_pred)): | |
instance_boxes = test_pred[k]['instances'].get_fields()['pred_boxes'].tensor | |
instance_id = test_pred[k]['instances'].get_fields()['pred_classes'] | |
depth = test_data[k]['depth'] | |
final_mask, instance_labels = get_mask_labels(depth, instance_boxes, instance_id) | |
img_path_split = test_data[k]['image_path'].split('/') | |
im_save_path = os.path.join(save_path, 'obj_detection', img_path_split[-3], img_path_split[-2]) | |
ps = test_data[k]['image_path'].split('.')[-1] | |
os.makedirs(im_save_path, exist_ok=True) | |
height, width = test_data[k]['true_height'], test_data[k]['true_width'] | |
final_mask = Image.fromarray(final_mask.cpu().numpy()).convert('L') | |
final_mask = final_mask.resize((height, width), resample=Image.Resampling.NEAREST) | |
final_mask.save(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.png'))) | |
with open(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.json')), 'w') as fp: | |
json.dump(instance_labels, fp) | |