radames's picture
radames HF staff
initial commit
c7f097c
raw history blame
No virus
3.36 kB
import io
import os
import torch
from skimage.io import imread
import numpy as np
import cv2
from tqdm import tqdm_notebook as tqdm
import base64
from IPython.display import HTML
# Util function for loading meshes
from pytorch3d.io import load_objs_as_meshes
from IPython.display import HTML
from base64 import b64encode
# Data structures and functions for rendering
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
look_at_view_transform,
OpenGLOrthographicCameras,
PointLights,
DirectionalLights,
Materials,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
HardPhongShader,
TexturesVertex
)
def set_renderer():
# Setup
device = torch.device("cuda:0")
torch.cuda.set_device(device)
# Initialize an OpenGL perspective camera.
R, T = look_at_view_transform(2.0, 0, 180)
cameras = OpenGLOrthographicCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512,
blur_radius=0.0,
faces_per_pixel=1,
bin_size = None,
max_faces_per_bin = None
)
lights = PointLights(device=device, location=((2.0, 2.0, 2.0),))
renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=raster_settings
),
shader=HardPhongShader(
device=device,
cameras=cameras,
lights=lights
)
)
return renderer
def get_verts_rgb_colors(obj_path):
rgb_colors = []
f = open(obj_path)
lines = f.readlines()
for line in lines:
ls = line.split(' ')
if len(ls) == 7:
rgb_colors.append(ls[-3:])
return np.array(rgb_colors, dtype='float32')[None, :, :]
def generate_video_from_obj(obj_path, video_path, renderer):
# Setup
device = torch.device("cuda:0")
torch.cuda.set_device(device)
# Load obj file
verts_rgb_colors = get_verts_rgb_colors(obj_path)
verts_rgb_colors = torch.from_numpy(verts_rgb_colors).to(device)
textures = TexturesVertex(verts_features=verts_rgb_colors)
wo_textures = TexturesVertex(verts_features=torch.ones_like(verts_rgb_colors)*0.75)
# Load obj
mesh = load_objs_as_meshes([obj_path], device=device)
# Set mesh
vers = mesh._verts_list
faces = mesh._faces_list
mesh_w_tex = Meshes(vers, faces, textures)
mesh_wo_tex = Meshes(vers, faces, wo_textures)
# create VideoWriter
fourcc = cv2. VideoWriter_fourcc(*'MP4V')
out = cv2.VideoWriter(video_path, fourcc, 20.0, (1024,512))
for i in tqdm(range(90)):
R, T = look_at_view_transform(1.8, 0, i*4, device=device)
images_w_tex = renderer(mesh_w_tex, R=R, T=T)
images_w_tex = np.clip(images_w_tex[0, ..., :3].cpu().numpy(), 0.0, 1.0)[:, :, ::-1] * 255
images_wo_tex = renderer(mesh_wo_tex, R=R, T=T)
images_wo_tex = np.clip(images_wo_tex[0, ..., :3].cpu().numpy(), 0.0, 1.0)[:, :, ::-1] * 255
image = np.concatenate([images_w_tex, images_wo_tex], axis=1)
out.write(image.astype('uint8'))
out.release()
def video(path):
mp4 = open(path,'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
return HTML('<video width=500 controls loop> <source src="%s" type="video/mp4"></video>' % data_url)