DECO / inference.py
ac5113's picture
added files
99a05f0
raw
history blame
7.31 kB
import torch
import os
import glob
import argparse
import numpy as np
import cv2
import PIL.Image as pil_img
from loguru import logger
import shutil
import trimesh
import pyrender
from models.deco import DECO
from common import constants
os.environ['PYOPENGL_PLATFORM'] = 'egl'
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
def initiate_model(args):
deco_model = DECO('hrnet', True, device)
logger.info(f'Loading weights from {args.model_path}')
checkpoint = torch.load(args.model_path)
deco_model.load_state_dict(checkpoint['deco'], strict=True)
deco_model.eval()
return deco_model
def render_image(scene, img_res, img=None, viewer=False):
'''
Render the given pyrender scene and return the image. Can also overlay the mesh on an image.
'''
if viewer:
pyrender.Viewer(scene, use_raymond_lighting=True)
return 0
else:
r = pyrender.OffscreenRenderer(viewport_width=img_res,
viewport_height=img_res,
point_size=1.0)
color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA)
color = color.astype(np.float32) / 255.0
if img is not None:
valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
input_img = img.detach().cpu().numpy()
output_img = (color[:, :, :-1] * valid_mask +
(1 - valid_mask) * input_img)
else:
output_img = color
return output_img
def create_scene(mesh, img, focal_length=500, camera_center=250, img_res=500):
# Setup the scene
scene = pyrender.Scene(bg_color=[1.0, 1.0, 1.0, 1.0],
ambient_light=(0.3, 0.3, 0.3))
# add mesh for camera
camera_pose = np.eye(4)
camera_rotation = np.eye(3, 3)
camera_translation = np.array([0., 0, 2.5])
camera_pose[:3, :3] = camera_rotation
camera_pose[:3, 3] = camera_rotation @ camera_translation
pyrencamera = pyrender.camera.IntrinsicsCamera(
fx=focal_length, fy=focal_length,
cx=camera_center, cy=camera_center)
scene.add(pyrencamera, pose=camera_pose)
# create and add light
light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=1)
light_pose = np.eye(4)
for lp in [[1, 1, 1], [-1, 1, 1], [1, -1, 1], [-1, -1, 1]]:
light_pose[:3, 3] = mesh.vertices.mean(0) + np.array(lp)
# out_mesh.vertices.mean(0) + np.array(lp)
scene.add(light, pose=light_pose)
# add body mesh
material = pyrender.MetallicRoughnessMaterial(
metallicFactor=0.0,
alphaMode='OPAQUE',
baseColorFactor=(1.0, 1.0, 0.9, 1.0))
mesh_images = []
# resize input image to fit the mesh image height
img_height = img_res
img_width = int(img_height * img.shape[1] / img.shape[0])
img = cv2.resize(img, (img_width, img_height))
mesh_images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
for sideview_angle in [0, 90, 180, 270]:
out_mesh = mesh.copy()
rot = trimesh.transformations.rotation_matrix(
np.radians(sideview_angle), [0, 1, 0])
out_mesh.apply_transform(rot)
out_mesh = pyrender.Mesh.from_trimesh(
out_mesh,
material=material)
mesh_pose = np.eye(4)
scene.add(out_mesh, pose=mesh_pose, name='mesh')
output_img = render_image(scene, img_res)
output_img = pil_img.fromarray((output_img * 255).astype(np.uint8))
output_img = np.asarray(output_img)[:, :, :3]
mesh_images.append(output_img)
# delete the previous mesh
prev_mesh = scene.get_nodes(name='mesh').pop()
scene.remove_node(prev_mesh)
# show upside down view
for topview_angle in [90, 270]:
out_mesh = mesh.copy()
rot = trimesh.transformations.rotation_matrix(
np.radians(topview_angle), [1, 0, 0])
out_mesh.apply_transform(rot)
out_mesh = pyrender.Mesh.from_trimesh(
out_mesh,
material=material)
mesh_pose = np.eye(4)
scene.add(out_mesh, pose=mesh_pose, name='mesh')
output_img = render_image(scene, img_res)
output_img = pil_img.fromarray((output_img * 255).astype(np.uint8))
output_img = np.asarray(output_img)[:, :, :3]
mesh_images.append(output_img)
# delete the previous mesh
prev_mesh = scene.get_nodes(name='mesh').pop()
scene.remove_node(prev_mesh)
# stack images
IMG = np.hstack(mesh_images)
IMG = pil_img.fromarray(IMG)
IMG.thumbnail((3000, 3000))
return IMG
def main(args):
if os.path.isdir(args.img_src):
images = glob.iglob(args.img_src + '/*', recursive=True)
else:
images = [args.img_src]
deco_model = initiate_model(args)
smpl_path = os.path.join(constants.SMPL_MODEL_DIR, 'smpl_neutral_tpose.ply')
for img_name in images:
img = cv2.imread(img_name)
img = cv2.resize(img, (256, 256), cv2.INTER_CUBIC)
img = img.transpose(2,0,1)/255.0
img = img[np.newaxis,:,:,:]
img = torch.tensor(img, dtype = torch.float32).to(device)
cont, _, _ = deco_model(img)
cont = cont.detach().cpu().numpy().squeeze()
cont_smpl = []
for indx, i in enumerate(cont):
if i >= 0.5:
cont_smpl.append(indx)
img = img.detach().cpu().numpy()
img = np.transpose(img[0], (1, 2, 0))
img = img * 255
img = img.astype(np.uint8)
contact_smpl = np.zeros((1, 1, 6890))
contact_smpl[0][0][cont_smpl] = 1
body_model_smpl = trimesh.load(smpl_path, process=False)
for vert in range(body_model_smpl.visual.vertex_colors.shape[0]):
body_model_smpl.visual.vertex_colors[vert] = args.mesh_colour
body_model_smpl.visual.vertex_colors[cont_smpl] = args.annot_colour
rend = create_scene(body_model_smpl, img)
os.makedirs(os.path.join(args.out_dir, 'Renders'), exist_ok=True)
rend.save(os.path.join(args.out_dir, 'Renders', os.path.basename(img_name).split('.')[0] + '.png'))
out_dir = os.path.join(args.out_dir, 'Preds', os.path.basename(img_name).split('.')[0])
os.makedirs(out_dir, exist_ok=True)
logger.info(f'Saving mesh to {out_dir}')
shutil.copyfile(img_name, os.path.join(out_dir, os.path.basename(img_name)))
body_model_smpl.export(os.path.join(out_dir, 'pred.obj'))
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--img_src', help='Source of image(s). Can be file or directory', default='./demo_out', type=str)
parser.add_argument('--out_dir', help='Where to store images', default='./demo_out', type=str)
parser.add_argument('--model_path', help='Path to best model weights', default='./checkpoints/Release_Checkpoint/deco_best.pth', type=str)
parser.add_argument('--mesh_colour', help='Colour of the mesh', nargs='+', type=int, default=[130, 130, 130, 255])
parser.add_argument('--annot_colour', help='Colour of the mesh', nargs='+', type=int, default=[0, 255, 0, 255])
args = parser.parse_args()
main(args)