File size: 9,798 Bytes
99a05f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import os
import cv2
import numpy as np
from tqdm import tqdm
import sys
import imagesize
import argparse
import torch
import pandas as pd
import json

import monai.metrics as metrics

HOT_TRAIN_SPLIT = "/ps/scratch/ps_shared/ychen2/4shashank/split/hot_train.odgt"
HOT_VAL_SPLIT = "/ps/scratch/ps_shared/ychen2/4shashank/split/hot_validation.odgt"
HOT_TEST_SPLIT = "/ps/scratch/ps_shared/ychen2/4shashank/split/hot_test.odgt"

def metric(mask, pred, back=True):
  iou = metrics.compute_meaniou(pred, mask, back, False)
  iou = iou.mean()
  return iou


def combine_hot_prox_split(split):
    if split == 'train':
        with open(HOT_TRAIN_SPLIT, "r") as f:
            records = [
                json.loads(line.strip("\n")) for line in f.readlines()
            ]
    elif split == 'val':
        with open(HOT_VAL_SPLIT, "r") as f:
            records = [
                json.loads(line.strip("\n")) for line in f.readlines()
            ]
    elif split == 'test':
        with open(HOT_TEST_SPLIT, "r") as f:
            records = [
                json.loads(line.strip("\n")) for line in f.readlines()
            ]
    elif split == 'trainval':
        with open(HOT_TRAIN_SPLIT, "r") as f:
            train_records = [
                json.loads(line.strip("\n")) for line in f.readlines()
            ]
        with open(HOT_VAL_SPLIT, "r") as f:
            val_records = [
                json.loads(line.strip("\n")) for line in f.readlines()
            ]
        records = train_records + val_records
    return records

