cad-recode / app.py
filapro's picture
Update app.py
0085c87 verified
raw
history blame
10.3 kB
import os
import spaces
import trimesh
import traceback
import numpy as np
import gradio as gr
from functools import partial
from multiprocessing import Process, Queue
import torch
from torch import nn
from transformers import (
AutoTokenizer, Qwen2ForCausalLM, Qwen2Model, PreTrainedModel)
from transformers.modeling_outputs import CausalLMOutputWithPast
class FourierPointEncoder(nn.Module):
def __init__(self, hidden_size):
super().__init__()
frequencies = 2.0 ** torch.arange(8, dtype=torch.float32)
self.register_buffer('frequencies', frequencies, persistent=False)
self.projection = nn.Linear(54, hidden_size)
def forward(self, points):
x = points[..., :3]
x = (x.unsqueeze(-1) * self.frequencies).view(*x.shape[:-1], -1)
x = torch.cat((points[..., :3], x.sin(), x.cos()), dim=-1)
x = self.projection(torch.cat((x, points[..., 3:]), dim=-1))
return x
class CADRecode(Qwen2ForCausalLM):
def __init__(self, config):
PreTrainedModel.__init__(self, config)
self.model = Qwen2Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
torch.set_default_dtype(torch.float32)
self.point_encoder = FourierPointEncoder(config.hidden_size)
torch.set_default_dtype(torch.bfloat16)
def forward(self,
input_ids=None,
attention_mask=None,
point_cloud=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
cache_position=None):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# concatenate point and text embeddings
if past_key_values is None or past_key_values.get_seq_length() == 0:
assert inputs_embeds is None
inputs_embeds = self.model.embed_tokens(input_ids)
point_embeds = self.point_encoder(point_cloud).bfloat16()
inputs_embeds[attention_mask == -1] = point_embeds.reshape(-1, point_embeds.shape[2])
attention_mask[attention_mask == -1] = 1
input_ids = None
position_ids = None
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions)
def prepare_inputs_for_generation(self, *args, **kwargs):
model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
model_inputs['point_cloud'] = kwargs['point_cloud']
return model_inputs
def mesh_to_point_cloud(mesh, n_points=256):
vertices, faces = trimesh.sample.sample_surface(mesh, n_points)
point_cloud = np.concatenate((
np.asarray(vertices),
mesh.face_normals[faces]
), axis=1)
ids = np.lexsort((point_cloud[:, 0], point_cloud[:, 1], point_cloud[:, 2]))
point_cloud = point_cloud[ids]
return point_cloud
def py_string_to_mesh_file(py_string, mesh_path, queue):
try:
exec(py_string, globals())
compound = globals()['r'].val()
vertices, faces = compound.tessellate(0.001, 0.1)
mesh = trimesh.Trimesh([(v.x, v.y, v.z) for v in vertices], faces)
mesh.export(mesh_path)
except:
queue.put(traceback.format_exc())
def py_string_to_mesh_file_safe(py_string, mesh_path):
# CadQuery code predicted by LLM may be unsafe and cause memory leaks.
# That's why we execute it in a separace Process with timeout.
queue = Queue()
process = Process(
target=py_string_to_mesh_file,
args=(py_string, mesh_path, queue))
process.start()
process.join(5)
if process.is_alive():
process.terminate()
process.join()
raise gr.Error('Process is alive after 3 seconds')
if not queue.empty():
raise gr.Error(queue.get())
def run_point_cloud(in_mesh_path, seed):
try:
mesh = trimesh.load(in_mesh_path)
mesh.apply_translation(-(mesh.bounds[0] + mesh.bounds[1]) / 2.0)
mesh.apply_scale(2.0 / max(mesh.extents))
np.random.seed(seed)
point_cloud = mesh_to_point_cloud(mesh)
pcd_path = '/tmp/pcd.obj'
trimesh.points.PointCloud(point_cloud[:, :3]).export(pcd_path)
return point_cloud, pcd_path
except:
raise gr.Error(traceback.format_exc())
@spaces.GPU(duration=20)
def run_cad_recode(point_cloud):
try:
input_ids = [tokenizer.pad_token_id] * len(point_cloud) + [tokenizer('<|im_start|>')['input_ids'][0]]
attention_mask = [-1] * len(point_cloud) + [1]
model = cad_recode.cuda()
with torch.no_grad():
batch_ids = cad_recode.generate(
input_ids=torch.tensor(input_ids).unsqueeze(0).to(model.device),
attention_mask=torch.tensor(attention_mask).unsqueeze(0).to(model.device),
point_cloud=torch.tensor(point_cloud.astype(np.float32)).unsqueeze(0).to(model.device),
max_new_tokens=768,
pad_token_id=tokenizer.pad_token_id).cpu()
py_string = tokenizer.batch_decode(batch_ids)[0]
begin = py_string.find('<|im_start|>') + 12
end = py_string.find('<|endoftext|>')
py_string = py_string[begin: end]
return py_string, py_string
except:
raise gr.Error(traceback.format_exc())
def run_mesh(py_string):
try:
out_mesh_path = '/tmp/mesh.stl'
py_string_to_mesh_file_safe(py_string, out_mesh_path)
return out_mesh_path
except:
raise gr.Error(traceback.format_exc())
def run():
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown('## CAD-Recode Demo\n'
'Upload mesh or select from examples and press Run! Mesh ⇾ 256 points ⇾ Python code by CAD-Recode ⇾ CAD model.')
with gr.Row(equal_height=True):
in_model = gr.Model3D(label='1. Input Mesh', interactive=True)
point_model = gr.Model3D(label='2. Sampled Point Cloud', display_mode='point_cloud', interactive=False)
out_model = gr.Model3D(
label='4. Result CAD Model', interactive=False
)
with gr.Row():
with gr.Column():
with gr.Row():
seed_slider = gr.Slider(label='Random Seed', value=42, interactive=True)
with gr.Row():
gr.Examples(
examples=[
['./data/49215_5368e45e_0000.stl', 42],
['./data/00882236.stl', 6],
['./data/User Library-engrenage.stl', 18],
['./data/00010900.stl', 42],
['./data/21492_8bd34fc1_0008.stl', 42],
['./data/00375556.stl', 96],
['./data/49121_adb01620_0000.stl', 42]],
example_labels=[
'fusion360_table1', 'deepcad_star', 'cc3d_gear', 'deepcad_barrels',
'fusion360_gear', 'deepcad_house', 'fusion360_table2'],
inputs=[in_model, seed_slider],
cache_examples=False)
with gr.Row():
run_button = gr.Button('Run')
with gr.Column():
out_code = gr.Code(language='python', label='3. Generated Python Code', wrap_lines=True, interactive=False)
with gr.Column():
pass
state = gr.State()
run_button.click(
run_point_cloud,
inputs=[in_model, seed_slider],
outputs=[state, point_model]
).success(
run_cad_recode,
inputs=[state],
outputs=[state, out_code]
).success(
run_mesh,
inputs=[state],
outputs=[out_model]
)
demo.launch(show_error=True)
tokenizer = AutoTokenizer.from_pretrained(
'Qwen/Qwen2-1.5B',
pad_token='<|im_end|>',
padding_side='left')
cad_recode = CADRecode.from_pretrained(
'filapro/cad-recode',
torch_dtype='auto',
attn_implementation='flash_attention_2').eval()
os.environ['TOKENIZERS_PARALLELISM'] = 'False'
run()