Spaces:
Sleeping
Sleeping
import os | |
import cv2 | |
import numpy as np | |
from tqdm import tqdm | |
import sys | |
import imagesize | |
import argparse | |
import torch | |
import pandas as pd | |
import json | |
import monai.metrics as metrics | |
HOT_TRAIN_SPLIT = "/ps/scratch/ps_shared/ychen2/4shashank/split/hot_train.odgt" | |
HOT_VAL_SPLIT = "/ps/scratch/ps_shared/ychen2/4shashank/split/hot_validation.odgt" | |
HOT_TEST_SPLIT = "/ps/scratch/ps_shared/ychen2/4shashank/split/hot_test.odgt" | |
def metric(mask, pred, back=True): | |
iou = metrics.compute_meaniou(pred, mask, back, False) | |
iou = iou.mean() | |
return iou | |
def combine_hot_prox_split(split): | |
if split == 'train': | |
with open(HOT_TRAIN_SPLIT, "r") as f: | |
records = [ | |
json.loads(line.strip("\n")) for line in f.readlines() | |
] | |
elif split == 'val': | |
with open(HOT_VAL_SPLIT, "r") as f: | |
records = [ | |
json.loads(line.strip("\n")) for line in f.readlines() | |
] | |
elif split == 'test': | |
with open(HOT_TEST_SPLIT, "r") as f: | |
records = [ | |
json.loads(line.strip("\n")) for line in f.readlines() | |
] | |
elif split == 'trainval': | |
with open(HOT_TRAIN_SPLIT, "r") as f: | |
train_records = [ | |
json.loads(line.strip("\n")) for line in f.readlines() | |
] | |
with open(HOT_VAL_SPLIT, "r") as f: | |
val_records = [ | |
json.loads(line.strip("\n")) for line in f.readlines() | |
] | |
records = train_records + val_records | |
return records | |
def hot_extract(img_dataset_path, smpl_params_path, dca_csv_path, out_dir, split=None, vis_path=None, visualize=False, record_idx=None, include_supporting=True): | |
n_vertices = 6890 | |
# structs we use | |
imgnames_ = [] | |
poses_, shapes_, transls_ = [], [], [] | |
cams_k_ = [] | |
polygon_2d_contact_ = [] | |
contact_3d_labels_ = [] | |
scene_seg_, part_seg_ = [], [] | |
img_dir = os.path.join(img_dataset_path, 'images', 'training') | |
smpl_params = np.load(smpl_params_path) | |
# smpl_params = np.load(smpl_params_path, allow_pickle=True) | |
# smpl_params = smpl_params['arr_0'].item() | |
annotations_dir = img_dir.replace('images', 'annotations') | |
records = combine_hot_prox_split(split) | |
# split records list into 4 sublists | |
if record_idx is not None: | |
records = np.array_split(records, 4)[record_idx] | |
# load dca csv | |
dca_csv = pd.read_csv(dca_csv_path) | |
iou_thresh = 0 | |
num_with_3d_contact = 0 | |
focal_length_accumulator = [] | |
for i, record in enumerate(tqdm(records, dynamic_ncols=True)): | |
imgpath = record['fpath_img'] | |
imgname = os.path.basename(imgpath) | |
# save image in temp_images | |
if visualize: | |
img = cv2.imread(os.path.join(img_dir, imgname)) | |
cv2.imwrite(os.path.join(vis_path, os.path.basename(imgname)), img) | |
# load image to get the size | |
img_w, img_h = record["width"], record["height"] | |
# get mask anns | |
polygon_2d_contact_path = os.path.join(annotations_dir, os.path.splitext(imgname)[0] + '.png') | |
# Get 3D contact annotations from DCA mturk csv | |
dca_row = dca_csv.loc[dca_csv['imgnames'] == imgname] # if no imgnames column, run scripts/datascripts/add_imgname_column_to_deco_csv.py | |
if len(dca_row) == 0: | |
contact_3d_labels = [] | |
else: | |
num_with_3d_contact += 1 | |
supporting_object = dca_row['supporting_object'].values[0] | |
vertices = eval(dca_row['vertices'].values[0]) | |
contact_3d_list = vertices[os.path.join('hot/training/', imgname)] | |
# Aggregate values in all keys | |
contact_3d_idx = [] | |
for item in contact_3d_list: | |
# one iteration loop as it is a list of one dict key value | |
for k, v in item.items(): | |
if include_supporting: | |
contact_3d_idx.extend(v) | |
else: | |
if k != 'SUPPORTING': | |
contact_3d_idx.extend(v) | |
# removed repeated values | |
contact_3d_idx = list(set(contact_3d_idx)) | |
contact_3d_labels = np.zeros(n_vertices) # smpl has 6980 vertices | |
contact_3d_labels[contact_3d_idx] = 1. | |
# find indices that match the imname | |
inds = np.where(smpl_params['imgname'] == os.path.join(img_dir, imgname))[0] | |
select_inds = [] | |
ious = [] | |
for ind in inds: | |
# part mask | |
part_path = smpl_params['part_seg'][ind] | |
# load the part_mask | |
part_mask = cv2.imread(part_path) | |
# binarize the part mask | |
part_mask = np.where(part_mask > 0, 1, 0) | |
# save part mask | |
if visualize: | |
cv2.imwrite(os.path.join(vis_path, os.path.basename(part_path)), part_mask*255) | |
# load gt polygon mask | |
polygon_2d_contact = cv2.imread(polygon_2d_contact_path) | |
# binarize the gt polygon mask | |
polygon_2d_contact = np.where(polygon_2d_contact > 0, 1, 0) | |
# save gt polygon mask in temp_images | |
if visualize: | |
cv2.imwrite(os.path.join(vis_path, os.path.basename(polygon_2d_contact_path)), polygon_2d_contact*255) | |
polygon_2d_contact = torch.from_numpy(polygon_2d_contact)[None,:].permute(0,3,1,2) | |
part_mask = torch.from_numpy(part_mask)[None,:].permute(0,3,1,2) | |
# compute iou with part mask and gt polygon mask | |
iou = metric(polygon_2d_contact, part_mask) | |
if iou > iou_thresh: | |
ious.append(iou) | |
select_inds.append(ind) | |
# get select_ind with maximum iou | |
if len(select_inds) > 0: | |
max_iou_ind = select_inds[np.argmax(ious)] | |
else: | |
continue | |
for ind in select_inds: | |
# part mask | |
part_path = smpl_params['part_seg'][ind] | |
# scene mask | |
scene_path = smpl_params['scene_seg'][ind] | |
# get smpl params | |
pose = smpl_params['pose'][ind] | |
shape = smpl_params['shape'][ind] | |
transl = smpl_params['global_t'][ind] | |
focal_length = smpl_params['focal_l'][ind] | |
camC = np.array([[img_w//2, img_h//2]]) | |
# read GT 2D keypoints | |
K = np.eye(3, dtype=np.float64) | |
K[0, 0] = focal_length | |
K[1, 1] = focal_length | |
K[:2, 2:] = camC.T | |
# store data | |
imgnames_.append(os.path.join(img_dir, imgname)) | |
polygon_2d_contact_.append(polygon_2d_contact_path) | |
# we use the heuristic that the 3D contact labeled is for the person with maximum iou with HOT contacts | |
if ind == max_iou_ind: | |
contact_3d_labels_.append(contact_3d_labels) | |
else: | |
contact_3d_labels_.append([]) | |
scene_seg_.append(scene_path) | |
part_seg_.append(part_path) | |
poses_.append(pose.squeeze()) | |
transls_.append(transl.squeeze()) | |
shapes_.append(shape.squeeze()) | |
cams_k_.append(K.tolist()) | |
focal_length_accumulator.append(focal_length) | |
print('Average focal length: ', np.mean(focal_length_accumulator)) | |
print('Median focal length: ', np.median(focal_length_accumulator)) | |
print('Std Dev focal length: ', np.std(focal_length_accumulator)) | |
# store the data struct | |
os.makedirs(out_dir, exist_ok=True) | |
if record_idx is not None: | |
out_file = os.path.join(out_dir, f'hot_noprox_supporting_{str(include_supporting)}_{split}_{record_idx}.npz') | |
else: | |
out_file = os.path.join(out_dir, f'hot_noprox_supporting_{str(include_supporting)}_{split}_combined.npz') | |
np.savez(out_file, imgname=imgnames_, | |
pose=poses_, | |
transl=transls_, | |
shape=shapes_, | |
cam_k=cams_k_, | |
polygon_2d_contact=polygon_2d_contact_, | |
contact_label=contact_3d_labels_, | |
scene_seg=scene_seg_, | |
part_seg=part_seg_ | |
) | |
print(f'Total number of rows: {len(imgnames_)}') | |
print('Saved to ', out_file) | |
print(f'Number of images with 3D contact labels: {num_with_3d_contact}') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--img_dataset_path', type=str, default='/ps/project/datasets/HOT/Contact_Data/') | |
parser.add_argument('--smpl_params_path', type=str, default='/ps/scratch/ps_shared/stripathi/deco/4agniv/hot/hot.npz') | |
parser.add_argument('--dca_csv_path', type=str, default='/ps/scratch/ps_shared/stripathi/deco/4agniv/hot/dca.csv') | |
parser.add_argument('--out_dir', type=str, default='/is/cluster/work/stripathi/pycharm_remote/dca_contact/data/dataset_extras') | |
parser.add_argument('--vis_path', type=str, default='/is/cluster/work/stripathi/pycharm_remote/dca_contact/temp_images') | |
parser.add_argument('--visualize', action='store_true', default=False) | |
parser.add_argument('--include_supporting', action='store_true', default=False) | |
parser.add_argument('--record_idx', type=int, default=None) | |
parser.add_argument('--split', type=str, default='train') | |
args = parser.parse_args() | |
hot_extract(img_dataset_path=args.img_dataset_path, | |
smpl_params_path=args.smpl_params_path, | |
dca_csv_path=args.dca_csv_path, | |
out_dir=args.out_dir, | |
vis_path=args.vis_path, | |
visualize=args.visualize, | |
split=args.split, | |
record_idx=args.record_idx, | |
include_supporting=args.include_supporting) | |