File size: 8,628 Bytes
99a05f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
import os
import cv2
import numpy as np
from tqdm import tqdm
import sys
import imagesize
import argparse
import torch
import pandas as pd
import json
import monai.metrics as metrics
PROX_TRAIN_SPLIT = "/ps/scratch/ps_shared/ychen2/4shashank/split/prox_train.odgt"
PROX_VAL_SPLIT = "/ps/scratch/ps_shared/ychen2/4shashank/split/prox_validation.odgt"
PROX_TEST_SPLIT = "/ps/scratch/ps_shared/ychen2/4shashank/split/prox_test.odgt"
def metric(mask, pred, back=True):
iou = metrics.compute_meaniou(pred, mask, back, False)
iou = iou.mean()
return iou
def combine_hot_prox_split(split):
if split == 'train':
with open(PROX_TRAIN_SPLIT, "r") as f:
records = [
json.loads(line.strip("\n")) for line in f.readlines()
elif split == 'val':
with open(PROX_VAL_SPLIT, "r") as f:
records = [
json.loads(line.strip("\n")) for line in f.readlines()
elif split == 'test':
with open(PROX_TEST_SPLIT, "r") as f:
records = [
json.loads(line.strip("\n")) for line in f.readlines()
return records
def hot_extract(img_dataset_path, smpl_params_path, dca_csv_path, out_dir, split=None, vis_path=None, visualize=False, downsample_factor=4):
n_vertices = 6890
# structs we use
imgnames_ = []
poses_, shapes_, transls_ = [], [], []
cams_k_ = []
polygon_2d_contact_ = []
contact_3d_labels_ = []
scene_seg_, part_seg_ = [], []
img_dir = os.path.join(img_dataset_path, 'images', 'training')
smpl_params = np.load(smpl_params_path)
# smpl_params = np.load(smpl_params_path, allow_pickle=True)
# smpl_params = smpl_params['arr_0'].item()
annotations_dir = img_dir.replace('images', 'annotations')
records = combine_hot_prox_split(split)
# load dca csv
dca_csv = pd.read_csv(dca_csv_path)
iou_thresh = 0
num_with_3d_contact = 0
focal_length_accumulator = []
for i, record in enumerate(tqdm(records, dynamic_ncols=True)):
imgpath = record['fpath_img']
imgname = os.path.basename(imgpath)
# save image in temp_images
if visualize:
img = cv2.imread(os.path.join(img_dir, imgname))
cv2.imwrite(os.path.join(vis_path, os.path.basename(imgname)), img)
# load image to get the size
img_w, img_h = record["width"], record["height"]
# get mask anns
polygon_2d_contact_path = os.path.join(annotations_dir, os.path.splitext(imgname)[0] + '.png')
# Get 3D contact annotations from DCA mturk csv
dca_row = dca_csv.loc[dca_csv['imgnames'] == imgname] # if no imgnames column, run scripts/datascripts/
if len(dca_row) == 0:
contact_3d_labels = []
num_with_3d_contact += 1
supporting_object = dca_row['supporting_object'].values[0]
vertices = eval(dca_row['vertices'].values[0])
contact_3d_list = vertices[os.path.join('hot/training/', imgname)]
# Aggregate values in all keys
contact_3d_idx = []
for item in contact_3d_list:
# one iteration loop as it is a list of one dict key value
for k, v in item.items():
# removed repeated values
contact_3d_idx = list(set(contact_3d_idx))
contact_3d_labels = np.zeros(n_vertices) # smpl has 6980 vertices
contact_3d_labels[contact_3d_idx] = 1.
# find indices that match the imname
inds = np.where(smpl_params['imgname'] == os.path.join(img_dir, imgname))[0]
select_inds = []
ious = []
for ind in inds:
# part mask
part_path = smpl_params['part_seg'][ind]
# load the part_mask
part_mask = cv2.imread(part_path)
# binarize the part mask
part_mask = np.where(part_mask > 0, 1, 0)
# save part mask
if visualize:
cv2.imwrite(os.path.join(vis_path, os.path.basename(part_path)), part_mask*255)
# load gt polygon mask
polygon_2d_contact = cv2.imread(polygon_2d_contact_path)
# binarize the gt polygon mask
polygon_2d_contact = np.where(polygon_2d_contact > 0, 1, 0)
# save gt polygon mask in temp_images
if visualize:
cv2.imwrite(os.path.join(vis_path, os.path.basename(polygon_2d_contact_path)), polygon_2d_contact*255)
polygon_2d_contact = torch.from_numpy(polygon_2d_contact)[None,:].permute(0,3,1,2)
part_mask = torch.from_numpy(part_mask)[None,:].permute(0,3,1,2)
# compute iou with part mask and gt polygon mask
iou = metric(polygon_2d_contact, part_mask)
if iou > iou_thresh:
# get select_ind with maximum iou
if len(select_inds) > 0:
max_iou_ind = select_inds[np.argmax(ious)]
for ind in select_inds:
# part mask
part_path = smpl_params['part_seg'][ind]
# scene mask
scene_path = smpl_params['scene_seg'][ind]
# get smpl params
pose = smpl_params['pose'][ind]
shape = smpl_params['shape'][ind]
transl = smpl_params['global_t'][ind]
focal_length = smpl_params['focal_l'][ind]
camC = np.array([[img_w//2, img_h//2]])
# read GT 2D keypoints
K = np.eye(3, dtype=np.float64)
K[0, 0] = focal_length
K[1, 1] = focal_length
K[:2, 2:] = camC.T
# store data
imgnames_.append(os.path.join(img_dir, imgname))
# we use the heuristic that the 3D contact labeled is for the person with maximum iou with HOT contacts
if ind == max_iou_ind:
print('Average focal length: ', np.mean(focal_length_accumulator))
print('Median focal length: ', np.median(focal_length_accumulator))
print('Std Dev focal length: ', np.std(focal_length_accumulator))
# store the data struct
os.makedirs(out_dir, exist_ok=True)
out_file = os.path.join(out_dir, f'hot_prox_{split}.npz')
np.savez(out_file, imgname=imgnames_,
print(f'Total number of rows: {len(imgnames_)}')
print('Saved to ', out_file)
print(f'Number of images with 3D contact labels: {num_with_3d_contact}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--img_dataset_path', type=str, default='/ps/project/datasets/HOT/Contact_Data/')
parser.add_argument('--smpl_params_path', type=str, default='/ps/scratch/ps_shared/stripathi/deco/4agniv/hot/hot.npz')
parser.add_argument('--dca_csv_path', type=str, default='/ps/scratch/ps_shared/stripathi/deco/4agniv/hot/dca.csv')
parser.add_argument('--out_dir', type=str, default='/is/cluster/work/stripathi/pycharm_remote/dca_contact/data/dataset_extras')
parser.add_argument('--vis_path', type=str, default='/is/cluster/work/stripathi/pycharm_remote/dca_contact/temp_images')
parser.add_argument('--visualize', action='store_true', default=False)
parser.add_argument('--split', type=str, default='train')
args = parser.parse_args()