import gradio as gr import os import pytorch_lightning as pl import torch as th import open3d as o3d import numpy as np import trimesh as tm from models.model import Model model = Model() ckpg = th.load("./checkpoints/epoch=99-step=6000.ckpt") model.load_state_dict(ckpg["state_dict"]) def process_mesh(mesh_file_name): mesh = tm.load_mesh(mesh_file_name) v = th.tensor(mesh.vertices, dtype=th.float) n = th.tensor(mesh.vertex_normals, dtype=th.float) with th.no_grad(): v, f, n, _ = model(v.unsqueeze(0), n.unsqueeze(0)) mesh = tm.Trimesh(vertices=v.squeeze(0), faces=f.squeeze(0), vertex_normals=n.squeeze(0)) obj_path = "./sample.obj" mesh.export(obj_path) return obj_path demo = gr.Interface( fn=process_mesh, inputs=gr.Model3D(), outputs=gr.Model3D( clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"), examples=[ [os.path.join(os.path.dirname(__file__), "files\\bunny_n1_hi_50.obj")], [os.path.join(os.path.dirname(__file__), "files\\child_n2_80.obj")], [os.path.join(os.path.dirname(__file__), "files\\eight_n3_70.obj")], ], ) if __name__ == "__main__": demo.launch()