def hot_extract(img_dataset_path, smpl_params_path, dca_csv_path, out_dir, split=None, vis_path=None, visualize=False, record_idx=None, include_supporting=True):

    n_vertices = 6890

    # structs we use
    imgnames_ = []
    poses_, shapes_, transls_ = [], [], []
    cams_k_ = []
    polygon_2d_contact_ = []
    contact_3d_labels_ = []
    scene_seg_, part_seg_ = [], []

    img_dir = os.path.join(img_dataset_path, 'images', 'training')
    smpl_params = np.load(smpl_params_path)
    # smpl_params = np.load(smpl_params_path, allow_pickle=True)
    # smpl_params = smpl_params['arr_0'].item()
    annotations_dir = img_dir.replace('images', 'annotations')
    records = combine_hot_prox_split(split)
    # split records list into 4 sublists
    if record_idx is not None:
        records = np.array_split(records, 4)[record_idx]

    # load dca csv
    dca_csv = pd.read_csv(dca_csv_path)

    iou_thresh = 0

    num_with_3d_contact = 0

    focal_length_accumulator = []
    for i, record in enumerate(tqdm(records, dynamic_ncols=True)):
        imgpath = record['fpath_img']
        imgname = os.path.basename(imgpath)
        # save image in temp_images
        if visualize:
            img = cv2.imread(os.path.join(img_dir, imgname))
            cv2.imwrite(os.path.join(vis_path, os.path.basename(imgname)), img)

        # load image to get the size
        img_w, img_h = record["width"], record["height"]

        # get mask anns
        polygon_2d_contact_path = os.path.join(annotations_dir, os.path.splitext(imgname)[0] + '.png')


        # Get 3D contact annotations from DCA mturk csv
        dca_row = dca_csv.loc[dca_csv['imgnames'] == imgname] # if no imgnames column, run scripts/datascripts/add_imgname_column_to_deco_csv.py
        if len(dca_row) == 0:
            contact_3d_labels = []
        else:
            num_with_3d_contact += 1
            supporting_object = dca_row['supporting_object'].values[0]
            vertices = eval(dca_row['vertices'].values[0])
            contact_3d_list = vertices[os.path.join('hot/training/', imgname)]
            # Aggregate values in all keys
            contact_3d_idx = []
            for item in contact_3d_list:
                # one iteration loop as it is a list of one dict key value
                for k, v in item.items():
                    if include_supporting:
                        contact_3d_idx.extend(v)
                    else:
                        if k != 'SUPPORTING':
                            contact_3d_idx.extend(v)
            # removed repeated values
            contact_3d_idx = list(set(contact_3d_idx))
            contact_3d_labels = np.zeros(n_vertices) # smpl has 6980 vertices
            contact_3d_labels[contact_3d_idx] = 1.

        # find indices that match the imname
        inds = np.where(smpl_params['imgname'] == os.path.join(img_dir, imgname))[0]
        select_inds = []
        ious = []
        for ind in inds:
            # part mask
            part_path = smpl_params['part_seg'][ind]
            # load the part_mask
            part_mask = cv2.imread(part_path)
            # binarize the part mask
            part_mask = np.where(part_mask > 0, 1, 0)
            # save part mask
            if visualize:
                cv2.imwrite(os.path.join(vis_path, os.path.basename(part_path)), part_mask*255)

            # load gt polygon mask
            polygon_2d_contact = cv2.imread(polygon_2d_contact_path)
            # binarize the gt polygon mask
            polygon_2d_contact = np.where(polygon_2d_contact > 0, 1, 0)

            # save gt polygon mask in temp_images
            if visualize:
                cv2.imwrite(os.path.join(vis_path, os.path.basename(polygon_2d_contact_path)), polygon_2d_contact*255)

            polygon_2d_contact = torch.from_numpy(polygon_2d_contact)[None,:].permute(0,3,1,2)
            part_mask = torch.from_numpy(part_mask)[None,:].permute(0,3,1,2)
            # compute iou with part mask and gt polygon mask
            iou = metric(polygon_2d_contact, part_mask)
            if iou > iou_thresh:
                ious.append(iou)
                select_inds.append(ind)

        # get select_ind with maximum iou
        if len(select_inds) > 0:
            max_iou_ind = select_inds[np.argmax(ious)]
        else:
            continue

        for ind in select_inds:
            # part mask
            part_path = smpl_params['part_seg'][ind]

            # scene mask
            scene_path = smpl_params['scene_seg'][ind]

            # get smpl params
            pose = smpl_params['pose'][ind]
            shape = smpl_params['shape'][ind]
            transl = smpl_params['global_t'][ind]
            focal_length = smpl_params['focal_l'][ind]
            camC = np.array([[img_w//2, img_h//2]])

            # read GT 2D keypoints
            K = np.eye(3, dtype=np.float64)
            K[0, 0] = focal_length
            K[1, 1] = focal_length
            K[:2, 2:] = camC.T

            # store data
            imgnames_.append(os.path.join(img_dir, imgname))
            polygon_2d_contact_.append(polygon_2d_contact_path)
            # we use the heuristic that the 3D contact labeled is for the person with maximum iou with HOT contacts
            if ind == max_iou_ind:
                contact_3d_labels_.append(contact_3d_labels)
            else:
                contact_3d_labels_.append([])
            scene_seg_.append(scene_path)
            part_seg_.append(part_path)
            poses_.append(pose.squeeze())
            transls_.append(transl.squeeze())
            shapes_.append(shape.squeeze())
            cams_k_.append(K.tolist())
        focal_length_accumulator.append(focal_length)

    print('Average focal length: ', np.mean(focal_length_accumulator))
    print('Median focal length: ', np.median(focal_length_accumulator))
    print('Std Dev focal length: ', np.std(focal_length_accumulator))

    # store the data struct
    os.makedirs(out_dir, exist_ok=True)
    if record_idx is not None:
        out_file = os.path.join(out_dir, f'hot_noprox_supporting_{str(include_supporting)}_{split}_{record_idx}.npz')
    else:
        out_file = os.path.join(out_dir, f'hot_noprox_supporting_{str(include_supporting)}_{split}_combined.npz')
    np.savez(out_file, imgname=imgnames_,
                       pose=poses_,
                       transl=transls_,
                       shape=shapes_,
                       cam_k=cams_k_,
                       polygon_2d_contact=polygon_2d_contact_,
                       contact_label=contact_3d_labels_,
                        scene_seg=scene_seg_,
                        part_seg=part_seg_
             )
    print(f'Total number of rows: {len(imgnames_)}')
    print('Saved to ', out_file)
    print(f'Number of images with 3D contact labels: {num_with_3d_contact}')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--img_dataset_path', type=str, default='/ps/project/datasets/HOT/Contact_Data/')
    parser.add_argument('--smpl_params_path', type=str, default='/ps/scratch/ps_shared/stripathi/deco/4agniv/hot/hot.npz')
    parser.add_argument('--dca_csv_path', type=str, default='/ps/scratch/ps_shared/stripathi/deco/4agniv/hot/dca.csv')
    parser.add_argument('--out_dir', type=str, default='/is/cluster/work/stripathi/pycharm_remote/dca_contact/data/dataset_extras')
    parser.add_argument('--vis_path', type=str, default='/is/cluster/work/stripathi/pycharm_remote/dca_contact/temp_images')
    parser.add_argument('--visualize', action='store_true', default=False)
    parser.add_argument('--include_supporting', action='store_true', default=False)
    parser.add_argument('--record_idx', type=int, default=None)
    parser.add_argument('--split', type=str, default='train')
    args = parser.parse_args()

    hot_extract(img_dataset_path=args.img_dataset_path,
                smpl_params_path=args.smpl_params_path,
                dca_csv_path=args.dca_csv_path,
                out_dir=args.out_dir,
                vis_path=args.vis_path,
                visualize=args.visualize,
                split=args.split,
                record_idx=args.record_idx,
                include_supporting=args.include_supporting)