|
import sys |
|
from pathlib import Path |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
|
|
|
CADFUSION_DIR = Path("CADFusion") |
|
sys.path.append(str(CADFUSION_DIR.resolve())) |
|
|
|
|
|
from cadfusion.inference import InferenceRunner |
|
from cadfusion.utils.config import load_config |
|
|
|
|
|
|
|
|
|
CONFIG_PATH = hf_hub_download( |
|
repo_id="microsoft/CADFusion", |
|
filename="adapter_config.json", |
|
revision="v1_1" |
|
) |
|
|
|
CHECKPOINT_PATH = hf_hub_download( |
|
repo_id="microsoft/CADFusion", |
|
filename="adapter_model.safetensors", |
|
revision="v1_1" |
|
) |
|
|
|
|
|
|
|
|
|
config = load_config(CONFIG_PATH) |
|
runner = InferenceRunner(config=config) |
|
runner.load_checkpoint(CHECKPOINT_PATH) |
|
|
|
|
|
|
|
|
|
def generate_cad(prompt: str): |
|
try: |
|
outputs = runner.infer_from_text(prompt) |
|
out_file = "output.stl" |
|
mesh = outputs.to_trimesh() |
|
mesh.export(out_file) |
|
return out_file |
|
except Exception as e: |
|
return f"Error during inference: {str(e)}" |
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate_cad, |
|
inputs=gr.Textbox(label="Enter CAD description"), |
|
outputs=gr.File(label="Generated STL Model"), |
|
title="CADFusion - Text to CAD", |
|
description="Enter a natural language description and generate a 3D CAD mesh (STL)." |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|