IF3D / preprocessing /preprocessing.py
leobcc's picture
vid2avatar baseline
6325697
raw
history blame
13.7 kB
import numpy as np
import pickle as pkl
import torch
import trimesh
import cv2
import os
from tqdm import tqdm
import glob
import argparse
from preprocessing_utils import (smpl_to_pose, PerspectiveCamera, Renderer, render_trimesh, \
estimate_translation_cv2, transform_smpl)
from loss import joints_2d_loss, pose_temporal_loss, get_loss_weights
def main(args):
device = torch.device("cuda:0")
seq = args.seq
gender = args.gender
DIR = './raw_data'
img_dir = f'{DIR}/{seq}/frames'
romp_file_dir = f'{DIR}/{seq}/ROMP'
img_paths = sorted(glob.glob(f"{img_dir}/*.png"))
romp_file_paths = sorted(glob.glob(f"{romp_file_dir}/*.npz"))
from smplx import SMPL
smpl_model = SMPL('../code/lib/smpl/smpl_model', gender=gender).to(device)
input_img = cv2.imread(img_paths[0])
if args.source == 'custom':
focal_length = max(input_img.shape[0], input_img.shape[1])
cam_intrinsics = np.array([[focal_length, 0., input_img.shape[1]//2],
[0., focal_length, input_img.shape[0]//2],
[0., 0., 1.]])
elif args.source == 'neuman':
NeuMan_DIR = '' # path to NeuMan dataset
with open(f'{NeuMan_DIR}/{seq}/sparse/cameras.txt') as f:
lines = f.readlines()
cam_params = lines[3].split()
cam_intrinsics = np.array([[float(cam_params[4]), 0., float(cam_params[6])],
[0., float(cam_params[5]), float(cam_params[7])],
[0., 0., 1.]])
elif args.source == 'deepcap':
DeepCap_DIR = '' # path to DeepCap dataset
with open(f'{DeepCap_DIR}/monocularCalibrationBM.calibration') as f:
lines = f.readlines()
cam_params = lines[5].split()
cam_intrinsics = np.array([[float(cam_params[1]), 0., float(cam_params[3])],
[0., float(cam_params[6]), float(cam_params[7])],
[0., 0., 1.]])
else:
print('Please specify the source of the dataset (custom, neuman, deepcap). We will continue to update the sources in the future.')
raise NotImplementedError
renderer = Renderer(img_size = [input_img.shape[0], input_img.shape[1]], cam_intrinsic=cam_intrinsics)
if args.mode == 'mask':
if not os.path.exists(f'{DIR}/{seq}/init_mask'):
os.makedirs(f'{DIR}/{seq}/init_mask')
elif args.mode == 'refine':
if not os.path.exists(f'{DIR}/{seq}/init_refined_smpl'):
os.makedirs(f'{DIR}/{seq}/init_refined_smpl')
if not os.path.exists(f'{DIR}/{seq}/init_refined_mask'):
os.makedirs(f'{DIR}/{seq}/init_refined_mask')
if not os.path.exists(f'{DIR}/{seq}/init_refined_smpl_files'):
os.makedirs(f'{DIR}/{seq}/init_refined_smpl_files')
openpose_dir = f'{DIR}/{seq}/openpose'
openpose_paths = sorted(glob.glob(f"{openpose_dir}/*.npy"))
opt_num_iters=150
weight_dict = get_loss_weights()
cam = PerspectiveCamera(focal_length_x=torch.tensor(cam_intrinsics[0, 0], dtype=torch.float32),
focal_length_y=torch.tensor(cam_intrinsics[1, 1], dtype=torch.float32),
center=torch.tensor(cam_intrinsics[0:2, 2]).unsqueeze(0)).to(device)
mean_shape = []
smpl2op_mapping = torch.tensor(smpl_to_pose(model_type='smpl', use_hands=False, use_face=False,
use_face_contour=False, openpose_format='coco25'), dtype=torch.long).cuda()
elif args.mode == 'final':
refined_smpl_dir = f'{DIR}/{seq}/init_refined_smpl_files'
refined_smpl_mask_dir = f'{DIR}/{seq}/init_refined_mask'
refined_smpl_paths = sorted(glob.glob(f"{refined_smpl_dir}/*.pkl"))
refined_smpl_mask_paths = sorted(glob.glob(f"{refined_smpl_mask_dir}/*.png"))
save_dir = f'../data/{seq}'
if not os.path.exists(os.path.join(save_dir, 'image')):
os.makedirs(os.path.join(save_dir, 'image'))
if not os.path.exists(os.path.join(save_dir, 'mask')):
os.makedirs(os.path.join(save_dir, 'mask'))
scale_factor = args.scale_factor
smpl_shape = np.load(f'{DIR}/{seq}/mean_shape.npy')
T_hip = smpl_model.get_T_hip(betas=torch.tensor(smpl_shape)[None].float().to(device)).squeeze().cpu().numpy()
K = np.eye(4)
K[:3, :3] = cam_intrinsics
K[0, 0] = K[0, 0] / scale_factor
K[1, 1] = K[1, 1] / scale_factor
K[0, 2] = K[0, 2] / scale_factor
K[1, 2] = K[1, 2] / scale_factor
dial_kernel = np.ones((20, 20),np.uint8)
output_trans = []
output_pose = []
output_P = {}
last_j3d = None
actor_id = 0
cam_extrinsics = np.eye(4)
R = torch.tensor(cam_extrinsics[:3,:3])[None].float()
T = torch.tensor(cam_extrinsics[:3, 3])[None].float()
for idx, img_path in enumerate(tqdm(img_paths)):
input_img = cv2.imread(img_path)
if args.mode == 'mask' or args.mode == 'refine':
seq_file = np.load(romp_file_paths[idx], allow_pickle=True)['results'][()]
# tracking in case of two persons or wrong ROMP detection
if len(seq_file['smpl_thetas']) >= 2:
dist = []
if idx == 0:
last_j3d = seq_file['joints'][actor_id]
for i in range(len(seq_file['smpl_thetas'])):
dist.append(np.linalg.norm(seq_file['joints'][i].mean(0) - last_j3d.mean(0, keepdims=True)))
actor_id = np.argmin(dist)
smpl_verts = seq_file['verts'][actor_id]
pj2d_org = seq_file['pj2d_org'][actor_id]
joints3d = seq_file['joints'][actor_id]
last_j3d = joints3d.copy()
tra_pred = estimate_translation_cv2(joints3d, pj2d_org, proj_mat=cam_intrinsics)
smpl_verts += tra_pred
if args.mode == 'refine':
openpose = np.load(openpose_paths[idx])
openpose_j2d = torch.tensor(openpose[:, :2][None], dtype=torch.float32, requires_grad=False, device=device)
openpose_conf = torch.tensor(openpose[:, -1][None], dtype=torch.float32, requires_grad=False, device=device)
smpl_shape = seq_file['smpl_betas'][actor_id][:10]
smpl_pose = seq_file['smpl_thetas'][actor_id]
smpl_trans = tra_pred
opt_betas = torch.tensor(smpl_shape[None], dtype=torch.float32, requires_grad=True, device=device)
opt_pose = torch.tensor(smpl_pose[None], dtype=torch.float32, requires_grad=True, device=device)
opt_trans = torch.tensor(smpl_trans[None], dtype=torch.float32, requires_grad=True, device=device)
opt_params = [{'params': opt_betas, 'lr': 1e-3},
{'params': opt_pose, 'lr': 1e-3},
{'params': opt_trans, 'lr': 1e-3}]
optimizer = torch.optim.Adam(opt_params, lr=2e-3, betas=(0.9, 0.999))
if idx == 0:
last_pose = [opt_pose.detach().clone()]
loop = tqdm(range(opt_num_iters))
for it in loop:
optimizer.zero_grad()
smpl_output = smpl_model(betas=opt_betas,
body_pose=opt_pose[:,3:],
global_orient=opt_pose[:,:3],
transl=opt_trans)
smpl_verts = smpl_output.vertices.data.cpu().numpy().squeeze()
smpl_joints_2d = cam(torch.index_select(smpl_output.joints, 1, smpl2op_mapping))
loss = dict()
loss['J2D_Loss'] = joints_2d_loss(openpose_j2d, smpl_joints_2d, openpose_conf)
loss['Temporal_Loss'] = pose_temporal_loss(last_pose[0], opt_pose)
w_loss = dict()
for k in loss:
w_loss[k] = weight_dict[k](loss[k], it)
tot_loss = list(w_loss.values())
tot_loss = torch.stack(tot_loss).sum()
tot_loss.backward()
optimizer.step()
l_str = 'Iter: %d' % it
for k in loss:
l_str += ', %s: %0.4f' % (k, weight_dict[k](loss[k], it).mean().item())
loop.set_description(l_str)
smpl_mesh = trimesh.Trimesh(smpl_verts, smpl_model.faces, process=False)
R = torch.tensor(cam_extrinsics[:3,:3])[None].float()
T = torch.tensor(cam_extrinsics[:3, 3])[None].float()
rendered_image = render_trimesh(renderer, smpl_mesh, R, T, 'n')
if input_img.shape[0] < input_img.shape[1]:
rendered_image = rendered_image[abs(input_img.shape[0]-input_img.shape[1])//2:(input_img.shape[0]+input_img.shape[1])//2,...]
else:
rendered_image = rendered_image[:,abs(input_img.shape[0]-input_img.shape[1])//2:(input_img.shape[0]+input_img.shape[1])//2]
valid_mask = (rendered_image[:,:,-1] > 0)[:, :, np.newaxis]
if args.mode == 'mask':
cv2.imwrite(os.path.join(f'{DIR}/{seq}/init_mask', '%04d.png' % idx), valid_mask*255)
elif args.mode == 'refine':
output_img = (rendered_image[:,:,:-1] * valid_mask + input_img * (1 - valid_mask)).astype(np.uint8)
cv2.imwrite(os.path.join(f'{DIR}/{seq}/init_refined_smpl', '%04d.png' % idx), output_img)
cv2.imwrite(os.path.join(f'{DIR}/{seq}/init_refined_mask', '%04d.png' % idx), valid_mask*255)
last_pose.pop(0)
last_pose.append(opt_pose.detach().clone())
smpl_dict = {}
smpl_dict['pose'] = opt_pose.data.squeeze().cpu().numpy()
smpl_dict['trans'] = opt_trans.data.squeeze().cpu().numpy()
smpl_dict['shape'] = opt_betas.data.squeeze().cpu().numpy()
mean_shape.append(smpl_dict['shape'])
pkl.dump(smpl_dict, open(os.path.join(f'{DIR}/{seq}/init_refined_smpl_files', '%04d.pkl' % idx), 'wb'))
elif args.mode == 'final':
input_img = cv2.resize(input_img, (input_img.shape[1] // scale_factor, input_img.shape[0] // scale_factor))
seq_file = pkl.load(open(refined_smpl_paths[idx], 'rb'))
mask = cv2.imread(refined_smpl_mask_paths[idx])
mask = cv2.resize(mask, (mask.shape[1] // scale_factor, mask.shape[0] // scale_factor))
# dilate mask to obtain a coarse bbox
mask = cv2.dilate(mask, dial_kernel)
cv2.imwrite(os.path.join(save_dir, 'image/%04d.png' % idx), input_img)
cv2.imwrite(os.path.join(save_dir, 'mask/%04d.png' % idx), mask)
smpl_pose = seq_file['pose']
smpl_trans = seq_file['trans']
# transform the spaces such that our camera has the same orientation as the OpenGL camera
target_extrinsic = np.eye(4)
target_extrinsic[1:3] *= -1
target_extrinsic, smpl_pose, smpl_trans = transform_smpl(cam_extrinsics, target_extrinsic, smpl_pose, smpl_trans, T_hip)
smpl_output = smpl_model(betas=torch.tensor(smpl_shape)[None].float().to(device),
body_pose=torch.tensor(smpl_pose[3:])[None].float().to(device),
global_orient=torch.tensor(smpl_pose[:3])[None].float().to(device),
transl=torch.tensor(smpl_trans)[None].float().to(device))
smpl_verts = smpl_output.vertices.data.cpu().numpy().squeeze()
# we need to center the human for every frame due to the potentially large global movement
v_max = smpl_verts.max(axis=0)
v_min = smpl_verts.min(axis=0)
normalize_shift = -(v_max + v_min) / 2.
trans = smpl_trans + normalize_shift
target_extrinsic[:3, -1] = target_extrinsic[:3, -1] - (target_extrinsic[:3, :3] @ normalize_shift)
P = K @ target_extrinsic
output_trans.append(trans)
output_pose.append(smpl_pose)
output_P[f"cam_{idx}"] = P
if args.mode == 'refine':
mean_shape = np.array(mean_shape)
np.save(f'{DIR}/{seq}/mean_shape.npy', mean_shape.mean(0))
if args.mode == 'final':
np.save(os.path.join(save_dir, 'poses.npy'), np.array(output_pose))
np.save(os.path.join(save_dir, 'mean_shape.npy'), smpl_shape)
np.save(os.path.join(save_dir, 'normalize_trans.npy'), np.array(output_trans))
np.savez(os.path.join(save_dir, "cameras.npz"), **output_P)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Preprocessing data")
# video source
parser.add_argument('--source', type=str, default='custom', help="custom video or dataset video")
# sequence name
parser.add_argument('--seq', type=str)
# gender
parser.add_argument('--gender', type=str, help="gender of the actor: MALE or FEMALE")
# mode
parser.add_argument('--mode', type=str, help="mask mode or refine mode: mask or refine or final")
# scale factor for the input image
parser.add_argument('--scale_factor', type=int, default=2, help="scale factor for the input image")
args = parser.parse_args()
main(args)