StructDiffusionDemo / app_v0.py
Weiyu Liu
add natural language model and app
f392320
import os
import argparse
import torch
import trimesh
import numpy as np
import pytorch_lightning as pl
import gradio as gr
from omegaconf import OmegaConf
import sys
sys.path.append('./src')
from StructDiffusion.data.semantic_arrangement_demo import SemanticArrangementDataset
from StructDiffusion.language.tokenizer import Tokenizer
from StructDiffusion.models.pl_models import ConditionalPoseDiffusionModel
from StructDiffusion.diffusion.sampler import Sampler
from StructDiffusion.diffusion.pose_conversion import get_struct_objs_poses
from StructDiffusion.utils.files import get_checkpoint_path_from_dir
from StructDiffusion.utils.rearrangement import show_pcs_with_trimesh
import StructDiffusion.utils.transformations as tra
def move_pc_and_create_scene_simple(obj_xyzs, struct_pose, pc_poses_in_struct):
device = obj_xyzs.device
# obj_xyzs: B, N, P, 3 or 6
# struct_pose: B, 1, 4, 4
# pc_poses_in_struct: B, N, 4, 4
B, N, _, _ = pc_poses_in_struct.shape
_, _, P, _ = obj_xyzs.shape
current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4
# print(torch.mean(obj_xyzs, dim=2).shape)
current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs[:, :, :, :3], dim=2) # B, N, 4, 4
current_pc_poses = current_pc_poses.reshape(B * N, 4, 4) # B x N, 4, 4
struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4
struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4
pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4
goal_pc_pose = struct_pose @ pc_poses_in_struct # B x N, 4, 4
# print("goal pc poses")
# print(goal_pc_pose)
goal_pc_transform = goal_pc_pose @ torch.inverse(current_pc_poses) # B x N, 4, 4
# # important: pytorch3d uses row-major ordering, need to transpose each transformation matrix
# transpose = tra3d.Transform3d(matrix=goal_pc_transform.transpose(1, 2))
# new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1) # B x N, P, 3
# new_obj_xyzs[:, :, :3] = transpose.transform_points(new_obj_xyzs[:, :, :3])
# a verision that does not rely on pytorch3d
new_obj_xyzs = obj_xyzs.reshape(B * N, P, -1)[:, :, :3] # B x N, P, 3
new_obj_xyzs = torch.concat([new_obj_xyzs, torch.ones(B * N, P, 1).to(device)], dim=-1) # B x N, P, 4
new_obj_xyzs = torch.einsum('bij,bkj->bki', goal_pc_transform, new_obj_xyzs)[:, :, :3] # # B x N, P, 3
# put it back to B, N, P, 3
obj_xyzs[:, :, :, :3] = new_obj_xyzs.reshape(B, N, P, -1)
return obj_xyzs
class Infer_Wrapper:
def __init__(self, args, cfg):
# load
pl.seed_everything(args.eval_random_seed)
self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
checkpoint_dir = os.path.join(cfg.WANDB.save_dir, cfg.WANDB.project, args.checkpoint_id, "checkpoints")
checkpoint_path = get_checkpoint_path_from_dir(checkpoint_dir)
self.tokenizer = Tokenizer(cfg.DATASET.vocab_dir)
# override ignore_rgb for visualization
cfg.DATASET.ignore_rgb = False
self.dataset = SemanticArrangementDataset(tokenizer=self.tokenizer, **cfg.DATASET)
self.sampler = Sampler(ConditionalPoseDiffusionModel, checkpoint_path, self.device)
def visualize_scene(self, di, session_id):
raw_datum = self.dataset.get_raw_data(di)
language_command = self.tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"])
obj_xyz = raw_datum["pcs"]
scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in obj_xyz], [xyz[:, 3:] for xyz in obj_xyz], return_scene=True)
scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2))
scene_filename = "./tmp_data/input_scene_{}.glb".format(session_id)
scene.export(scene_filename)
return language_command, scene_filename
def infer(self, di, session_id, progress=gr.Progress()):
# di = np.random.choice(len(self.dataset))
raw_datum = self.dataset.get_raw_data(di)
print(self.tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
datum = self.dataset.convert_to_tensors(raw_datum, self.tokenizer)
batch = self.dataset.single_datum_to_batch(datum, args.num_samples, self.device, inference_mode=True)
num_poses = datum["goal_poses"].shape[0]
xs = self.sampler.sample(batch, num_poses, progress)
struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
# vis
vis_obj_xyzs = new_obj_xyzs[:3]
if torch.is_tensor(vis_obj_xyzs):
if vis_obj_xyzs.is_cuda:
vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
vis_obj_xyzs = vis_obj_xyzs.numpy()
# for bi, vis_obj_xyz in enumerate(vis_obj_xyzs):
# if verbose:
# print("example {}".format(bi))
# print(vis_obj_xyz.shape)
#
# if trimesh:
# show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz])
vis_obj_xyz = vis_obj_xyzs[0]
scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz], return_scene=True)
scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2))
scene_filename = "./tmp_data/output_scene_{}.glb".format(session_id)
scene.export(scene_filename)
# pc_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/pc.glb"
# scene_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/scene.glb"
#
# vis_obj_xyz = vis_obj_xyz.reshape(-1, 6)
# vis_pc = trimesh.PointCloud(vis_obj_xyz[:, :3], colors=np.concatenate([vis_obj_xyz[:, 3:] * 255, np.ones([vis_obj_xyz.shape[0], 1]) * 255], axis=-1))
# vis_pc.export(pc_filename)
#
# scene = trimesh.Scene()
# # add the coordinate frame first
# # geom = trimesh.creation.axis(0.01)
# # scene.add_geometry(geom)
# table = trimesh.creation.box(extents=[1.0, 1.0, 0.02])
# table.apply_translation([0.5, 0, -0.01])
# table.visual.vertex_colors = [150, 111, 87, 125]
# scene.add_geometry(table)
# # bounds = trimesh.creation.box(extents=[4.0, 4.0, 4.0])
# # bounds = trimesh.creation.icosphere(subdivisions=3, radius=3.1)
# # bounds.apply_translation([0, 0, 0])
# # bounds.visual.vertex_colors = [30, 30, 30, 30]
# # scene.add_geometry(bounds)
# # RT_4x4 = np.array([[-0.39560353822208355, -0.9183993826406329, 0.006357240869497738, 0.2651463080169481],
# # [-0.797630370081598, 0.3401340617616391, -0.4980909683511864, 0.2225696480721997],
# # [0.45528412367406523, -0.2021172778236285, -0.8671014777611122, 0.9449050652025951],
# # [0.0, 0.0, 0.0, 1.0]])
# # RT_4x4 = np.linalg.inv(RT_4x4)
# # RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
# # scene.camera_transform = RT_4x4
#
# mesh_list = trimesh.util.concatenate(scene.dump())
# print(mesh_list)
# trimesh.io.export.export_mesh(mesh_list, scene_filename, file_type='obj')
return scene_filename
def infer_new(self, di, session_id, progress=gr.Progress()):
# di = np.random.choice(len(self.dataset))
raw_datum = self.dataset.get_raw_data(di)
print(self.tokenizer.convert_structure_params_to_natural_language(raw_datum["sentence"]))
datum = self.dataset.convert_to_tensors(raw_datum, self.tokenizer)
batch = self.dataset.single_datum_to_batch(datum, args.num_samples, self.device, inference_mode=True)
num_poses = datum["goal_poses"].shape[0]
xs = self.sampler.sample(batch, num_poses, progress)
struct_pose, pc_poses_in_struct = get_struct_objs_poses(xs[0])
new_obj_xyzs = move_pc_and_create_scene_simple(batch["pcs"], struct_pose, pc_poses_in_struct)
# vis
vis_obj_xyzs = new_obj_xyzs[:3]
if torch.is_tensor(vis_obj_xyzs):
if vis_obj_xyzs.is_cuda:
vis_obj_xyzs = vis_obj_xyzs.detach().cpu()
vis_obj_xyzs = vis_obj_xyzs.numpy()
# for bi, vis_obj_xyz in enumerate(vis_obj_xyzs):
# if verbose:
# print("example {}".format(bi))
# print(vis_obj_xyz.shape)
#
# if trimesh:
# show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz])
vis_obj_xyz = vis_obj_xyzs[0]
scene = show_pcs_with_trimesh([xyz[:, :3] for xyz in vis_obj_xyz], [xyz[:, 3:] for xyz in vis_obj_xyz], return_scene=True)
scene.apply_transform(tra.euler_matrix(np.pi, 0, np.pi/2))
scene_filename = "./tmp_data/output_scene_{}.glb".format(session_id)
scene.export(scene_filename)
# pc_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/pc.glb"
# scene_filename = "/home/weiyu/Research/StructDiffusion/StructDiffusion/interactive_demo/tmp_data/scene.glb"
#
# vis_obj_xyz = vis_obj_xyz.reshape(-1, 6)
# vis_pc = trimesh.PointCloud(vis_obj_xyz[:, :3], colors=np.concatenate([vis_obj_xyz[:, 3:] * 255, np.ones([vis_obj_xyz.shape[0], 1]) * 255], axis=-1))
# vis_pc.export(pc_filename)
#
# scene = trimesh.Scene()
# # add the coordinate frame first
# # geom = trimesh.creation.axis(0.01)
# # scene.add_geometry(geom)
# table = trimesh.creation.box(extents=[1.0, 1.0, 0.02])
# table.apply_translation([0.5, 0, -0.01])
# table.visual.vertex_colors = [150, 111, 87, 125]
# scene.add_geometry(table)
# # bounds = trimesh.creation.box(extents=[4.0, 4.0, 4.0])
# # bounds = trimesh.creation.icosphere(subdivisions=3, radius=3.1)
# # bounds.apply_translation([0, 0, 0])
# # bounds.visual.vertex_colors = [30, 30, 30, 30]
# # scene.add_geometry(bounds)
# # RT_4x4 = np.array([[-0.39560353822208355, -0.9183993826406329, 0.006357240869497738, 0.2651463080169481],
# # [-0.797630370081598, 0.3401340617616391, -0.4980909683511864, 0.2225696480721997],
# # [0.45528412367406523, -0.2021172778236285, -0.8671014777611122, 0.9449050652025951],
# # [0.0, 0.0, 0.0, 1.0]])
# # RT_4x4 = np.linalg.inv(RT_4x4)
# # RT_4x4 = RT_4x4 @ np.diag([1, -1, -1, 1])
# # scene.camera_transform = RT_4x4
#
# mesh_list = trimesh.util.concatenate(scene.dump())
# print(mesh_list)
# trimesh.io.export.export_mesh(mesh_list, scene_filename, file_type='obj')
return scene_filename
args = OmegaConf.create()
args.base_config_file = "./configs/base.yaml"
args.config_file = "./configs/conditional_pose_diffusion.yaml"
args.checkpoint_id = "ConditionalPoseDiffusion"
args.eval_random_seed = 42
args.num_samples = 1
base_cfg = OmegaConf.load(args.base_config_file)
cfg = OmegaConf.load(args.config_file)
cfg = OmegaConf.merge(base_cfg, cfg)
infer_wrapper = Infer_Wrapper(args, cfg)
# version 0
# demo = gr.Interface(
# fn=infer_wrapper.run,
# inputs=gr.Slider(0, len(infer_wrapper.dataset)),
# # clear color range [0-1.0]
# outputs=gr.Model3D(clear_color=[0, 0, 0, 0], label="3D Model")
# )
#
# demo.launch()
# version 1
demo = gr.Blocks(theme=gr.themes.Soft())
with demo:
gr.Markdown("<p style='text-align:center;font-size:18px'><b>StructDiffusion Demo</b></p>")
# font-size:18px
gr.Markdown("<p style='text-align:center'>StructDiffusion combines a diffusion model and an object-centric transformer to construct structures given partial-view point clouds and high-level language goals.<br><a href='https://structdiffusion.github.io/'>Website</a> | <a href='https://github.com/StructDiffusion/StructDiffusion'>Code</a></p>")
session_id = gr.State(value=np.random.randint(0, 1000))
data_selection = gr.Number(label="Example No.", minimum=0, maximum=len(infer_wrapper.dataset) - 1, precision=0)
input_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Input 3D Scene")
language_command = gr.Textbox(label="Input Language Command")
output_scene = gr.Model3D(clear_color=[0, 0, 0, 0], label="Generated 3D Structure")
b1 = gr.Button("Show Input Language and Scene")
b2 = gr.Button("Generate 3D Structure")
b1.click(infer_wrapper.visualize_scene, inputs=[data_selection, session_id], outputs=[language_command, input_scene])
b2.click(infer_wrapper.infer, inputs=[data_selection, session_id], outputs=output_scene)
demo.queue(concurrency_count=10)
demo.launch()