import torch import os import glob import argparse import numpy as np import cv2 import PIL.Image as pil_img from loguru import logger import shutil import trimesh import pyrender from models.deco import DECO from common import constants os.environ['PYOPENGL_PLATFORM'] = 'egl' if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') def initiate_model(args): deco_model = DECO('hrnet', True, device) logger.info(f'Loading weights from {args.model_path}') checkpoint = torch.load(args.model_path) deco_model.load_state_dict(checkpoint['deco'], strict=True) deco_model.eval() return deco_model def render_image(scene, img_res, img=None, viewer=False): ''' Render the given pyrender scene and return the image. Can also overlay the mesh on an image. ''' if viewer: pyrender.Viewer(scene, use_raymond_lighting=True) return 0 else: r = pyrender.OffscreenRenderer(viewport_width=img_res, viewport_height=img_res, point_size=1.0) color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA) color = color.astype(np.float32) / 255.0 if img is not None: valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis] input_img = img.detach().cpu().numpy() output_img = (color[:, :, :-1] * valid_mask + (1 - valid_mask) * input_img) else: output_img = color return output_img def create_scene(mesh, img, focal_length=500, camera_center=250, img_res=500): # Setup the scene scene = pyrender.Scene(bg_color=[1.0, 1.0, 1.0, 1.0], ambient_light=(0.3, 0.3, 0.3)) # add mesh for camera camera_pose = np.eye(4) camera_rotation = np.eye(3, 3) camera_translation = np.array([0., 0, 2.5]) camera_pose[:3, :3] = camera_rotation camera_pose[:3, 3] = camera_rotation @ camera_translation pyrencamera = pyrender.camera.IntrinsicsCamera( fx=focal_length, fy=focal_length, cx=camera_center, cy=camera_center) scene.add(pyrencamera, pose=camera_pose) # create and add light light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=1) light_pose = np.eye(4) for lp in [[1, 1, 1], [-1, 1, 1], [1, -1, 1], [-1, -1, 1]]: light_pose[:3, 3] = mesh.vertices.mean(0) + np.array(lp) # out_mesh.vertices.mean(0) + np.array(lp) scene.add(light, pose=light_pose) # add body mesh material = pyrender.MetallicRoughnessMaterial( metallicFactor=0.0, alphaMode='OPAQUE', baseColorFactor=(1.0, 1.0, 0.9, 1.0)) mesh_images = [] # resize input image to fit the mesh image height img_height = img_res img_width = int(img_height * img.shape[1] / img.shape[0]) img = cv2.resize(img, (img_width, img_height)) mesh_images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) for sideview_angle in [0, 90, 180, 270]: out_mesh = mesh.copy() rot = trimesh.transformations.rotation_matrix( np.radians(sideview_angle), [0, 1, 0]) out_mesh.apply_transform(rot) out_mesh = pyrender.Mesh.from_trimesh( out_mesh, material=material) mesh_pose = np.eye(4) scene.add(out_mesh, pose=mesh_pose, name='mesh') output_img = render_image(scene, img_res) output_img = pil_img.fromarray((output_img * 255).astype(np.uint8)) output_img = np.asarray(output_img)[:, :, :3] mesh_images.append(output_img) # delete the previous mesh prev_mesh = scene.get_nodes(name='mesh').pop() scene.remove_node(prev_mesh) # show upside down view for topview_angle in [90, 270]: out_mesh = mesh.copy() rot = trimesh.transformations.rotation_matrix( np.radians(topview_angle), [1, 0, 0]) out_mesh.apply_transform(rot) out_mesh = pyrender.Mesh.from_trimesh( out_mesh, material=material) mesh_pose = np.eye(4) scene.add(out_mesh, pose=mesh_pose, name='mesh') output_img = render_image(scene, img_res) output_img = pil_img.fromarray((output_img * 255).astype(np.uint8)) output_img = np.asarray(output_img)[:, :, :3] mesh_images.append(output_img) # delete the previous mesh prev_mesh = scene.get_nodes(name='mesh').pop() scene.remove_node(prev_mesh) # stack images IMG = np.hstack(mesh_images) IMG = pil_img.fromarray(IMG) IMG.thumbnail((3000, 3000)) return IMG def main(args): if os.path.isdir(args.img_src): images = glob.iglob(args.img_src + '/*', recursive=True) else: images = [args.img_src] deco_model = initiate_model(args) smpl_path = os.path.join(constants.SMPL_MODEL_DIR, 'smpl_neutral_tpose.ply') for img_name in images: img = cv2.imread(img_name) img = cv2.resize(img, (256, 256), cv2.INTER_CUBIC) img = img.transpose(2,0,1)/255.0 img = img[np.newaxis,:,:,:] img = torch.tensor(img, dtype = torch.float32).to(device) cont, _, _ = deco_model(img) cont = cont.detach().cpu().numpy().squeeze() cont_smpl = [] for indx, i in enumerate(cont): if i >= 0.5: cont_smpl.append(indx) img = img.detach().cpu().numpy() img = np.transpose(img[0], (1, 2, 0)) img = img * 255 img = img.astype(np.uint8) contact_smpl = np.zeros((1, 1, 6890)) contact_smpl[0][0][cont_smpl] = 1 body_model_smpl = trimesh.load(smpl_path, process=False) for vert in range(body_model_smpl.visual.vertex_colors.shape[0]): body_model_smpl.visual.vertex_colors[vert] = args.mesh_colour body_model_smpl.visual.vertex_colors[cont_smpl] = args.annot_colour rend = create_scene(body_model_smpl, img) os.makedirs(os.path.join(args.out_dir, 'Renders'), exist_ok=True) rend.save(os.path.join(args.out_dir, 'Renders', os.path.basename(img_name).split('.')[0] + '.png')) out_dir = os.path.join(args.out_dir, 'Preds', os.path.basename(img_name).split('.')[0]) os.makedirs(out_dir, exist_ok=True) logger.info(f'Saving mesh to {out_dir}') shutil.copyfile(img_name, os.path.join(out_dir, os.path.basename(img_name))) body_model_smpl.export(os.path.join(out_dir, 'pred.obj')) if __name__=='__main__': parser = argparse.ArgumentParser() parser.add_argument('--img_src', help='Source of image(s). Can be file or directory', default='./demo_out', type=str) parser.add_argument('--out_dir', help='Where to store images', default='./demo_out', type=str) parser.add_argument('--model_path', help='Path to best model weights', default='./checkpoints/Release_Checkpoint/deco_best.pth', type=str) parser.add_argument('--mesh_colour', help='Colour of the mesh', nargs='+', type=int, default=[130, 130, 130, 255]) parser.add_argument('--annot_colour', help='Colour of the mesh', nargs='+', type=int, default=[0, 255, 0, 255]) args = parser.parse_args() main(args)