Upload 50 files
Browse files- CADFusion/.gitignore +171 -0
- CADFusion/CODE_OF_CONDUCT.md +9 -0
- CADFusion/LICENSE +21 -0
- CADFusion/README.md +194 -0
- CADFusion/SECURITY.md +41 -0
- CADFusion/SUPPORT.md +25 -0
- CADFusion/data/sl_data/convert.py +125 -0
- CADFusion/data/sl_data/sl_data.zip +3 -0
- CADFusion/data/vf_data/example_vf_data.zip +3 -0
- CADFusion/ds_config.yaml +22 -0
- CADFusion/pyproject.toml +38 -0
- CADFusion/scripts/alternate_VF.sh +47 -0
- CADFusion/scripts/alternate_VF_quadra_gpu.sh +50 -0
- CADFusion/scripts/generate_samples.sh +44 -0
- CADFusion/scripts/make_dpo_data.sh +5 -0
- CADFusion/scripts/preprocess_skexgen.sh +28 -0
- CADFusion/scripts/train_loop.sh +42 -0
- CADFusion/scripts/train_with_shuffling.sh +20 -0
- CADFusion/src/data_preprocessing/call_openai.py +37 -0
- CADFusion/src/data_preprocessing/captioning.py +101 -0
- CADFusion/src/data_preprocessing/convert.py +120 -0
- CADFusion/src/dpo/llava_utils.py +95 -0
- CADFusion/src/dpo/make_dpo_dataset.py +162 -0
- CADFusion/src/dpo/openai_utils.py +88 -0
- CADFusion/src/rendering_utils/geometry/arc.py +32 -0
- CADFusion/src/rendering_utils/geometry/circle.py +27 -0
- CADFusion/src/rendering_utils/geometry/curve.py +13 -0
- CADFusion/src/rendering_utils/geometry/geom_utils.py +95 -0
- CADFusion/src/rendering_utils/geometry/line.py +24 -0
- CADFusion/src/rendering_utils/geometry/obj_parser.py +276 -0
- CADFusion/src/rendering_utils/geometry/obj_utils.py +93 -0
- CADFusion/src/rendering_utils/img_renderer.py +84 -0
- CADFusion/src/rendering_utils/parser.py +478 -0
- CADFusion/src/rendering_utils/parser_visual.py +110 -0
- CADFusion/src/rendering_utils/ptl_sampler.py +88 -0
- CADFusion/src/rendering_utils/utils/obj_reconverter.py +437 -0
- CADFusion/src/rendering_utils/utils/util.py +72 -0
- CADFusion/src/test/VLM_score.py +95 -0
- CADFusion/src/test/chamfer_dist.py +308 -0
- CADFusion/src/test/dist_eval.py +351 -0
- CADFusion/src/test/f1_eval.py +74 -0
- CADFusion/src/test/generate.ipynb +291 -0
- CADFusion/src/test/inference.py +106 -0
- CADFusion/src/test/utils.py +86 -0
- CADFusion/src/test/visual_utils/__init__.py +0 -0
- CADFusion/src/test/visual_utils/parser.py +478 -0
- CADFusion/src/train/CAD_dataset.py +89 -0
- CADFusion/src/train/dpo.py +79 -0
- CADFusion/src/train/llama_finetune.py +127 -0
- CADFusion/src/train/utils.py +86 -0
CADFusion/.gitignore
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
|
| 110 |
+
# pdm
|
| 111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 112 |
+
#pdm.lock
|
| 113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 114 |
+
# in version control.
|
| 115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 116 |
+
.pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 121 |
+
__pypackages__/
|
| 122 |
+
|
| 123 |
+
# Celery stuff
|
| 124 |
+
celerybeat-schedule
|
| 125 |
+
celerybeat.pid
|
| 126 |
+
|
| 127 |
+
# SageMath parsed files
|
| 128 |
+
*.sage.py
|
| 129 |
+
|
| 130 |
+
# Environments
|
| 131 |
+
.env
|
| 132 |
+
.venv
|
| 133 |
+
env/
|
| 134 |
+
venv/
|
| 135 |
+
ENV/
|
| 136 |
+
env.bak/
|
| 137 |
+
venv.bak/
|
| 138 |
+
|
| 139 |
+
# Spyder project settings
|
| 140 |
+
.spyderproject
|
| 141 |
+
.spyproject
|
| 142 |
+
|
| 143 |
+
# Rope project settings
|
| 144 |
+
.ropeproject
|
| 145 |
+
|
| 146 |
+
# mkdocs documentation
|
| 147 |
+
/site
|
| 148 |
+
|
| 149 |
+
# mypy
|
| 150 |
+
.mypy_cache/
|
| 151 |
+
.dmypy.json
|
| 152 |
+
dmypy.json
|
| 153 |
+
|
| 154 |
+
# Pyre type checker
|
| 155 |
+
.pyre/
|
| 156 |
+
|
| 157 |
+
# pytype static type analyzer
|
| 158 |
+
.pytype/
|
| 159 |
+
|
| 160 |
+
# Cython debug symbols
|
| 161 |
+
cython_debug/
|
| 162 |
+
|
| 163 |
+
# PyCharm
|
| 164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 168 |
+
#.idea/
|
| 169 |
+
|
| 170 |
+
# PyPI configuration file
|
| 171 |
+
.pypirc
|
CADFusion/CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Microsoft Open Source Code of Conduct
|
| 2 |
+
|
| 3 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
| 4 |
+
|
| 5 |
+
Resources:
|
| 6 |
+
|
| 7 |
+
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
| 8 |
+
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
| 9 |
+
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
|
CADFusion/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) Microsoft Corporation.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE
|
CADFusion/README.md
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CADFusion
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
This repo is the official implementation of paper **[ICML 2025] Text-to-CAD Generation Through Infusing Visual Feedback in Large Language Models** by *Ruiyu Wang, Yu Yuan, Shizhao Sun, Jiang Bian*.
|
| 5 |
+
|
| 6 |
+
[Paper](https://arxiv.org/abs/2501.19054) | [Video](https://www.youtube-nocookie.com/embed/LK8LAzR0v5M?si=FD1Vg9wjkROTKjDV) | [Huggingface](https://huggingface.co/microsoft/CADFusion)
|
| 7 |
+
|
| 8 |
+
CADFusion is a text-to-CAD generation framework that leverages visual feedback to enhance the performance of large language models (LLMs) in generating CAD models from textual descriptions. It consists of two main components: sequential learning and visual learning. The sequential learning component fine-tunes LLMs on a text-to-CAD dataset, while the visual learning component alternates between training a visual feedback model and fine-tuning the LLM with the generated visual feedback.
|
| 9 |
+
|
| 10 |
+
## Installation
|
| 11 |
+
|
| 12 |
+
- Create a conda environment and install the generic dependencies.
|
| 13 |
+
|
| 14 |
+
```
|
| 15 |
+
name=<your-env-name>
|
| 16 |
+
conda create -n $name python=3.9
|
| 17 |
+
conda activate $name
|
| 18 |
+
python -m pip install -e .
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
- Install the additional dependencies for training.
|
| 22 |
+
|
| 23 |
+
```
|
| 24 |
+
python -m pip install -e .["train"]
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
- Install the additional dependencies for evaluation and rendering.
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
python -m pip install -e .["render"]
|
| 31 |
+
conda install -c conda-forge pythonocc-core=7.7.0
|
| 32 |
+
python -m pip install git+https://github.com/otaheri/chamfer_distance@dc9987dcf70888d387d96893ba1fb9ba9a333992
|
| 33 |
+
python -m pip install -e .["eval"]
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
## Data Preparation
|
| 37 |
+
CADFusion is trained by alternating the **Sequential Learning (SL)** stage and the **Visual Feedback (VF)** stage.
|
| 38 |
+
We introduce how to prepare the training data for these two stages in the below.
|
| 39 |
+
|
| 40 |
+
### Data for Sequential Learning
|
| 41 |
+
|
| 42 |
+
#### Approach 1: use human-annotated textual descriptions provided by us
|
| 43 |
+
We provide human-annoated textual descriptions and their correspoding CAD model IDs in [Skexgen](https://github.com/samxuxiang/SkexGen) under `data/sl_data/sl_data.zip`. It should contain the following files after unzipping:
|
| 44 |
+
```
|
| 45 |
+
data/sl_data
|
| 46 |
+
├── train.json
|
| 47 |
+
├── val.json
|
| 48 |
+
├── test.json
|
| 49 |
+
```
|
| 50 |
+
To use our annotated data, download the SkexGen data, unzip it as the reference dataset and run the convertion script to get the dataset. In detail, run the following command:
|
| 51 |
+
```
|
| 52 |
+
# make sure you are in the root directory of this repo and have the 'data/sl_data/sl_data.zip' unzipped
|
| 53 |
+
gdown --id 1so_CCGLIhqGEDQxMoiR--A4CQk4MjuOp
|
| 54 |
+
unzip cad_data.zip
|
| 55 |
+
python3 data/sl_data/convert.py
|
| 56 |
+
```
|
| 57 |
+
The `train.json`, `val.json` and `test.json` under `data/sl_data` are the datasets.
|
| 58 |
+
|
| 59 |
+
#### Approach 2: create human-annotated textual descriptions by yourself
|
| 60 |
+
We provide a script to execute all the preprocessing steps until human annotation.
|
| 61 |
+
```
|
| 62 |
+
./scripts/preprocess_skexgen.sh
|
| 63 |
+
```
|
| 64 |
+
If you want to customize the internal steps, expand the following section for more details.
|
| 65 |
+
<details>
|
| 66 |
+
<summary>Start from scratch (click to expand).</summary>
|
| 67 |
+
|
| 68 |
+
1. Download the [SkexGen](https://github.com/samxuxiang/SkexGen) data by: [Google Drive link](https://drive.google.com/file/d/1so_CCGLIhqGEDQxMoiR--A4CQk4MjuOp/view).
|
| 69 |
+
|
| 70 |
+
```
|
| 71 |
+
gdown --id 1so_CCGLIhqGEDQxMoiR--A4CQk4MjuOp
|
| 72 |
+
unzip cad_data.zip
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
2. Convert the SkexGen data into sequences. Note that `train_deduplicate_s.pkl`, `val.pkl` and `test.pkl` should be converted separately.
|
| 76 |
+
```
|
| 77 |
+
python3 src/data_preprocessing/convert.py --in_path <skexgen_path> --out_path <sequence_path>
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
3. Render the sequences into images. *Note that running the last step on linux requires the installation of an x server (e.g. `xvfb`). See [this discussion.](https://github.com/tpaviot/pythonocc-core/issues/1302#issuecomment-2053526444)*
|
| 81 |
+
```
|
| 82 |
+
python3 src/rendering_utils/parser.py --in-path <sequence_path> --out-path <visual_object_folder>
|
| 83 |
+
timeout 180 python3 src/rendering_utils/parser_visual.py --data_folder <visual_object_folder>
|
| 84 |
+
python3 src/rendering_utils/img_renderer.py --input_dir <visual_object_folder> --output_dir <image_folder>
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
4. Annotate these data with LLM captioning.
|
| 88 |
+
```
|
| 89 |
+
# Generic:
|
| 90 |
+
python3 src/data_preprocessing/captioning.py --image-folder-path <image_folder> --out-path <sl_data_path>
|
| 91 |
+
|
| 92 |
+
```
|
| 93 |
+
* We use openai and azure system for LLM calling. You are welcome to use your own LLMs and prompts by changing `line 21, 22` of `src/data_preprocessing/captioning.py` with your own client definition and function calls.
|
| 94 |
+
</details>
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
### Data for Visual Feedback
|
| 98 |
+
|
| 99 |
+
The Visual Feedback dataset should be automatically generated from the Visual Feedback pipeline described in the Training section.
|
| 100 |
+
We provide an example under `data/vf_data/example_vf_data.json` to help people understand how it should look like.
|
| 101 |
+
You can retrieve this file by unzipping `data/vf_data/example_vf_data.zip`.
|
| 102 |
+
We do not recommend using this example data as the training data, as the policy update should depend on its own generations.
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
## Training
|
| 106 |
+
Our training receipe contains two parts. In the first part, we conduct initial sequential learning. In the second part, we conduct alternate training between sequential learning and visual feedback.
|
| 107 |
+
### Initial Sequential Learning
|
| 108 |
+
We use the following script to train the model in the sequential learning stage.
|
| 109 |
+
```
|
| 110 |
+
./scripts/train_with_shuffling.sh <run_name>
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
You are also welcome to customize the training procedure. A normal training script on multiple GPUs is provided. Change `num_processes` in `ds_config.yaml` to specify how many GPUs will be used.
|
| 114 |
+
```
|
| 115 |
+
CUDA_VISIBLE_DEVICES=<gpu_ids> accelerate launch --config_file ds_config.yaml src/train/llama_finetune.py \
|
| 116 |
+
--num-epochs <num_epochs> --run-name <run_name> --data-path <train_data> --eval-data-path <eval_data> \
|
| 117 |
+
--device-map accelerate --model-name llama3 --expdir <model_saving_path>
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
In our work we shuffle the dataset per x epochs. To train model with this implementation, inspect and modify `scripts/train_with_shuffling.sh`.
|
| 121 |
+
|
| 122 |
+
### Alternate Training between Sequential Learning and Visual Feedback
|
| 123 |
+
We provide a script for executing our alternate training round. See `scripts/alternate_VF.sh`.
|
| 124 |
+
```
|
| 125 |
+
./scripts/alternate_VF.sh # change the value of base_name in the script as instructed
|
| 126 |
+
```
|
| 127 |
+
We also provide a script for training on multiple gpus for saving time: `scripts/alternate_VF_quadra_gpu.sh`. In our setting, we use 4 GPUs for training. You can change the script to use more GPUs if you have them available.
|
| 128 |
+
|
| 129 |
+
If you only want to conduct a single round of visual learning, run
|
| 130 |
+
```
|
| 131 |
+
python src/train/dpo.py --run-name <dpo_run_name> --pretrained-path <pretrained_model_path> --data-path <dpo_data_Path> --output-path <model_saving_path>
|
| 132 |
+
```
|
| 133 |
+
By default it runs dpo for 3 epochs, but you can change by adding flag `--num-epochs x`.
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
## Model Checkpoints
|
| 137 |
+
We provide two versions.
|
| 138 |
+
v1.0 has 5 rounds of alternate training and is used for evaluation in our paper.
|
| 139 |
+
v1.1 has 9 rounds of alternate training and is considered to have better performance than v1.0.
|
| 140 |
+
- [CADFusion v1.0](https://huggingface.co/microsoft/CADFusion/tree/main/v1_0)
|
| 141 |
+
- [CADFusion v1.1](https://huggingface.co/microsoft/CADFusion/tree/main/v1_1)
|
| 142 |
+
|
| 143 |
+
You should download, unzip and place them under the `exp/model_ckpt` folder for using.
|
| 144 |
+
|
| 145 |
+
## Inference & Visualization
|
| 146 |
+
Use `scripts/generate_samples.sh`.
|
| 147 |
+
```
|
| 148 |
+
./scripts/generate_samples.sh <run_name> test --full
|
| 149 |
+
```
|
| 150 |
+
You can find samples generated in `exp/model_generation/<run_name>.jsonl` and rendered figures under the `exp/figures/<run_name>` folder. The point clouds, .obj files, .step and .stl files are saved under `exp/visual_objects/<run_name>` directory for your own usage and evaluation.
|
| 151 |
+
|
| 152 |
+
## Evaluation
|
| 153 |
+
Use the functions in `src/test`. This includes the Chamfer Distance (`chamfer_dist.py`), Minimum Matching Distance, Coverage, Jensen-Shannon Divergence (`dist_eval.py`), and the VLM score (`VLM_score.py`).
|
| 154 |
+
|
| 155 |
+
For VLM Score, we use Azure OpenAI API to access the GPT-4o model for scoring the CAD objects.
|
| 156 |
+
In this way, you should log in your own azure account before using this module.
|
| 157 |
+
If your are using other LLM/VLM service and feel difficult to adapt to our setup, we provide the prompt in the python module that is available for you to integrate into your own testing pipeline.
|
| 158 |
+
|
| 159 |
+
###
|
| 160 |
+
|
| 161 |
+
## Acknowledgements
|
| 162 |
+
We would like to acknowledge that the CAD rendering and distributional metrics in this repository is partially based on and adapted from the [SkexGen](https://github.com/samxuxiang/SkexGen) project.
|
| 163 |
+
|
| 164 |
+
## Citation
|
| 165 |
+
If you find our work useful, please cite the following paper
|
| 166 |
+
```
|
| 167 |
+
@inproceedings{wang2025texttocad,
|
| 168 |
+
title = {Text-to-CAD Generation Through Infusing Visual Feedback in Large Language Models},
|
| 169 |
+
author = {Wang, Ruiyu and Yuan, Yu and Sun, Shizhao and Bian, Jiang},
|
| 170 |
+
booktitle = {International Conference on Machine Learning},
|
| 171 |
+
year={2025}
|
| 172 |
+
}
|
| 173 |
+
```
|
| 174 |
+
## Contributing
|
| 175 |
+
|
| 176 |
+
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
| 177 |
+
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
| 178 |
+
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
| 179 |
+
|
| 180 |
+
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
| 181 |
+
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
| 182 |
+
provided by the bot. You will only need to do this once across all repos using our CLA.
|
| 183 |
+
|
| 184 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
| 185 |
+
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
| 186 |
+
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
| 187 |
+
|
| 188 |
+
## Trademarks
|
| 189 |
+
|
| 190 |
+
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
|
| 191 |
+
trademarks or logos is subject to and must follow
|
| 192 |
+
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
|
| 193 |
+
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
|
| 194 |
+
Any use of third-party trademarks or logos are subject to those third-party's policies.
|
CADFusion/SECURITY.md
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.9 BLOCK -->
|
| 2 |
+
|
| 3 |
+
## Security
|
| 4 |
+
|
| 5 |
+
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
|
| 6 |
+
|
| 7 |
+
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
|
| 8 |
+
|
| 9 |
+
## Reporting Security Issues
|
| 10 |
+
|
| 11 |
+
**Please do not report security vulnerabilities through public GitHub issues.**
|
| 12 |
+
|
| 13 |
+
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
|
| 14 |
+
|
| 15 |
+
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
|
| 16 |
+
|
| 17 |
+
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
|
| 18 |
+
|
| 19 |
+
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
| 20 |
+
|
| 21 |
+
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
| 22 |
+
* Full paths of source file(s) related to the manifestation of the issue
|
| 23 |
+
* The location of the affected source code (tag/branch/commit or direct URL)
|
| 24 |
+
* Any special configuration required to reproduce the issue
|
| 25 |
+
* Step-by-step instructions to reproduce the issue
|
| 26 |
+
* Proof-of-concept or exploit code (if possible)
|
| 27 |
+
* Impact of the issue, including how an attacker might exploit the issue
|
| 28 |
+
|
| 29 |
+
This information will help us triage your report more quickly.
|
| 30 |
+
|
| 31 |
+
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
|
| 32 |
+
|
| 33 |
+
## Preferred Languages
|
| 34 |
+
|
| 35 |
+
We prefer all communications to be in English.
|
| 36 |
+
|
| 37 |
+
## Policy
|
| 38 |
+
|
| 39 |
+
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
|
| 40 |
+
|
| 41 |
+
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
CADFusion/SUPPORT.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TODO: The maintainer of this repo has not yet edited this file
|
| 2 |
+
|
| 3 |
+
**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
|
| 4 |
+
|
| 5 |
+
- **No CSS support:** Fill out this template with information about how to file issues and get help.
|
| 6 |
+
- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
|
| 7 |
+
- **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
|
| 8 |
+
|
| 9 |
+
*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
|
| 10 |
+
|
| 11 |
+
# Support
|
| 12 |
+
|
| 13 |
+
## How to file issues and get help
|
| 14 |
+
|
| 15 |
+
This project uses GitHub Issues to track bugs and feature requests. Please search the existing
|
| 16 |
+
issues before filing new issues to avoid duplicates. For new issues, file your bug or
|
| 17 |
+
feature request as a new Issue.
|
| 18 |
+
|
| 19 |
+
For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
|
| 20 |
+
FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
|
| 21 |
+
CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
|
| 22 |
+
|
| 23 |
+
## Microsoft Support Policy
|
| 24 |
+
|
| 25 |
+
Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
|
CADFusion/data/sl_data/convert.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import pickle
|
| 3 |
+
|
| 4 |
+
SKETCH_R = 1
|
| 5 |
+
RADIUS_R = 1
|
| 6 |
+
EXTRUDE_R = 1.0
|
| 7 |
+
SCALE_R = 1.4
|
| 8 |
+
OFFSET_R = 0.9
|
| 9 |
+
PIX_PAD = 4
|
| 10 |
+
CMD_PAD = 3
|
| 11 |
+
COORD_PAD = 4
|
| 12 |
+
EXT_PAD = 1
|
| 13 |
+
EXTRA_PAD = 1
|
| 14 |
+
R_PAD = 2
|
| 15 |
+
|
| 16 |
+
def create_curve_str(se_xy, se_cmd):
|
| 17 |
+
curve_str = ""
|
| 18 |
+
xy_offset = 0
|
| 19 |
+
if se_cmd == 0: # line
|
| 20 |
+
curve_str = " line," + ",".join(str(x) for x in se_xy[0])
|
| 21 |
+
xy_offset = 2
|
| 22 |
+
elif se_cmd == 1: # arc
|
| 23 |
+
curve_str = " arc," + ",".join(str(x) for x in se_xy[0:2].flatten())
|
| 24 |
+
xy_offset = 3
|
| 25 |
+
elif se_cmd == 2: # circle
|
| 26 |
+
curve_str = " circle," + ",".join(str(x) for x in se_xy[0:4].flatten())
|
| 27 |
+
xy_offset = 5
|
| 28 |
+
curve_str += " <curve_end>"
|
| 29 |
+
return curve_str, xy_offset
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def create_sketch_str(se_xy, se_cmd):
|
| 33 |
+
sketch_str = ""
|
| 34 |
+
len_xy, len_cmd = len(se_xy), len(se_cmd)
|
| 35 |
+
xy_idx = 0
|
| 36 |
+
for cmd_item in se_cmd: # for each command
|
| 37 |
+
if 0 <= cmd_item <= 2: # curve
|
| 38 |
+
curve_str, xy_offset = create_curve_str(se_xy[xy_idx:], cmd_item)
|
| 39 |
+
sketch_str += curve_str
|
| 40 |
+
xy_idx += xy_offset
|
| 41 |
+
elif cmd_item == -1: # loop
|
| 42 |
+
sketch_str += " <loop_end>"
|
| 43 |
+
xy_idx += 1
|
| 44 |
+
elif cmd_item == -2: # face
|
| 45 |
+
sketch_str += " <face_end>"
|
| 46 |
+
xy_idx += 1
|
| 47 |
+
elif cmd_item == -3: # sketch
|
| 48 |
+
sketch_str += " <sketch_end>"
|
| 49 |
+
xy_idx += 1
|
| 50 |
+
else:
|
| 51 |
+
raise ValueError("Invalid command: " + str(cmd_item))
|
| 52 |
+
if xy_idx != len_xy:
|
| 53 |
+
raise ValueError("xy_idx != len_xy")
|
| 54 |
+
return sketch_str
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def create_extrude_str(se_ext):
|
| 58 |
+
extrude_str = ""
|
| 59 |
+
# extrude operation
|
| 60 |
+
if se_ext[14] == 1:
|
| 61 |
+
extrude_str += "add"
|
| 62 |
+
elif se_ext[14] == 2:
|
| 63 |
+
extrude_str += "cut"
|
| 64 |
+
elif se_ext[14] == 3:
|
| 65 |
+
extrude_str += "intersect"
|
| 66 |
+
else:
|
| 67 |
+
raise ValueError("Invalid extrude operation: " + str(se_ext[14]))
|
| 68 |
+
# other extrude parameters
|
| 69 |
+
extrude_str = (
|
| 70 |
+
extrude_str + "," + ",".join(str(x - EXT_PAD) for x in se_ext[0:5])
|
| 71 |
+
) # ext_v, ext_T
|
| 72 |
+
extrude_str = (
|
| 73 |
+
extrude_str + "," + ",".join(str(x - R_PAD) for x in se_ext[5:14])
|
| 74 |
+
) # ext_R
|
| 75 |
+
extrude_str = (
|
| 76 |
+
extrude_str + "," + ",".join(str(x - EXT_PAD) for x in se_ext[15:18])
|
| 77 |
+
) # scale, offset
|
| 78 |
+
# extrude end
|
| 79 |
+
extrude_str += " <extrude_end>"
|
| 80 |
+
return extrude_str
|
| 81 |
+
|
| 82 |
+
def create_command_sequence(item):
|
| 83 |
+
se_str = ""
|
| 84 |
+
num_se = item["num_se"]
|
| 85 |
+
for se_idx in range(num_se): # for each sketch-extrude
|
| 86 |
+
xy, cmd, ext = (
|
| 87 |
+
item["se_xy"][se_idx] - COORD_PAD,
|
| 88 |
+
item["se_cmd"][se_idx] - CMD_PAD,
|
| 89 |
+
item["se_ext"][se_idx],
|
| 90 |
+
)
|
| 91 |
+
se_str = se_str + " " + create_sketch_str(xy, cmd).strip()
|
| 92 |
+
se_str = se_str + " " + create_extrude_str(ext).strip()
|
| 93 |
+
return se_str.strip()
|
| 94 |
+
|
| 95 |
+
with open("data/sl_data/train.json", "r") as f:
|
| 96 |
+
train_data = json.load(f)
|
| 97 |
+
with open("data/sl_data/test.json", "r") as f:
|
| 98 |
+
test_data = json.load(f)
|
| 99 |
+
with open("data/sl_data/val.json", "r") as f:
|
| 100 |
+
val_data = json.load(f)
|
| 101 |
+
|
| 102 |
+
with open("cad_data/train_deduplicate_s.pkl", "rb") as f:
|
| 103 |
+
sk_data = pickle.load(f)
|
| 104 |
+
|
| 105 |
+
for item in train_data:
|
| 106 |
+
serial_num = item['serial_num']
|
| 107 |
+
description = item['description']
|
| 108 |
+
item["command_sequence"] = create_command_sequence(sk_data[serial_num])
|
| 109 |
+
|
| 110 |
+
for item in test_data:
|
| 111 |
+
serial_num = item['serial_num']
|
| 112 |
+
description = item['description']
|
| 113 |
+
item["command_sequence"] = create_command_sequence(sk_data[serial_num])
|
| 114 |
+
|
| 115 |
+
for item in val_data:
|
| 116 |
+
serial_num = item['serial_num']
|
| 117 |
+
description = item['description']
|
| 118 |
+
item["command_sequence"] = create_command_sequence(sk_data[serial_num])
|
| 119 |
+
|
| 120 |
+
with open("data/sl_data/train.json", "w+") as f:
|
| 121 |
+
json.dump(train_data, f, indent=4)
|
| 122 |
+
with open("data/sl_data/test.json", "w+") as f:
|
| 123 |
+
json.dump(test_data, f, indent=4)
|
| 124 |
+
with open("data/sl_data/val.json", "w+") as f:
|
| 125 |
+
json.dump(val_data, f, indent=4)
|
CADFusion/data/sl_data/sl_data.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a080e00591a07420d916e82365d8602ebeab00ffd909f87bc9911b231f2f5ea0
|
| 3 |
+
size 1084518
|
CADFusion/data/vf_data/example_vf_data.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:907df4efd2ceafd9d8c336dfbf62d1754f692c0aab72b1b212ea7b844125e702
|
| 3 |
+
size 2142
|
CADFusion/ds_config.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
gradient_accumulation_steps: 1
|
| 5 |
+
gradient_clipping: 1.0
|
| 6 |
+
offload_optimizer_device: none
|
| 7 |
+
offload_param_device: none
|
| 8 |
+
zero3_init_flag: true
|
| 9 |
+
zero_stage: 2
|
| 10 |
+
distributed_type: DEEPSPEED
|
| 11 |
+
downcast_bf16: 'no'
|
| 12 |
+
machine_rank: 0
|
| 13 |
+
main_training_function: main
|
| 14 |
+
mixed_precision: fp16
|
| 15 |
+
num_machines: 1
|
| 16 |
+
num_processes: 1
|
| 17 |
+
rdzv_backend: static
|
| 18 |
+
same_network: true
|
| 19 |
+
tpu_env: []
|
| 20 |
+
tpu_use_cluster: false
|
| 21 |
+
tpu_use_sudo: false
|
| 22 |
+
use_cpu: false
|
CADFusion/pyproject.toml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "CADFusion"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "Enhancing Text-to-CAD generation via sequential learning and visual feedback."
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.8"
|
| 11 |
+
classifiers = [
|
| 12 |
+
"Programming Language :: Python :: 3",
|
| 13 |
+
"License :: OSI Approved :: Apache Software License",
|
| 14 |
+
]
|
| 15 |
+
dependencies = [
|
| 16 |
+
"torch==2.7.1",
|
| 17 |
+
"transformers==4.50.0",
|
| 18 |
+
"huggingface_hub==0.26.0",
|
| 19 |
+
"peft==0.9.0",
|
| 20 |
+
"accelerate==0.28.0",
|
| 21 |
+
"psutil==5.9.8",
|
| 22 |
+
"pillow==10.4.0",
|
| 23 |
+
"datasets==3.1.0",
|
| 24 |
+
"trl==0.11.4",
|
| 25 |
+
"gdown==5.2.0"
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
[project.optional-dependencies]
|
| 29 |
+
train = ["wandb==0.16.4", "deepspeed==0.15.0"]
|
| 30 |
+
render = ["trimesh==4.4.9", "plyfile==1.0.3"]
|
| 31 |
+
eval = ["openai==1.75.0", "azure-identity==1.21.0", "scikit-learn==1.3.2"]
|
| 32 |
+
build = ["build", "twine"]
|
| 33 |
+
|
| 34 |
+
[tool.setuptools.packages.find]
|
| 35 |
+
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
|
| 36 |
+
|
| 37 |
+
[tool.wheel]
|
| 38 |
+
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
|
CADFusion/scripts/alternate_VF.sh
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# set it to your data path
|
| 2 |
+
data_path=data/sl_data
|
| 3 |
+
# by default set it to CADFusion/exp
|
| 4 |
+
exp_path=exp/model_ckpt
|
| 5 |
+
# by default set it to CADFusion/data
|
| 6 |
+
vf_path=data/vf_data
|
| 7 |
+
train_data=$data_path/train.json
|
| 8 |
+
eval_data=$data_path/val.json
|
| 9 |
+
|
| 10 |
+
# This script requires your SL run named as xxxx0, because for each VF stage, the final digit increments
|
| 11 |
+
# to show the number of VF rounds finished.
|
| 12 |
+
# e.g. SL name: CAD-0
|
| 13 |
+
# base_name: CAD- (remove the last digit, the script autofills it)
|
| 14 |
+
# VF run 1: CAD-1 (automatically)
|
| 15 |
+
# VF run 2: CAD-2 (automatically)
|
| 16 |
+
# ...
|
| 17 |
+
base_name=model_name_you_trained_for_SL_with_last_digit_removed
|
| 18 |
+
|
| 19 |
+
run_name=${base_name}0
|
| 20 |
+
./scripts/generate_samples.sh $run_name test "--full --device-map auto"
|
| 21 |
+
./scripts/generate_samples.sh $run_name train "--sample-len 1000 --device-map auto"
|
| 22 |
+
|
| 23 |
+
./scripts/make_dpo_data.sh $run_name --score-only
|
| 24 |
+
./scripts/make_dpo_data.sh $run_name-train "--gpu 0"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
for LOOP in 1 2 3 4 5
|
| 28 |
+
do
|
| 29 |
+
echo "Starting VF round $LOOP"
|
| 30 |
+
run_name=$base_name$LOOP
|
| 31 |
+
dpo_training_path=$vf_path/$base_name$((LOOP-1))-train.json
|
| 32 |
+
dpo_run_name=$base_name$LOOP-dpo
|
| 33 |
+
dpo_save_path=$exp_path/$dpo_run_name
|
| 34 |
+
sft_run_name=$base_name$LOOP
|
| 35 |
+
|
| 36 |
+
python src/train/dpo.py --run-name $dpo_run_name --pretrained-path $exp_path/$base_name$((LOOP-1)) --data-path $dpo_training_path --output-path $dpo_save_path
|
| 37 |
+
python src/train/llama_finetune.py --num-epochs 1 --run-name $sft_run_name --data-path $train_data --eval-data-path $eval_data --eval-freq 3000 --pretrained-path $dpo_save_path --expdir $exp_path
|
| 38 |
+
|
| 39 |
+
./scripts/generate_samples.sh $dpo_run_name test "--full --device-map auto"
|
| 40 |
+
./scripts/generate_samples.sh $run_name test "--full --device-map auto"
|
| 41 |
+
./scripts/generate_samples.sh $run_name train "--sample-len 1000 --device-map auto"
|
| 42 |
+
|
| 43 |
+
./scripts/make_dpo_data.sh $dpo_run_name --score-only
|
| 44 |
+
./scripts/make_dpo_data.sh $run_name "--score-only --gpu 0"
|
| 45 |
+
./scripts/make_dpo_data.sh $run_name-train "--gpu 0"
|
| 46 |
+
|
| 47 |
+
done
|
CADFusion/scripts/alternate_VF_quadra_gpu.sh
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# set it to your data path
|
| 2 |
+
data_path=data/sl_data
|
| 3 |
+
# by default set it to CADFusion/exp
|
| 4 |
+
exp_path=exp/model_ckpt
|
| 5 |
+
# by default set it to CADFusion/data
|
| 6 |
+
vf_path=data/vf_data
|
| 7 |
+
train_data=$data_path/train.json
|
| 8 |
+
eval_data=$data_path/val.json
|
| 9 |
+
|
| 10 |
+
# This script requires your SL run named as xxxx0, because for each VF stage, the final digit increments
|
| 11 |
+
# to show the number of VF rounds finished.
|
| 12 |
+
# e.g. SL name: CAD-0
|
| 13 |
+
# base_name: CAD- (remove the last digit, the script autofills it)
|
| 14 |
+
# VF run 1: CAD-1 (automatically)
|
| 15 |
+
# VF run 2: CAD-2 (automatically)
|
| 16 |
+
# ...
|
| 17 |
+
base_name=model_name_you_trained_for_SL_with_last_digit_removed
|
| 18 |
+
|
| 19 |
+
run_name=${base_name}0
|
| 20 |
+
CUDA_VISIBLE_DEVICES=0,1 ./scripts/generate_samples.sh $run_name test "--full --device-map auto" &
|
| 21 |
+
CUDA_VISIBLE_DEVICES=2,3 ./scripts/generate_samples.sh $run_name train "--sample-len 10 --device-map auto"
|
| 22 |
+
wait
|
| 23 |
+
|
| 24 |
+
./scripts/make_dpo_data.sh $run_name --score-only &
|
| 25 |
+
./scripts/make_dpo_data.sh $run_name-train "--gpu 1"
|
| 26 |
+
wait
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
for LOOP in 1 2 3 4 5
|
| 30 |
+
do
|
| 31 |
+
echo "Starting VF round $LOOP"
|
| 32 |
+
run_name=$base_name$LOOP
|
| 33 |
+
dpo_training_path=$vf_path/$base_name$((LOOP-1))-train.json
|
| 34 |
+
dpo_run_name=$base_name$LOOP-dpo
|
| 35 |
+
dpo_save_path=$exp_path/$dpo_run_name
|
| 36 |
+
sft_run_name=$base_name$LOOP
|
| 37 |
+
|
| 38 |
+
python src/train/dpo.py --run-name $dpo_run_name --pretrained-path $exp_path/$base_name$((LOOP-1)) --data-path $dpo_training_path --output-path $dpo_save_path
|
| 39 |
+
python src/train/llama_finetune.py --num-epochs 1 --run-name $sft_run_name --data-path $train_data --eval-data-path $eval_data --eval-freq 3000 --pretrained-path $dpo_save_path --expdir $exp_path
|
| 40 |
+
|
| 41 |
+
CUDA_VISIBLE_DEVICES=0 ./scripts/generate_samples.sh $dpo_run_name test "--full --device-map auto" &
|
| 42 |
+
CUDA_VISIBLE_DEVICES=1 ./scripts/generate_samples.sh $run_name test "--full --device-map auto" &
|
| 43 |
+
CUDA_VISIBLE_DEVICES=2,3 ./scripts/generate_samples.sh $run_name train "--sample-len 1000 --device-map auto"
|
| 44 |
+
wait
|
| 45 |
+
|
| 46 |
+
./scripts/make_dpo_data.sh $dpo_run_name --score-only &
|
| 47 |
+
./scripts/make_dpo_data.sh $run_name "--score-only --gpu 1" &
|
| 48 |
+
./scripts/make_dpo_data.sh $run_name-train "--gpu 2"
|
| 49 |
+
wait
|
| 50 |
+
done
|
CADFusion/scripts/generate_samples.sh
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train_data_path=data/sl_data/train.json
|
| 2 |
+
test_data_path=data/sl_data/test.json
|
| 3 |
+
run_name=$1
|
| 4 |
+
temperature=0.9
|
| 5 |
+
|
| 6 |
+
if [ -z "$2" ]
|
| 7 |
+
then
|
| 8 |
+
data_path=$test_data_path
|
| 9 |
+
else
|
| 10 |
+
if [ $2 = "train" ]; then
|
| 11 |
+
data_path=$train_data_path
|
| 12 |
+
run_name=$1-train
|
| 13 |
+
else
|
| 14 |
+
data_path=$test_data_path
|
| 15 |
+
temperature=0.3
|
| 16 |
+
fi
|
| 17 |
+
fi
|
| 18 |
+
|
| 19 |
+
model_path=exp/model_ckpt/$1
|
| 20 |
+
inference_path=exp/model_generation/$run_name.jsonl
|
| 21 |
+
visual_obj_path=exp/visual_objects/$run_name
|
| 22 |
+
output_figure_path=exp/figures/$run_name
|
| 23 |
+
log_path=exp/logs/$run_name
|
| 24 |
+
|
| 25 |
+
mkdir -p $log_path
|
| 26 |
+
mkdir -p exp/model_generation
|
| 27 |
+
|
| 28 |
+
echo "--------------------Inferencing--------------------" > $log_path/inference.txt
|
| 29 |
+
rm $inference_path
|
| 30 |
+
python3 src/test/inference.py --pretrained-path $model_path --in-path $data_path --out-path $inference_path --num-samples 5 --temperature $temperature --model-name llama3 > $log_path/inference.txt $3
|
| 31 |
+
|
| 32 |
+
echo "--------------------Parsing CAD objects--------------------" > $log_path/parsing_cad.txt
|
| 33 |
+
rm -rf $visual_obj_path
|
| 34 |
+
python3 src/rendering_utils/parser.py --in-path $inference_path --out-path $visual_obj_path > $log_path/parsing_cad.txt
|
| 35 |
+
|
| 36 |
+
echo "--------------------Parsing visual objects--------------------" > $log_path/parsing_visual.txt
|
| 37 |
+
python3 src/rendering_utils/parser_visual.py --data_folder $visual_obj_path > $log_path/parsing_visual.txt
|
| 38 |
+
python3 src/rendering_utils/ptl_sampler.py --in_dir $visual_obj_path --out_dir ptl > $log_path/sampling_ptl.out
|
| 39 |
+
|
| 40 |
+
echo "--------------------Rendering--------------------" > $log_path/rendering.txt
|
| 41 |
+
rm -rf $output_figure_path
|
| 42 |
+
export DISPLAY=:99
|
| 43 |
+
Xvfb :99 -screen 0 640x480x24 &
|
| 44 |
+
python3 src/rendering_utils/img_renderer.py --input_dir $visual_obj_path --output_dir $output_figure_path > $log_path/rendering.txt
|
CADFusion/scripts/make_dpo_data.sh
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
source_path=exp/model_generation/$1.jsonl
|
| 2 |
+
figure_path=exp/figures/$1/
|
| 3 |
+
save_path=data/vf_data/$1.json
|
| 4 |
+
|
| 5 |
+
python src/dpo/make_dpo_dataset.py --source-data-path $source_path --figure-path $figure_path --save-path $save_path --num-samples 5 $2
|
CADFusion/scripts/preprocess_skexgen.sh
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gdown --id 1so_CCGLIhqGEDQxMoiR--A4CQk4MjuOp
|
| 2 |
+
unzip cad_data.zip
|
| 3 |
+
|
| 4 |
+
# convert data into sequence and save in json
|
| 5 |
+
mkdir data
|
| 6 |
+
mkdir data/raw
|
| 7 |
+
python3 src/data_preprocessing/convert.py --in-path cad_data/train_deduplicate_s.pkl --out-path data/raw/train.json
|
| 8 |
+
python3 src/data_preprocessing/convert.py --in-path cad_data/val.pkl --out-path data/raw/val.json
|
| 9 |
+
python3 src/data_preprocessing/convert.py --in-path cad_data/test.pkl --out-path data/raw/test.json
|
| 10 |
+
|
| 11 |
+
# render the image for each entry in order to retrieve textual information by captioning:
|
| 12 |
+
mkdir exp
|
| 13 |
+
mkdir exp/visual_objects
|
| 14 |
+
mkdir exp/figures
|
| 15 |
+
for file in test val train; do
|
| 16 |
+
python3 src/rendering_utils/parser.py --in-path data/raw/$file.json --out-path exp/visual_objects/$file
|
| 17 |
+
timeout 180 python3 src/rendering_utils/parser_visual.py --data_folder exp/visual_objects/$file
|
| 18 |
+
|
| 19 |
+
export DISPLAY=:99
|
| 20 |
+
Xvfb :99 -screen 0 640x480x24 &
|
| 21 |
+
python3 src/rendering_utils/img_renderer.py --input_dir exp/visual_objects/$file --output_dir exp/figures/$file
|
| 22 |
+
done
|
| 23 |
+
|
| 24 |
+
# caption the images to generate descriptions
|
| 25 |
+
mkdir data/sl_data
|
| 26 |
+
python3 src/data_preprocessing/captioning.py --image-folder-path exp/figures/train --out-path data/sl_data/train.json
|
| 27 |
+
python3 src/data_preprocessing/captioning.py --image-folder-path exp/figures/val --out-path data/sl_data/val.json
|
| 28 |
+
python3 src/data_preprocessing/captioning.py --image-folder-path exp/figures/test --out-path data/sl_data/test.json
|
CADFusion/scripts/train_loop.sh
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# by default set it to CADFusion/data
|
| 2 |
+
data_path=/your/path/to/data/folder
|
| 3 |
+
# by default set it to CADFusion/exp
|
| 4 |
+
exp_path=/your/path/to/exp/folder
|
| 5 |
+
# by default set it to CADFusion/data
|
| 6 |
+
exp_path=/your/path/to/vf_data/folder
|
| 7 |
+
train_data=$data_path/train.json
|
| 8 |
+
eval_data=$data_path/eval.json
|
| 9 |
+
|
| 10 |
+
base_name=model_name_you_trained_for_SL
|
| 11 |
+
|
| 12 |
+
run_name=${base_name}0
|
| 13 |
+
CUDA_VISIBLE_DEVICES=0,1 ./scripts/inference.sh $run_name test "--full --device-map auto" &
|
| 14 |
+
CUDA_VISIBLE_DEVICES=2,3 ./scripts/inference.sh $run_name train "--sample-len 1000 --device-map auto"
|
| 15 |
+
wait
|
| 16 |
+
|
| 17 |
+
./scripts/make_dpo_data.sh $run_name --score-only &
|
| 18 |
+
./scripts/make_dpo_data.sh $run_name-train "--gpu 1"
|
| 19 |
+
wait
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
for LOOP in 1 2 3 4 5
|
| 23 |
+
do
|
| 24 |
+
run_name=$base_name$LOOP
|
| 25 |
+
dpo_training_path=$vf_path/$base_name$((LOOP-1))-train.json
|
| 26 |
+
dpo_run_name=$base_name$LOOP-dpo
|
| 27 |
+
dpo_save_path=$exp_path/$dpo_run_name
|
| 28 |
+
sft_run_name=$base_name$LOOP
|
| 29 |
+
|
| 30 |
+
python src/train/dpo.py --run-name $dpo_run_name --pretrained-path $exp_path/$base_name$((LOOP-1)) --data-path $dpo_training_path --output-path $dpo_save_path
|
| 31 |
+
python src/train/llama_finetune.py --num-epochs 1 --run-name $sft_run_name --data-path $train_data --eval-data-path $eval_data --eval-freq 3000 --pretrained-path $dpo_save_path --expdir $exp_path
|
| 32 |
+
|
| 33 |
+
CUDA_VISIBLE_DEVICES=0 ./scripts/inference.sh $dpo_run_name test "--full --device-map auto" &
|
| 34 |
+
CUDA_VISIBLE_DEVICES=1 ./scripts/inference.sh $run_name test "--full --device-map auto" &
|
| 35 |
+
CUDA_VISIBLE_DEVICES=2,3 ./scripts/inference.sh $run_name train "--sample-len 1000 --device-map auto"
|
| 36 |
+
wait
|
| 37 |
+
|
| 38 |
+
./scripts/make_dpo_data.sh $dpo_run_name --score-only &
|
| 39 |
+
./scripts/make_dpo_data.sh $run_name "--score-only --gpu 1" &
|
| 40 |
+
./scripts/make_dpo_data.sh $run_name-train "--gpu 2"
|
| 41 |
+
wait
|
| 42 |
+
done
|
CADFusion/scripts/train_with_shuffling.sh
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# set it to your data path
|
| 2 |
+
data_path=data/sl_data
|
| 3 |
+
# set it to your experiment path
|
| 4 |
+
exp_path=exp/model_ckpt
|
| 5 |
+
train_data=$data_path/train.json
|
| 6 |
+
eval_data=$data_path/val.json
|
| 7 |
+
shuffle_dataset_between_x_epochs=2
|
| 8 |
+
mkdir -p $exp_path
|
| 9 |
+
|
| 10 |
+
# round 0
|
| 11 |
+
accelerate launch --config_file ds_config.yaml src/train/llama_finetune.py --lora-rank 32 --lora-alpha 32 \
|
| 12 |
+
--num-epochs $shuffle_dataset_between_x_epochs --run-name $1 --data-path $train_data --eval-data-path $eval_data \
|
| 13 |
+
--device-map accelerate --eval-freq 1000 --save-freq 50000 --model-name llama3 --expdir $exp_path
|
| 14 |
+
|
| 15 |
+
for round in 1 2 3 4 5 6 7 8 9
|
| 16 |
+
do
|
| 17 |
+
python src/train/llama_finetune.py --lora-rank 32 --pretrained-path $exp_path/$1 --lora-alpha 32 \
|
| 18 |
+
--num-epochs $shuffle_dataset_between_x_epochs --run-name $1 --data-path $train_data --eval-data-path $eval_data \
|
| 19 |
+
--eval-freq 4000 --save-freq 50000 --expdir $exp_path
|
| 20 |
+
done
|
CADFusion/src/data_preprocessing/call_openai.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import AzureOpenAI
|
| 2 |
+
from azure.identity import AzureCliCredential, get_bearer_token_provider
|
| 3 |
+
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
def setup_client():
|
| 7 |
+
scope = "api://trapi/.default"
|
| 8 |
+
credential = get_bearer_token_provider(AzureCliCredential(), scope)
|
| 9 |
+
|
| 10 |
+
api_version = '2024-12-01-preview'
|
| 11 |
+
deployment_name = 'gpt-4o_2024-08-06'
|
| 12 |
+
instance = 'gcr/shared/' # See https://aka.ms/trapi/models for the instance name, remove /openai (library adds it implicitly)
|
| 13 |
+
endpoint = f'https://trapi.research.microsoft.com/{instance}'
|
| 14 |
+
|
| 15 |
+
client = AzureOpenAI(
|
| 16 |
+
azure_endpoint=endpoint,
|
| 17 |
+
azure_ad_token_provider=credential,
|
| 18 |
+
api_version=api_version,
|
| 19 |
+
)
|
| 20 |
+
return client, deployment_name
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def call_openai(client, deployment, prompt):
|
| 24 |
+
output = None
|
| 25 |
+
while output is None:
|
| 26 |
+
try:
|
| 27 |
+
time.sleep(0.5)
|
| 28 |
+
completion = client.chat.completions.create(
|
| 29 |
+
model = deployment,
|
| 30 |
+
messages = prompt,
|
| 31 |
+
)
|
| 32 |
+
output = completion.choices[0].message.content
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print("API error:", e)
|
| 35 |
+
time.sleep(1)
|
| 36 |
+
output = None
|
| 37 |
+
return output
|
CADFusion/src/data_preprocessing/captioning.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
import base64
|
| 4 |
+
import json
|
| 5 |
+
import time
|
| 6 |
+
from mimetypes import guess_type
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
# from parse_sequence import parse_sequence
|
| 9 |
+
# from parse_visual import run_parallel
|
| 10 |
+
# from parse_image import render_file
|
| 11 |
+
from call_openai import setup_client, call_openai
|
| 12 |
+
import argparse
|
| 13 |
+
|
| 14 |
+
parser = argparse.ArgumentParser()
|
| 15 |
+
parser.add_argument('--image-folder-path', type=str, default='exp/figures/test', help='Path to the input folder')
|
| 16 |
+
parser.add_argument('--out-path', type=str, default='data/raw', help='Path to the output file')
|
| 17 |
+
args = parser.parse_args()
|
| 18 |
+
file_path = args.image_folder_path
|
| 19 |
+
out_path = args.out_path
|
| 20 |
+
|
| 21 |
+
client, deployment_name = setup_client()
|
| 22 |
+
call_client = call_openai
|
| 23 |
+
|
| 24 |
+
def local_image_to_data_url(image_path):
|
| 25 |
+
# Encode a local image into data URL
|
| 26 |
+
mime_type, _ = guess_type(image_path)
|
| 27 |
+
if mime_type is None:
|
| 28 |
+
mime_type = 'application/octet-stream'
|
| 29 |
+
with open(image_path, "rb") as image_file:
|
| 30 |
+
base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
|
| 31 |
+
return f"data:{mime_type};base64,{base64_encoded_data}"
|
| 32 |
+
|
| 33 |
+
def call_model_1(prompt, image_path):
|
| 34 |
+
message_text = [
|
| 35 |
+
{"role":"system","content":"You are an AI assistant that helps people find information."},
|
| 36 |
+
{"role":"user","content":[
|
| 37 |
+
{
|
| 38 |
+
"type": "text",
|
| 39 |
+
"text": prompt
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"type": "image_url",
|
| 43 |
+
"image_url": {"url": local_image_to_data_url(image_path)}
|
| 44 |
+
}
|
| 45 |
+
]}
|
| 46 |
+
]
|
| 47 |
+
return call_client(client, deployment_name, message_text)
|
| 48 |
+
|
| 49 |
+
def call_model_2(prompt1, image_path, output1, prompt2):
|
| 50 |
+
message_text = [
|
| 51 |
+
{"role":"system","content":"You are an AI assistant that helps people find information."},
|
| 52 |
+
{"role":"user","content":[
|
| 53 |
+
{
|
| 54 |
+
"type": "text",
|
| 55 |
+
"text": prompt1
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"type": "image_url",
|
| 59 |
+
"image_url": {"url": local_image_to_data_url(image_path)}
|
| 60 |
+
}
|
| 61 |
+
]},
|
| 62 |
+
{"role":"assistant","content":output1},
|
| 63 |
+
{"role":"user","content":prompt2}
|
| 64 |
+
]
|
| 65 |
+
return call_client(client, deployment_name, message_text)
|
| 66 |
+
|
| 67 |
+
files = [f for f in os.listdir(args.image_folder_path) if os.path.isfile(os.path.join(args.image_folder_path, f))]
|
| 68 |
+
files.sort()
|
| 69 |
+
results = []
|
| 70 |
+
for filename in tqdm(files):
|
| 71 |
+
time.sleep(0.5)
|
| 72 |
+
output1 = None
|
| 73 |
+
output2 = None
|
| 74 |
+
image_path = os.path.join(file_path, filename)
|
| 75 |
+
# Send request
|
| 76 |
+
prompt1 = """Propose a series of questions about the 3D shape and give the answers. The first question should ask for a detailed description and others should focus on the specific geometric properties, number, size proportions and positional relationship, and other details."""
|
| 77 |
+
prompt2 = """Based on the dialogue, please give a final description of the 3D shape. No more than 70 words."""
|
| 78 |
+
while output1 is None or str(output1).startswith("I'm sorry"):
|
| 79 |
+
try:
|
| 80 |
+
output1 = call_model_1(prompt1, image_path)
|
| 81 |
+
except requests.RequestException as e:
|
| 82 |
+
print(f"Request failed: {e}")
|
| 83 |
+
time.sleep(1)
|
| 84 |
+
output1 = None
|
| 85 |
+
while output2 is None or str(output2).startswith("I'm sorry"):
|
| 86 |
+
try:
|
| 87 |
+
output2 = call_model_2(prompt1, image_path, output1, prompt2)
|
| 88 |
+
except requests.RequestException as e:
|
| 89 |
+
print(f"Request failed: {e}")
|
| 90 |
+
time.sleep(1)
|
| 91 |
+
output2 = None
|
| 92 |
+
|
| 93 |
+
result = {
|
| 94 |
+
"pic_name":filename,
|
| 95 |
+
"questions": output1,
|
| 96 |
+
"description":output2
|
| 97 |
+
}
|
| 98 |
+
results.append(result)
|
| 99 |
+
|
| 100 |
+
with open(out_path, 'w+', encoding='utf-8') as f:
|
| 101 |
+
json.dump(results, f, ensure_ascii=False, indent=4)
|
CADFusion/src/data_preprocessing/convert.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
# hyperparameters from SkexGen project
|
| 5 |
+
SKETCH_R = 1
|
| 6 |
+
RADIUS_R = 1
|
| 7 |
+
EXTRUDE_R = 1.0
|
| 8 |
+
SCALE_R = 1.4
|
| 9 |
+
OFFSET_R = 0.9
|
| 10 |
+
PIX_PAD = 4
|
| 11 |
+
CMD_PAD = 3
|
| 12 |
+
COORD_PAD = 4
|
| 13 |
+
EXT_PAD = 1
|
| 14 |
+
EXTRA_PAD = 1
|
| 15 |
+
R_PAD = 2
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def create_curve_str(se_xy, se_cmd):
|
| 19 |
+
curve_str = ""
|
| 20 |
+
xy_offset = 0
|
| 21 |
+
if se_cmd == 0: # line
|
| 22 |
+
curve_str = " line," + ",".join(str(x) for x in se_xy[0])
|
| 23 |
+
xy_offset = 2
|
| 24 |
+
elif se_cmd == 1: # arc
|
| 25 |
+
curve_str = " arc," + ",".join(str(x) for x in se_xy[0:2].flatten())
|
| 26 |
+
xy_offset = 3
|
| 27 |
+
elif se_cmd == 2: # circle
|
| 28 |
+
curve_str = " circle," + ",".join(str(x) for x in se_xy[0:4].flatten())
|
| 29 |
+
xy_offset = 5
|
| 30 |
+
curve_str += " <curve_end>"
|
| 31 |
+
return curve_str, xy_offset
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def create_sketch_str(se_xy, se_cmd):
|
| 35 |
+
sketch_str = ""
|
| 36 |
+
len_xy, len_cmd = len(se_xy), len(se_cmd)
|
| 37 |
+
xy_idx = 0
|
| 38 |
+
for cmd_item in se_cmd: # for each command
|
| 39 |
+
if 0 <= cmd_item <= 2: # curve
|
| 40 |
+
curve_str, xy_offset = create_curve_str(se_xy[xy_idx:], cmd_item)
|
| 41 |
+
sketch_str += curve_str
|
| 42 |
+
xy_idx += xy_offset
|
| 43 |
+
elif cmd_item == -1: # loop
|
| 44 |
+
sketch_str += " <loop_end>"
|
| 45 |
+
xy_idx += 1
|
| 46 |
+
elif cmd_item == -2: # face
|
| 47 |
+
sketch_str += " <face_end>"
|
| 48 |
+
xy_idx += 1
|
| 49 |
+
elif cmd_item == -3: # sketch
|
| 50 |
+
sketch_str += " <sketch_end>"
|
| 51 |
+
xy_idx += 1
|
| 52 |
+
else:
|
| 53 |
+
raise ValueError("Invalid command: " + str(cmd_item))
|
| 54 |
+
if xy_idx != len_xy:
|
| 55 |
+
raise ValueError("xy_idx != len_xy")
|
| 56 |
+
return sketch_str
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def create_extrude_str(se_ext):
|
| 60 |
+
extrude_str = ""
|
| 61 |
+
# extrude operation
|
| 62 |
+
if se_ext[14] == 1:
|
| 63 |
+
extrude_str += "add"
|
| 64 |
+
elif se_ext[14] == 2:
|
| 65 |
+
extrude_str += "cut"
|
| 66 |
+
elif se_ext[14] == 3:
|
| 67 |
+
extrude_str += "intersect"
|
| 68 |
+
else:
|
| 69 |
+
raise ValueError("Invalid extrude operation: " + str(se_ext[14]))
|
| 70 |
+
# other extrude parameters
|
| 71 |
+
extrude_str = (
|
| 72 |
+
extrude_str + "," + ",".join(str(x - EXT_PAD) for x in se_ext[0:5])
|
| 73 |
+
) # ext_v, ext_T
|
| 74 |
+
extrude_str = (
|
| 75 |
+
extrude_str + "," + ",".join(str(x - R_PAD) for x in se_ext[5:14])
|
| 76 |
+
) # ext_R
|
| 77 |
+
extrude_str = (
|
| 78 |
+
extrude_str + "," + ",".join(str(x - EXT_PAD) for x in se_ext[15:18])
|
| 79 |
+
) # scale, offset
|
| 80 |
+
# extrude end
|
| 81 |
+
extrude_str += " <extrude_end>"
|
| 82 |
+
return extrude_str
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def convert(in_path, out_path):
|
| 86 |
+
with open(in_path, "rb") as f:
|
| 87 |
+
data = pickle.load(f)
|
| 88 |
+
print("Data loaded: " + str(len(data)) + " samples")
|
| 89 |
+
|
| 90 |
+
results = []
|
| 91 |
+
for item in data: # for each data
|
| 92 |
+
se_str = ""
|
| 93 |
+
num_se = item["num_se"]
|
| 94 |
+
for se_idx in range(num_se): # for each sketch-extrude
|
| 95 |
+
xy, cmd, ext = (
|
| 96 |
+
item["se_xy"][se_idx] - COORD_PAD,
|
| 97 |
+
item["se_cmd"][se_idx] - CMD_PAD,
|
| 98 |
+
item["se_ext"][se_idx],
|
| 99 |
+
)
|
| 100 |
+
se_str = se_str + " " + create_sketch_str(xy, cmd).strip()
|
| 101 |
+
se_str = se_str + " " + create_extrude_str(ext).strip()
|
| 102 |
+
results.append(se_str.strip())
|
| 103 |
+
|
| 104 |
+
# with open(out_path, "wb") as f:
|
| 105 |
+
# pickle.dump(results, f)
|
| 106 |
+
# print("Data converted: " + str(len(results)) + " samples")
|
| 107 |
+
with open(out_path, "w") as f:
|
| 108 |
+
json.dump(results, f, indent=4)
|
| 109 |
+
print("Data converted: " + str(len(results)) + " samples")
|
| 110 |
+
# with open(out_path, "w") as f: # Open in text mode
|
| 111 |
+
# for result in results:
|
| 112 |
+
# f.write(result + "\n")
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
parser = argparse.ArgumentParser()
|
| 116 |
+
parser.add_argument("--in-path", type=str, required=True)
|
| 117 |
+
parser.add_argument("--out-path", type=str, required=True)
|
| 118 |
+
args = parser.parse_args()
|
| 119 |
+
|
| 120 |
+
convert(args.in_path, args.out_path)
|
CADFusion/src/dpo/llava_utils.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import base64
|
| 3 |
+
import time
|
| 4 |
+
import json
|
| 5 |
+
import requests
|
| 6 |
+
from mimetypes import guess_type
|
| 7 |
+
from transformers import pipeline
|
| 8 |
+
from transformers import LlavaNextProcessor
|
| 9 |
+
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image
|
| 12 |
+
dev='cuda:0'
|
| 13 |
+
|
| 14 |
+
# processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
|
| 15 |
+
# model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
| 16 |
+
# model.to(device)
|
| 17 |
+
|
| 18 |
+
def restart_model(device):
|
| 19 |
+
global dev
|
| 20 |
+
dev = device
|
| 21 |
+
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
|
| 22 |
+
model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
| 23 |
+
model.to(device)
|
| 24 |
+
return model, processor
|
| 25 |
+
|
| 26 |
+
def ask_llm_on_figure(data, model, processor):
|
| 27 |
+
"""
|
| 28 |
+
The layout of a typical data item
|
| 29 |
+
{
|
| 30 |
+
"index": 1,
|
| 31 |
+
"pic_name": "000000_001_final.png",
|
| 32 |
+
"ground_truth": "line,9,9 <curve_end> line,9,53 <curve_end> line,53,53 <curve_end> line,53,9 <curve_end> <loop_end> circle,31,29,31,20,35,25,27,25 <curve_end> <loop_end> circle,31,41,31,32,35,37,27,37 <curve_end> <loop_end> <face_end> <sketch_end> add,31,32,31,31,31,0,1,0,0,0,1,1,0,0,62,31,31 <extrude_end>",
|
| 33 |
+
"description": "Create a rectangular panel with two circular through-holes centrally aligned on the vertical axis.",
|
| 34 |
+
"prompt": "Below is a description of a 3D shape:\nCreate a rectangular panel with two circular through-holes centrally aligned on the vertical axis.\nGenerate a Computer-Aided Design(CAD) command sequence of the 3D shape:\n",
|
| 35 |
+
"output": "line,se,9 <curve_end> line,ne,9 <curve_end> line,ne,53 <curve_end> line,se,53 <curve_end> <loop_end> circle,22,41,22, Twenty1 ,31,30,12,30 <curve_end> <loop_end> circle,40,21,40, Ten2 ,50,32,29,32 <curve_end> <loop_end> <face_end> <sketch_end> add,31,33,31,31,31,1,0,0,0,0,1,0,-1,0,62,31,31 <extrude_end>"
|
| 36 |
+
},
|
| 37 |
+
"""
|
| 38 |
+
url = data['figure_path']
|
| 39 |
+
image = Image.open(url)
|
| 40 |
+
description = data['description']
|
| 41 |
+
# data_scale = 10
|
| 42 |
+
# measurement = 'the degree of correspondence between them'
|
| 43 |
+
|
| 44 |
+
prompt = 'You are a harsh grader for new CAD designers\' works. The following is a text description of a CAD figure that they designed and an image of a CAD instance.' +\
|
| 45 |
+
f'\nDescription: {description}\n ' + \
|
| 46 |
+
f'Comment on this work for \n '+\
|
| 47 |
+
'1. If the overall shape remains correct; \n '+\
|
| 48 |
+
'2. If the number of components are correct, especially the circular holes; \n '+\
|
| 49 |
+
'3. If the distribution of the components are natural, i.e. they are not clustered together or collide with each other.\n'+\
|
| 50 |
+
'After that, give a score out of 10. Do not comment on issues such as texture, smoothness and colors'
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
conversation = [
|
| 54 |
+
{
|
| 55 |
+
"role": "user",
|
| 56 |
+
"content": [
|
| 57 |
+
{"type": "text", "text": prompt},
|
| 58 |
+
{"type": "image"},
|
| 59 |
+
],
|
| 60 |
+
},
|
| 61 |
+
]
|
| 62 |
+
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
| 63 |
+
inputs = processor(images=image, text=prompt, return_tensors="pt",).to(dev, torch.float16)
|
| 64 |
+
|
| 65 |
+
# autoregressively complete prompt
|
| 66 |
+
output = model.generate(**inputs, max_new_tokens=256, pad_token_id=processor.tokenizer.eos_token_id)
|
| 67 |
+
output = processor.decode(output[0], skip_special_tokens=True)
|
| 68 |
+
idx = output.index('assistant\n')
|
| 69 |
+
response = output[idx+10:]
|
| 70 |
+
return(response)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def ask_llm(data, model, processor):
|
| 74 |
+
description = data['gpt_label']
|
| 75 |
+
|
| 76 |
+
prompt = 'The following is an evaluation of an CAD object.' +\
|
| 77 |
+
f'\n evaluation: {description}\n' +\
|
| 78 |
+
'Extract the integer score of the evaluation. The score is between 0 to 10. Return the number only.'
|
| 79 |
+
|
| 80 |
+
conversation = [
|
| 81 |
+
{
|
| 82 |
+
"role": "user",
|
| 83 |
+
"content": [
|
| 84 |
+
{"type": "text", "text": prompt},
|
| 85 |
+
],
|
| 86 |
+
},
|
| 87 |
+
]
|
| 88 |
+
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
| 89 |
+
inputs = processor(text=prompt, return_tensors="pt",).to(dev, torch.float16)
|
| 90 |
+
|
| 91 |
+
output = model.generate(**inputs, max_new_tokens=16, pad_token_id=processor.tokenizer.eos_token_id)
|
| 92 |
+
output = processor.decode(output[0], skip_special_tokens=True)
|
| 93 |
+
idx = output.index('assistant\n')
|
| 94 |
+
response = output[idx+10:]
|
| 95 |
+
return(response)
|
CADFusion/src/dpo/make_dpo_dataset.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
from openai_utils import ask_gpt_on_figure, ask_gpt
|
| 7 |
+
from llava_utils import ask_llm, ask_llm_on_figure, restart_model
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
if __name__ == '__main__':
|
| 12 |
+
parser = argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument("--source-data-path", type=str, required=True)
|
| 14 |
+
parser.add_argument("--figure-path", type=str, required=True)
|
| 15 |
+
parser.add_argument("--save-path", type=str, required=True)
|
| 16 |
+
parser.add_argument("--num-samples", type=int, required=True)
|
| 17 |
+
parser.add_argument("--gpu", type=int, default=0)
|
| 18 |
+
parser.add_argument("--score-only", action="store_true", default=False)
|
| 19 |
+
parser.add_argument("--gpt", action="store_true", default=False)
|
| 20 |
+
args = parser.parse_args()
|
| 21 |
+
|
| 22 |
+
source_path = args.source_data_path
|
| 23 |
+
folder_path = args.figure_path
|
| 24 |
+
save_path = args.save_path
|
| 25 |
+
num_samples = args.num_samples
|
| 26 |
+
device=f'cuda:{args.gpu}'
|
| 27 |
+
if args.gpt:
|
| 28 |
+
func1, func2 = ask_gpt_on_figure, ask_gpt
|
| 29 |
+
model = None
|
| 30 |
+
processor = None
|
| 31 |
+
else:
|
| 32 |
+
func1, func2 = ask_llm_on_figure, ask_llm
|
| 33 |
+
model, processor = restart_model(device)
|
| 34 |
+
|
| 35 |
+
with open(source_path, 'r') as f:
|
| 36 |
+
test_data = json.load(f)
|
| 37 |
+
|
| 38 |
+
####### Stage 1 #######
|
| 39 |
+
# for model generations that are able to render pictures,
|
| 40 |
+
# ask gpt to rate the generation quality.
|
| 41 |
+
for data in tqdm(test_data):
|
| 42 |
+
file_id = str(data['index']).zfill(6)
|
| 43 |
+
file = None
|
| 44 |
+
for f in os.listdir(folder_path):
|
| 45 |
+
if f.startswith(file_id):
|
| 46 |
+
file = folder_path + f
|
| 47 |
+
data['figure_path'] = file
|
| 48 |
+
error_cnt = 0
|
| 49 |
+
while 1:
|
| 50 |
+
try:
|
| 51 |
+
data['gpt_label'] = func1(data, model, processor)
|
| 52 |
+
break
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(e)
|
| 55 |
+
if args.gpt:
|
| 56 |
+
time.sleep(3)
|
| 57 |
+
else:
|
| 58 |
+
if error_cnt == 5:
|
| 59 |
+
exit()
|
| 60 |
+
model, processor = restart_model(device)
|
| 61 |
+
error_cnt += 1
|
| 62 |
+
with open(save_path, 'w+') as f:
|
| 63 |
+
json.dump(test_data, f, indent=4)
|
| 64 |
+
|
| 65 |
+
with open(save_path, 'r') as f:
|
| 66 |
+
test_data = json.load(f)
|
| 67 |
+
####### Stage 2 #######
|
| 68 |
+
# clean up the dataset to summarize the generation quality estimation to a numerical score, and
|
| 69 |
+
# remove the failed ones, i.e. the generations that cannot render
|
| 70 |
+
for data in tqdm(test_data):
|
| 71 |
+
if "gpt_label" in data.keys():
|
| 72 |
+
error_cnt = 0
|
| 73 |
+
while 1:
|
| 74 |
+
try:
|
| 75 |
+
score = func2(data, model, processor)
|
| 76 |
+
print(score)
|
| 77 |
+
break
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(e)
|
| 80 |
+
if args.gpt:
|
| 81 |
+
time.sleep(3)
|
| 82 |
+
else:
|
| 83 |
+
if error_cnt == 5:
|
| 84 |
+
exit()
|
| 85 |
+
model, processor = restart_model(device)
|
| 86 |
+
error_cnt += 1
|
| 87 |
+
try:
|
| 88 |
+
data['gpt_score'] = int(score)
|
| 89 |
+
except:
|
| 90 |
+
print(f'ERROR: {score}')
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
saved_data = [data for data in test_data if 'gpt_score' in data.keys()]
|
| 94 |
+
with open(save_path, 'w+') as f:
|
| 95 |
+
json.dump(saved_data, f, indent=4)
|
| 96 |
+
|
| 97 |
+
if args.score_only:
|
| 98 |
+
exit()
|
| 99 |
+
|
| 100 |
+
####### Stage 3 #######
|
| 101 |
+
# 1. group up the scored generations by their description: we do not compare
|
| 102 |
+
# generation results that come from different origin prompts
|
| 103 |
+
temp_data = []
|
| 104 |
+
max_idx = test_data[-1]['index']
|
| 105 |
+
sample_size = max_idx // num_samples + 1
|
| 106 |
+
# a. select if any above 6
|
| 107 |
+
|
| 108 |
+
# for i in range(sample_size):
|
| 109 |
+
# next_sample = test_data[i*num_samples:(i+1)*num_samples]
|
| 110 |
+
# next_sample = [item for item in next_sample if 'gpt_score' in item.keys()]
|
| 111 |
+
# above_score = [item['gpt_score'] >= 6 for item in next_sample]
|
| 112 |
+
# if any(above_score):
|
| 113 |
+
# temp_data.extend(next_sample)
|
| 114 |
+
# temp_data = [data for data in temp_data if 'gpt_score' in data.keys()]
|
| 115 |
+
|
| 116 |
+
# b. select if avg above 6
|
| 117 |
+
|
| 118 |
+
# for i in range(sample_size):
|
| 119 |
+
# next_sample = test_data[i*num_samples:(i+1)*num_samples]
|
| 120 |
+
# next_sample = [item for item in next_sample if 'gpt_score' in item.keys()]
|
| 121 |
+
# if len(next_sample) == 0:
|
| 122 |
+
# continue
|
| 123 |
+
# scores = sum(item['gpt_score'] for item in next_sample) / len(next_sample)
|
| 124 |
+
# if scores >= 6:
|
| 125 |
+
# temp_data.extend(next_sample)
|
| 126 |
+
# temp_data = [data for data in temp_data if 'gpt_score' in data.keys()]
|
| 127 |
+
|
| 128 |
+
# c. select if individual above 6
|
| 129 |
+
test_data = saved_data
|
| 130 |
+
for item in test_data:
|
| 131 |
+
if 'gpt_score' not in item.keys():
|
| 132 |
+
continue
|
| 133 |
+
if item['gpt_score'] >= 6:
|
| 134 |
+
temp_data.append(item)
|
| 135 |
+
print(test_data[-1]['index'], max_idx)
|
| 136 |
+
|
| 137 |
+
grouped = [[] for _ in range(max_idx)]
|
| 138 |
+
for item in temp_data:
|
| 139 |
+
idx = item['index']
|
| 140 |
+
grouped[idx // num_samples].append(item)
|
| 141 |
+
grouped = [item for item in grouped if len(item) > 0]
|
| 142 |
+
|
| 143 |
+
# 2. within each group, make pairs where the chosens have higher score than the rejected ones.
|
| 144 |
+
# TODO: find a way to balance the data generated from each group
|
| 145 |
+
final_data = []
|
| 146 |
+
for group in grouped:
|
| 147 |
+
for item1 in group:
|
| 148 |
+
for item2 in group:
|
| 149 |
+
if item2['gpt_score'] > item1['gpt_score']:
|
| 150 |
+
info_dict = {
|
| 151 |
+
"description": item1['description'],
|
| 152 |
+
"prompt": item1['prompt'],
|
| 153 |
+
"chosen": item2['output'],
|
| 154 |
+
"rejected": item1['output']
|
| 155 |
+
}
|
| 156 |
+
final_data.append(info_dict)
|
| 157 |
+
# uncomment this break if you do not want too many data.
|
| 158 |
+
# break
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
with open(save_path, 'w+') as f:
|
| 162 |
+
json.dump(final_data, f, indent=4)
|
CADFusion/src/dpo/openai_utils.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import base64
|
| 3 |
+
import time
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
from mimetypes import guess_type
|
| 7 |
+
from openai import AzureOpenAI
|
| 8 |
+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
| 9 |
+
|
| 10 |
+
END_POINT = '<endpoint>'
|
| 11 |
+
MODEL_NAME = 'gpt-4o_2024-08-06'
|
| 12 |
+
API_VER = '2024-02-01'
|
| 13 |
+
|
| 14 |
+
def local_image_to_data_url(image_path):
|
| 15 |
+
# Encode a local image into data URL
|
| 16 |
+
mime_type, _ = guess_type(image_path)
|
| 17 |
+
if mime_type is None:
|
| 18 |
+
mime_type = 'application/octet-stream'
|
| 19 |
+
with open(image_path, "rb") as image_file:
|
| 20 |
+
base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
|
| 21 |
+
return f"data:{mime_type};base64,{base64_encoded_data}"
|
| 22 |
+
|
| 23 |
+
def ask_gpt_on_figure(data, _, __):
|
| 24 |
+
endpoint = END_POINT
|
| 25 |
+
token_provider = get_bearer_token_provider(
|
| 26 |
+
DefaultAzureCredential(),
|
| 27 |
+
"https://cognitiveservices.azure.com/.default"
|
| 28 |
+
)
|
| 29 |
+
deployment_name = MODEL_NAME
|
| 30 |
+
|
| 31 |
+
client = AzureOpenAI(
|
| 32 |
+
azure_ad_token_provider=token_provider,
|
| 33 |
+
azure_endpoint=endpoint,
|
| 34 |
+
api_version=API_VER
|
| 35 |
+
)
|
| 36 |
+
description = data['description']
|
| 37 |
+
data_scale = 10
|
| 38 |
+
measurement = 'if the figure corresponds to the given description'
|
| 39 |
+
|
| 40 |
+
prompt = 'The following is a text description of a 3D CAD figure and an image of a CAD instance. ' +\
|
| 41 |
+
f'Measure {measurement}, and give a score in the scale of {data_scale}. Do not comment on issues such as texture, smoothness and colors' +\
|
| 42 |
+
f'\n description: {description}\n'
|
| 43 |
+
image_path = data['figure_path']
|
| 44 |
+
response = client.chat.completions.create(
|
| 45 |
+
model=deployment_name,
|
| 46 |
+
messages=[
|
| 47 |
+
{'role': 'system', 'content': 'You are a helpful assistant'},
|
| 48 |
+
{'role': 'user', 'content': [
|
| 49 |
+
{'type': 'text', 'text': prompt},
|
| 50 |
+
{'type': 'image_url', 'image_url': {'url': local_image_to_data_url(image_path)}},
|
| 51 |
+
]}
|
| 52 |
+
]
|
| 53 |
+
)
|
| 54 |
+
time.sleep(3)
|
| 55 |
+
return(response.choices[0].message.content)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def ask_gpt(data, _, __):
|
| 59 |
+
endpoint = END_POINT
|
| 60 |
+
token_provider = get_bearer_token_provider(
|
| 61 |
+
DefaultAzureCredential(),
|
| 62 |
+
"https://cognitiveservices.azure.com/.default"
|
| 63 |
+
)
|
| 64 |
+
deployment_name = MODEL_NAME
|
| 65 |
+
|
| 66 |
+
client = AzureOpenAI(
|
| 67 |
+
azure_ad_token_provider=token_provider,
|
| 68 |
+
azure_endpoint=endpoint,
|
| 69 |
+
api_version=API_VER
|
| 70 |
+
)
|
| 71 |
+
description = data['gpt_label']
|
| 72 |
+
|
| 73 |
+
prompt = 'The following is an evaluation of an CAD object.' +\
|
| 74 |
+
f'\n evaluation: {description}\n' +\
|
| 75 |
+
'Extract the integer score of the evaluation. The score is between 0 to 10. Return the number only.'
|
| 76 |
+
|
| 77 |
+
response = client.chat.completions.create(
|
| 78 |
+
model=deployment_name,
|
| 79 |
+
messages=[
|
| 80 |
+
{'role': 'system', 'content': 'You are a helpful assistant'},
|
| 81 |
+
{'role': 'user', 'content': [
|
| 82 |
+
{'type': 'text', 'text': prompt},
|
| 83 |
+
]}
|
| 84 |
+
]
|
| 85 |
+
)
|
| 86 |
+
# print(response.choices[0].message.content)
|
| 87 |
+
time.sleep(3)
|
| 88 |
+
return(response.choices[0].message.content)
|
CADFusion/src/rendering_utils/geometry/arc.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import math
|
| 3 |
+
from geometry.curve import Curve
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Arc(Curve):
|
| 7 |
+
def __init__(self, point_indices, point_data, is_outer):
|
| 8 |
+
assert len(point_indices) == 4, "Arc must be defined by 3 points"
|
| 9 |
+
assert point_data is not None
|
| 10 |
+
super(Arc, self).__init__(point_indices, point_data)
|
| 11 |
+
self.type = 'arc'
|
| 12 |
+
self.is_outer = is_outer
|
| 13 |
+
self.start = self.point_geom[0, :]
|
| 14 |
+
self.mid = self.point_geom[1, :]
|
| 15 |
+
self.center = self.point_geom[2, :]
|
| 16 |
+
self.end = self.point_geom[3, :]
|
| 17 |
+
|
| 18 |
+
self.r1 = math.sqrt( (self.start[0] - self.center[0])**2 + (self.start[1] - self.center[1])**2 )
|
| 19 |
+
self.r2 = math.sqrt( (self.end[0] - self.center[0])**2 + (self.end[1] - self.center[1])**2 )
|
| 20 |
+
self.radius = (self.r1+self.r2)/2
|
| 21 |
+
|
| 22 |
+
self.start_idx = point_indices[0]
|
| 23 |
+
self.mid_idx = point_indices[1]
|
| 24 |
+
self.center_idx = point_indices[2]
|
| 25 |
+
self.end_idx = point_indices[3]
|
| 26 |
+
|
| 27 |
+
self.bbox = self.verts_to_bbox(np.vstack([self.start, self.end, self.mid]))
|
| 28 |
+
self.bottom_left = np.array([self.bbox[0], self.bbox[2]])
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
CADFusion/src/rendering_utils/geometry/circle.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from geometry.curve import Curve
|
| 3 |
+
import pdb
|
| 4 |
+
|
| 5 |
+
class Circle(Curve):
|
| 6 |
+
def __init__(self, point_indices, point_data, is_outer):
|
| 7 |
+
assert len(point_indices) == 2, "Circle must be defined by 1 points"
|
| 8 |
+
assert point_data is not None
|
| 9 |
+
super(Circle, self).__init__(point_indices, point_data)
|
| 10 |
+
self.type = 'circle'
|
| 11 |
+
self.center = self.point_geom[0, :]
|
| 12 |
+
self.radius = self.point_geom[1, 0]
|
| 13 |
+
self.center_idx = point_indices[0]
|
| 14 |
+
self.radius_idx = point_indices[1]
|
| 15 |
+
self.is_outer = is_outer
|
| 16 |
+
|
| 17 |
+
self.pt1 = np.array([self.center[0], self.center[1]+self.radius])
|
| 18 |
+
self.pt2 = np.array([self.center[0], self.center[1]-self.radius])
|
| 19 |
+
self.pt3 = np.array([self.center[0]+self.radius, self.center[1]])
|
| 20 |
+
self.pt4 = np.array([self.center[0]-self.radius, self.center[1]])
|
| 21 |
+
self.bbox = self.verts_to_bbox(np.vstack([self.pt1, self.pt2, self.pt3, self.pt4]))
|
| 22 |
+
self.bottom_left = np.array([self.bbox[0], self.bbox[2]])
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
CADFusion/src/rendering_utils/geometry/curve.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Curve():
|
| 5 |
+
def __init__(self, point_indices, point_data):
|
| 6 |
+
self.point_indices = point_indices
|
| 7 |
+
self.point_geom = point_data[point_indices, 0:2]
|
| 8 |
+
|
| 9 |
+
def verts_to_bbox(self, verts):
|
| 10 |
+
xs = [v[0] for v in verts]
|
| 11 |
+
ys = [v[1] for v in verts]
|
| 12 |
+
bbox = [min(xs), max(xs), min(ys), max(ys)]
|
| 13 |
+
return bbox
|
CADFusion/src/rendering_utils/geometry/geom_utils.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
def angle_from_vector_to_x(vec):
|
| 5 |
+
assert vec.size == 2
|
| 6 |
+
# We need to find a unit vector
|
| 7 |
+
angle = 0.0
|
| 8 |
+
|
| 9 |
+
l = np.linalg.norm(vec)
|
| 10 |
+
uvec = vec/l
|
| 11 |
+
|
| 12 |
+
# 2 | 1
|
| 13 |
+
#-------
|
| 14 |
+
# 3 | 4
|
| 15 |
+
if uvec[0] >=0:
|
| 16 |
+
if uvec[1] >= 0:
|
| 17 |
+
# Qadrant 1
|
| 18 |
+
angle = math.asin(uvec[1])
|
| 19 |
+
else:
|
| 20 |
+
# Qadrant 4
|
| 21 |
+
angle = 2.0*math.pi - math.asin(-uvec[1])
|
| 22 |
+
else:
|
| 23 |
+
if vec[1] >= 0:
|
| 24 |
+
# Qadrant 2
|
| 25 |
+
angle = math.pi - math.asin(uvec[1])
|
| 26 |
+
else:
|
| 27 |
+
# Qadrant 3
|
| 28 |
+
angle = math.pi + math.asin(-uvec[1])
|
| 29 |
+
return angle
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def convert_angle_to_1to360_range(angle_rad):
|
| 33 |
+
"""
|
| 34 |
+
Converts the given angle in radians into 1-360 degrees range
|
| 35 |
+
"""
|
| 36 |
+
angle = math.degrees(angle_rad)
|
| 37 |
+
# Lifted from: https://stackoverflow.com/questions/12234574/calculating-if-an-angle-is-between-two-angles
|
| 38 |
+
angle=(int(angle) % 360) + (angle-math.trunc(angle)) # converts angle to range -360 + 360
|
| 39 |
+
if angle > 0.0:
|
| 40 |
+
return angle
|
| 41 |
+
else:
|
| 42 |
+
return angle + 360.0
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def angle_is_between(angle_rad, a_rad, b_rad):
|
| 46 |
+
"""
|
| 47 |
+
Checks if angle is in between the range of a and b
|
| 48 |
+
(All angles must be given in radians)
|
| 49 |
+
"""
|
| 50 |
+
angle = convert_angle_to_1to360_range(angle_rad)
|
| 51 |
+
a = convert_angle_to_1to360_range(a_rad)
|
| 52 |
+
b = convert_angle_to_1to360_range(b_rad)
|
| 53 |
+
if a < b:
|
| 54 |
+
return a <= angle and angle <= b
|
| 55 |
+
return a <= angle or angle <= b
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def quantize_verts(verts, n_bits=8):
|
| 59 |
+
"""Convert vertices in [-1., 1.] to discrete values in [0, n_bits**2 - 1]."""
|
| 60 |
+
min_range = -0.5
|
| 61 |
+
max_range = 0.5
|
| 62 |
+
range_quantize = 2 ** n_bits - 1
|
| 63 |
+
verts_quantize = (verts - min_range) * range_quantize / (max_range - min_range)
|
| 64 |
+
return verts_quantize.astype("int32")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def dequantize_verts(verts, n_bits=8, add_noise=False):
|
| 68 |
+
"""Convert quantized vertices to floats."""
|
| 69 |
+
min_range = -0.5
|
| 70 |
+
max_range = 0.5
|
| 71 |
+
range_quantize = 2 ** n_bits - 1
|
| 72 |
+
verts = verts.astype("float32")
|
| 73 |
+
verts = verts * (max_range - min_range) / range_quantize + min_range
|
| 74 |
+
if add_noise:
|
| 75 |
+
verts += np.random.uniform(size=verts.shape) * (1 / range_quantize)
|
| 76 |
+
return verts
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def center_vertices(vertices):
|
| 80 |
+
"""Translate the vertices so that bounding box is centered at zero."""
|
| 81 |
+
vert_min = vertices.min(axis=0)
|
| 82 |
+
vert_max = vertices.max(axis=0)
|
| 83 |
+
vert_center = 0.5 * (vert_min + vert_max)
|
| 84 |
+
return vertices - vert_center, vert_center
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def scale_vertices(vertices):
|
| 88 |
+
"""Scale the vertices so that the long diagonal of the bounding box is one."""
|
| 89 |
+
vert_min = vertices.min(axis=0)
|
| 90 |
+
vert_max = vertices.max(axis=0)
|
| 91 |
+
extents = vert_max - vert_min
|
| 92 |
+
scale = np.sqrt(np.sum(extents ** 2))
|
| 93 |
+
return vertices / scale, scale
|
| 94 |
+
|
| 95 |
+
|
CADFusion/src/rendering_utils/geometry/line.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from geometry.curve import Curve
|
| 3 |
+
|
| 4 |
+
class Line(Curve):
|
| 5 |
+
def __init__(self, point_indices, point_data, is_outer):
|
| 6 |
+
assert len(point_indices) == 2, "Line must be defined by two points"
|
| 7 |
+
assert point_data is not None
|
| 8 |
+
super(Line, self).__init__(point_indices, point_data)
|
| 9 |
+
pt0 = self.point_geom[0, :]
|
| 10 |
+
pt1 = self.point_geom[1, :]
|
| 11 |
+
self.type = 'line'
|
| 12 |
+
self.start = pt0
|
| 13 |
+
self.end = pt1
|
| 14 |
+
self.start_idx = point_indices[0]
|
| 15 |
+
self.end_idx = point_indices[1]
|
| 16 |
+
self.is_outer = is_outer
|
| 17 |
+
|
| 18 |
+
self.bbox = self.verts_to_bbox(np.vstack([pt0, pt1]))
|
| 19 |
+
self.bottom_left = np.array([self.bbox[0], self.bbox[2]])
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
CADFusion/src/rendering_utils/geometry/obj_parser.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from geometry.arc import Arc
|
| 6 |
+
from geometry.circle import Circle
|
| 7 |
+
from geometry.line import Line
|
| 8 |
+
|
| 9 |
+
from geometry import geom_utils
|
| 10 |
+
import pdb
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class OBJParser:
|
| 14 |
+
"""
|
| 15 |
+
A class to read an OBJ file containing the sketch data
|
| 16 |
+
and hand it back in a form which is easy to work with.
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self, pathname=None):
|
| 19 |
+
self.pathname = pathname
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def convert_vertices(self, vertices):
|
| 23 |
+
"""Convert all the vertices to .obj format"""
|
| 24 |
+
vertex_strings = ""
|
| 25 |
+
for pt in vertices:
|
| 26 |
+
# e.g. v 0.123 0.234 0.345 1.0
|
| 27 |
+
vertex_string = f"v {pt[0]} {pt[1]}\n"
|
| 28 |
+
vertex_strings += vertex_string
|
| 29 |
+
return vertex_strings
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def convert_curves(self, faces):
|
| 33 |
+
curve_strings = ""
|
| 34 |
+
total_curve = 0
|
| 35 |
+
|
| 36 |
+
# Faces (multiple closed regions)
|
| 37 |
+
for group_idx, loops in enumerate(faces):
|
| 38 |
+
curve_strings += f"\nface\n"
|
| 39 |
+
# Multiple loops (inner and outer)
|
| 40 |
+
for loop in loops:
|
| 41 |
+
if loop[0].is_outer:
|
| 42 |
+
curve_strings += f"out\n"
|
| 43 |
+
else:
|
| 44 |
+
curve_strings += f"in\n"
|
| 45 |
+
# All curves in one loop
|
| 46 |
+
for curve in loop:
|
| 47 |
+
total_curve += 1
|
| 48 |
+
if curve.type == 'line':
|
| 49 |
+
curve_strings += f"l {curve.start_idx} {curve.end_idx}\n"
|
| 50 |
+
elif curve.type == 'circle':
|
| 51 |
+
curve_strings += f"c {curve.center_idx} {curve.radius_idx}\n"
|
| 52 |
+
elif curve.type == 'arc':
|
| 53 |
+
curve_strings += f"a {curve.start_idx} {curve.mid_idx} {curve.center_idx} {curve.end_idx}\n"
|
| 54 |
+
|
| 55 |
+
return curve_strings, total_curve
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def parse3d(self, point3d):
|
| 59 |
+
x = point3d[0]
|
| 60 |
+
y = point3d[1]
|
| 61 |
+
z = point3d[2]
|
| 62 |
+
return str(x)+' '+str(y)+' '+str(z)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def write_obj2(self, file, vertices, faces, meta_info, scale=None):
|
| 66 |
+
""" Write to .obj file """
|
| 67 |
+
vertex_strings = self.convert_vertices(vertices)
|
| 68 |
+
curve_strings, total_curve = self.convert_curves(faces)
|
| 69 |
+
|
| 70 |
+
with open(file, "w") as fh:
|
| 71 |
+
# Write Meta info
|
| 72 |
+
fh.write("# WaveFront *.obj file\n")
|
| 73 |
+
fh.write(f"# Vertices: {len(vertices)}\n")
|
| 74 |
+
fh.write(f"# Curves: {total_curve}\n")
|
| 75 |
+
fh.write("\n")
|
| 76 |
+
|
| 77 |
+
# Write vertex and curve
|
| 78 |
+
fh.write(vertex_strings)
|
| 79 |
+
fh.write("\n")
|
| 80 |
+
fh.write(curve_strings)
|
| 81 |
+
fh.write("\n")
|
| 82 |
+
|
| 83 |
+
#Write extrude value
|
| 84 |
+
fh.write("ExtrudeOperation: " + meta_info['set_op']+"\n")
|
| 85 |
+
extrude_string = 'Extrude '
|
| 86 |
+
for value in meta_info['extrude_value']:
|
| 87 |
+
extrude_string += str(value)+' '
|
| 88 |
+
fh.write(extrude_string)
|
| 89 |
+
fh.write("\n")
|
| 90 |
+
|
| 91 |
+
#Write refe plane transformation
|
| 92 |
+
p_orig = self.parse3d(meta_info['t_orig'])
|
| 93 |
+
x_axis = self.parse3d(meta_info['t_x'])
|
| 94 |
+
y_axis = self.parse3d(meta_info['t_y'])
|
| 95 |
+
z_axis = self.parse3d(meta_info['t_z'])
|
| 96 |
+
fh.write('T_origin '+p_orig)
|
| 97 |
+
fh.write("\n")
|
| 98 |
+
fh.write('T_xaxis '+x_axis)
|
| 99 |
+
fh.write("\n")
|
| 100 |
+
fh.write('T_yaxis '+y_axis)
|
| 101 |
+
fh.write("\n")
|
| 102 |
+
fh.write('T_zaxis '+z_axis)
|
| 103 |
+
fh.write("\n")
|
| 104 |
+
|
| 105 |
+
# Normalized object
|
| 106 |
+
if scale is not None:
|
| 107 |
+
fh.write('Scale '+str(scale))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def write_obj(self, file, curve_strings, total_curve, vertex_strings, total_v, meta_info, scale=None):
|
| 111 |
+
""" Write to .obj file """
|
| 112 |
+
#vertex_strings = self.convert_vertices(vertices)
|
| 113 |
+
#curve_strings, total_curve = self.convert_curves(faces)
|
| 114 |
+
|
| 115 |
+
with open(file, "w") as fh:
|
| 116 |
+
# Write Meta info
|
| 117 |
+
fh.write("# WaveFront *.obj file\n")
|
| 118 |
+
fh.write(f"# Vertices: {total_v}\n")
|
| 119 |
+
fh.write(f"# Curves: {total_curve}\n")
|
| 120 |
+
fh.write("\n")
|
| 121 |
+
|
| 122 |
+
# Write vertex and curve
|
| 123 |
+
fh.write(vertex_strings)
|
| 124 |
+
fh.write("\n")
|
| 125 |
+
fh.write(curve_strings)
|
| 126 |
+
fh.write("\n")
|
| 127 |
+
|
| 128 |
+
#Write extrude value
|
| 129 |
+
fh.write("ExtrudeOperation: " + meta_info['set_op']+"\n")
|
| 130 |
+
extrude_string = 'Extrude '
|
| 131 |
+
for value in meta_info['extrude_value']:
|
| 132 |
+
extrude_string += str(value)+' '
|
| 133 |
+
fh.write(extrude_string)
|
| 134 |
+
fh.write("\n")
|
| 135 |
+
|
| 136 |
+
#Write refe plane transformation
|
| 137 |
+
p_orig = self.parse3d(meta_info['t_orig'])
|
| 138 |
+
x_axis = self.parse3d(meta_info['t_x'])
|
| 139 |
+
y_axis = self.parse3d(meta_info['t_y'])
|
| 140 |
+
z_axis = self.parse3d(meta_info['t_z'])
|
| 141 |
+
fh.write('T_origin '+p_orig)
|
| 142 |
+
fh.write("\n")
|
| 143 |
+
fh.write('T_xaxis '+x_axis)
|
| 144 |
+
fh.write("\n")
|
| 145 |
+
fh.write('T_yaxis '+y_axis)
|
| 146 |
+
fh.write("\n")
|
| 147 |
+
fh.write('T_zaxis '+z_axis)
|
| 148 |
+
fh.write("\n")
|
| 149 |
+
|
| 150 |
+
# Normalized object
|
| 151 |
+
if scale is not None:
|
| 152 |
+
fh.write('Scale '+str(scale))
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def parse_file(self, scale=1.0):
|
| 156 |
+
"""
|
| 157 |
+
Parse obj file
|
| 158 |
+
Return
|
| 159 |
+
vertex 2D location numpy
|
| 160 |
+
curve list (geometry class)
|
| 161 |
+
extrude parameters
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
assert self.pathname is not None, "File is None"
|
| 165 |
+
assert self.pathname.exists(), "No such file"
|
| 166 |
+
|
| 167 |
+
# Parse file
|
| 168 |
+
vertex_list = []
|
| 169 |
+
loops = []
|
| 170 |
+
closed_loop = []
|
| 171 |
+
|
| 172 |
+
# Read vertice
|
| 173 |
+
with open(self.pathname) as obj_file:
|
| 174 |
+
for line in obj_file:
|
| 175 |
+
tokens = line.split()
|
| 176 |
+
if not tokens:
|
| 177 |
+
continue
|
| 178 |
+
line_type = tokens[0]
|
| 179 |
+
# Vertex
|
| 180 |
+
if line_type == "v":
|
| 181 |
+
vertex_list.append([float(x) for x in tokens[1:]])
|
| 182 |
+
vertices = np.array(vertex_list, dtype=np.float64) * scale
|
| 183 |
+
|
| 184 |
+
# Read curves
|
| 185 |
+
faces = []
|
| 186 |
+
loops = []
|
| 187 |
+
loop = []
|
| 188 |
+
|
| 189 |
+
# Read in all lines
|
| 190 |
+
lines = []
|
| 191 |
+
with open(self.pathname) as obj_file:
|
| 192 |
+
for line in obj_file:
|
| 193 |
+
lines.append(line)
|
| 194 |
+
|
| 195 |
+
# Parse all lines
|
| 196 |
+
faces = []
|
| 197 |
+
for str_idx, line in enumerate(lines):
|
| 198 |
+
tokens = line.split()
|
| 199 |
+
if not tokens:
|
| 200 |
+
continue
|
| 201 |
+
line_type = tokens[0]
|
| 202 |
+
|
| 203 |
+
# Start of a new face
|
| 204 |
+
if line_type == "face":
|
| 205 |
+
faces.append(self.read_face(lines, str_idx+1, vertices))
|
| 206 |
+
|
| 207 |
+
# Read meta data
|
| 208 |
+
meta_data = line.strip('# ').strip(' \n').split(' ')
|
| 209 |
+
meta_name = meta_data[0]
|
| 210 |
+
|
| 211 |
+
if meta_name == 'Extrude':
|
| 212 |
+
extrude_values = [float(x) for x in meta_data[1:]]
|
| 213 |
+
extrude_values = [x*scale for x in extrude_values]
|
| 214 |
+
elif meta_name == 'T_origin':
|
| 215 |
+
t_orig = [float(x) for x in meta_data[1:]]
|
| 216 |
+
t_orig = [x*scale for x in t_orig]
|
| 217 |
+
elif meta_name == 'T_xaxis':
|
| 218 |
+
t_x = [float(x) for x in meta_data[1:]]
|
| 219 |
+
elif meta_name == 'T_yaxis':
|
| 220 |
+
t_y = [float(x) for x in meta_data[1:]]
|
| 221 |
+
elif meta_name == 'T_zaxis':
|
| 222 |
+
t_z = [float(x) for x in meta_data[1:]]
|
| 223 |
+
elif meta_name == 'ExtrudeOperation:':
|
| 224 |
+
set_op = meta_data[1]
|
| 225 |
+
|
| 226 |
+
meta_info = {'extrude_value': extrude_values,
|
| 227 |
+
'set_op': set_op,
|
| 228 |
+
't_orig': t_orig,
|
| 229 |
+
't_x': t_x,
|
| 230 |
+
't_y': t_y,
|
| 231 |
+
't_z': t_z,
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
return vertices, faces, meta_info
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def read_face(self, lines, str_idx, vertices):
|
| 239 |
+
loops = []
|
| 240 |
+
loop = []
|
| 241 |
+
for line in lines[str_idx:]:
|
| 242 |
+
tokens = line.split()
|
| 243 |
+
if not tokens:
|
| 244 |
+
continue
|
| 245 |
+
line_type = tokens[0]
|
| 246 |
+
|
| 247 |
+
if line_type == 'face':
|
| 248 |
+
break
|
| 249 |
+
|
| 250 |
+
# Start of a new loop
|
| 251 |
+
if line_type == "out" or line_type == "in":
|
| 252 |
+
if len(loop) > 0:
|
| 253 |
+
loops.append(loop)
|
| 254 |
+
loop = []
|
| 255 |
+
is_outer = (line_type == 'out')
|
| 256 |
+
|
| 257 |
+
# Line
|
| 258 |
+
if line_type == 'l':
|
| 259 |
+
c_tok = tokens[1:]
|
| 260 |
+
curve = Line([int(c_tok[0]), int(c_tok[1])], vertices, is_outer=is_outer)
|
| 261 |
+
loop.append(curve)
|
| 262 |
+
|
| 263 |
+
# Arc
|
| 264 |
+
if line_type == 'a':
|
| 265 |
+
c_tok = tokens[1:]
|
| 266 |
+
curve = Arc([int(c_tok[0]), int(c_tok[1]), int(c_tok[2]), int(c_tok[3])], vertices, is_outer=is_outer)
|
| 267 |
+
loop.append(curve)
|
| 268 |
+
|
| 269 |
+
# Circle
|
| 270 |
+
if line_type == 'c':
|
| 271 |
+
c_tok = tokens[1:]
|
| 272 |
+
curve = Circle([int(c_tok[0]), int(c_tok[1])], vertices, is_outer=is_outer)
|
| 273 |
+
loop.append(curve)
|
| 274 |
+
|
| 275 |
+
loops.append(loop)
|
| 276 |
+
return loops
|
CADFusion/src/rendering_utils/geometry/obj_utils.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import pdb
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def read_wire_obj(obj_path):
|
| 8 |
+
"""Read vertices and lines from .obj file defining a wire body."""
|
| 9 |
+
vertex_list = []
|
| 10 |
+
loops = []
|
| 11 |
+
|
| 12 |
+
# Read vertice and curves
|
| 13 |
+
with open(obj_path) as obj_file:
|
| 14 |
+
|
| 15 |
+
for line in obj_file:
|
| 16 |
+
tokens = line.split()
|
| 17 |
+
if not tokens:
|
| 18 |
+
continue
|
| 19 |
+
|
| 20 |
+
line_type = tokens[0]
|
| 21 |
+
|
| 22 |
+
if line_type == "v":
|
| 23 |
+
vertex_list.append([float(x) for x in tokens[1:]])
|
| 24 |
+
|
| 25 |
+
if line_type == "g":
|
| 26 |
+
pdb.set_trace()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Read meta data
|
| 32 |
+
meta_data = line.strip('# ').strip(' \n').split(' ')
|
| 33 |
+
meta_name = meta_data[0]
|
| 34 |
+
if meta_name == 'Extrude':
|
| 35 |
+
extrude_values= [float(x) for x in meta_data[1:]]
|
| 36 |
+
elif meta_name == 'T_origin':
|
| 37 |
+
t_orig = [float(x) for x in meta_data[1:]]
|
| 38 |
+
elif meta_name == 'T_xaxis':
|
| 39 |
+
t_x = [float(x) for x in meta_data[1:]]
|
| 40 |
+
elif meta_name == 'T_yaxis':
|
| 41 |
+
t_y = [float(x) for x in meta_data[1:]]
|
| 42 |
+
elif meta_name == 'T_zaxis':
|
| 43 |
+
t_z = [float(x) for x in meta_data[1:]]
|
| 44 |
+
elif meta_name == 'ExtrudeOperation:':
|
| 45 |
+
set_op = meta_data[1]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
vertices = np.array(vertex_list)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
meta_info = {'extrude_value': extrude_values,
|
| 53 |
+
'set_op': set_op,
|
| 54 |
+
't_orig': t_orig,
|
| 55 |
+
't_x': t_x,
|
| 56 |
+
't_y': t_y,
|
| 57 |
+
't_z': t_z}
|
| 58 |
+
|
| 59 |
+
total_in_outs.append(in_outs)
|
| 60 |
+
|
| 61 |
+
return np.array(flat_vertices_list, dtype=np.float32), flat_hyperedge, total_in_outs, meta_info
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def write_wire_obj(vertices, faces, file_path, transpose=True, scale=1.0):
|
| 65 |
+
"""Write vertices and hyperedges to obj."""
|
| 66 |
+
vertex_dimension = vertices.shape[1]
|
| 67 |
+
assert vertex_dimension in (2, 3)
|
| 68 |
+
if transpose and vertex_dimension == 3:
|
| 69 |
+
# Permute 3D vertices where z comes first followed by x and y
|
| 70 |
+
vertices = vertices[:, [1, 2, 0]]
|
| 71 |
+
vertices *= scale
|
| 72 |
+
if faces is not None:
|
| 73 |
+
if len(faces) > 0:
|
| 74 |
+
if min(min(faces)) == 0:
|
| 75 |
+
f_add = 1
|
| 76 |
+
else:
|
| 77 |
+
f_add = 0
|
| 78 |
+
with open(file_path, "w") as f:
|
| 79 |
+
for v in vertices:
|
| 80 |
+
if vertex_dimension == 2:
|
| 81 |
+
f.write("v {} {} {}\n".format(v[0], v[1], 0.0))
|
| 82 |
+
else:
|
| 83 |
+
f.write("v {} {} {}\n".format(v[0], v[1], v[2]))
|
| 84 |
+
for face in faces:
|
| 85 |
+
line = "l"
|
| 86 |
+
for i in face:
|
| 87 |
+
# Pradeep: always adding 1 to the face index makes sense to me. Not sure why
|
| 88 |
+
# PolyGen does this conditionally (see L95 above)
|
| 89 |
+
# Something to note.
|
| 90 |
+
line += " {}".format(i + 1)
|
| 91 |
+
line += "\n"
|
| 92 |
+
f.write(line)
|
| 93 |
+
|
CADFusion/src/rendering_utils/img_renderer.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from OCC.Core.Graphic3d import *
|
| 3 |
+
from OCC.Display.OCCViewer import Viewer3d
|
| 4 |
+
from OCC.Extend.DataExchange import read_step_file
|
| 5 |
+
from OCC.Extend.TopologyUtils import TopologyExplorer
|
| 6 |
+
from OCC.Core.Quantity import Quantity_Color, Quantity_TOC_RGB, Quantity_NOC_WHITE
|
| 7 |
+
from OCC.Core.V3d import V3d_DirectionalLight
|
| 8 |
+
from OCC.Core.gp import gp_Dir
|
| 9 |
+
from glob import glob
|
| 10 |
+
import pathlib
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def render(shape, filename, width=1024, height=768, face_color_rgb=(0.2, 0.2, 0.2), edge_color_rgb=(0, 0, 0), show_face_boundary=True):
|
| 15 |
+
viewer = Viewer3d()
|
| 16 |
+
viewer.Create(phong_shading=True, create_default_lights=True)
|
| 17 |
+
viewer.set_bg_gradient_color([255, 255, 255], [255, 255, 255])
|
| 18 |
+
viewer.SetModeShaded()
|
| 19 |
+
viewer.hide_triedron()
|
| 20 |
+
viewer.EnableAntiAliasing()
|
| 21 |
+
dir_light = V3d_DirectionalLight(gp_Dir(0, 0.5, -1), Quantity_Color(Quantity_NOC_WHITE))
|
| 22 |
+
dir_light.SetEnabled(True)
|
| 23 |
+
dir_light.SetIntensity(500.0)
|
| 24 |
+
viewer.Viewer.AddLight(dir_light)
|
| 25 |
+
viewer.Viewer.SetLightOn()
|
| 26 |
+
|
| 27 |
+
viewer.default_drawer.EnableDrawHiddenLine()
|
| 28 |
+
viewer.default_drawer.SetFaceBoundaryDraw(show_face_boundary)
|
| 29 |
+
ais_context = viewer.GetContext()
|
| 30 |
+
dc = ais_context.DeviationCoefficient()
|
| 31 |
+
da = ais_context.DeviationAngle()
|
| 32 |
+
factor = 10
|
| 33 |
+
ais_context.SetDeviationCoefficient(dc / factor)
|
| 34 |
+
ais_context.SetDeviationAngle(da / factor)
|
| 35 |
+
topexp = TopologyExplorer(shape)
|
| 36 |
+
for face in topexp.faces():
|
| 37 |
+
if face is not None:
|
| 38 |
+
viewer.DisplayShape(face, color=Quantity_Color(*face_color_rgb, Quantity_TOC_RGB))
|
| 39 |
+
for edge in topexp.edges():
|
| 40 |
+
if edge is not None:
|
| 41 |
+
viewer.DisplayShape(edge, color=Quantity_Color(*edge_color_rgb, Quantity_TOC_RGB))
|
| 42 |
+
viewer.FitAll()
|
| 43 |
+
viewer.SetSize(width, height)
|
| 44 |
+
viewer.View.Dump(str(filename))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def main():
|
| 48 |
+
p = argparse.ArgumentParser()
|
| 49 |
+
p.add_argument("--input_dir", type=str, required=True, help="Input folder of STP/STEP files")
|
| 50 |
+
p.add_argument("--output_dir", type=str, required=True, help="Output folder of PNG files")
|
| 51 |
+
p.add_argument("--width", type=int, default=1024, help="Width of image")
|
| 52 |
+
p.add_argument("--height", type=int, default=768, help="Height of image")
|
| 53 |
+
|
| 54 |
+
args = p.parse_args()
|
| 55 |
+
|
| 56 |
+
files = []
|
| 57 |
+
cad_folders = sorted(glob(args.input_dir+'/*/'))
|
| 58 |
+
for folder in cad_folders:
|
| 59 |
+
input_path = pathlib.Path(folder)
|
| 60 |
+
files += list(input_path.glob("*.st*p"))
|
| 61 |
+
print(len(files))
|
| 62 |
+
# files = files[36000:] # debug only (* remove *)
|
| 63 |
+
output_path = pathlib.Path(args.output_dir)
|
| 64 |
+
if not output_path.exists():
|
| 65 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 66 |
+
|
| 67 |
+
i = 0
|
| 68 |
+
j = 0
|
| 69 |
+
for fn in tqdm(files):
|
| 70 |
+
j += 1
|
| 71 |
+
try:
|
| 72 |
+
shape = read_step_file(str(fn))
|
| 73 |
+
# render(shape, output_path.joinpath(f'{j:06d}' + ".png"), args.width, args.height)
|
| 74 |
+
render(shape, output_path.joinpath(fn.stem[:6] + ".png"), args.width, args.height)
|
| 75 |
+
except Exception as e:
|
| 76 |
+
i += 1
|
| 77 |
+
# raise e
|
| 78 |
+
print(e)
|
| 79 |
+
continue
|
| 80 |
+
print("error number: ", i)
|
| 81 |
+
print("total number: ", j)
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
main()
|
CADFusion/src/rendering_utils/parser.py
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
import re
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
# hyperparameters from SkexGen project
|
| 11 |
+
SKETCH_R = 1
|
| 12 |
+
RADIUS_R = 1
|
| 13 |
+
EXTRUDE_R = 1.0
|
| 14 |
+
SCALE_R = 1.4
|
| 15 |
+
OFFSET_R = 0.9
|
| 16 |
+
PIX_PAD = 4
|
| 17 |
+
CMD_PAD = 3
|
| 18 |
+
COORD_PAD = 4
|
| 19 |
+
EXT_PAD = 1
|
| 20 |
+
EXTRA_PAD = 1
|
| 21 |
+
R_PAD = 2
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class CADparser:
|
| 25 |
+
"""Parse CAD sequence to CAD object."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, bit):
|
| 28 |
+
self.vertex_dict = OrderedDict()
|
| 29 |
+
self.bit = bit
|
| 30 |
+
|
| 31 |
+
def perform(self, cad_seq):
|
| 32 |
+
# divide into sketch and extrude
|
| 33 |
+
sketches, extrudes = self.get_SE(cad_seq)
|
| 34 |
+
if sketches is None or extrudes is None:
|
| 35 |
+
return None
|
| 36 |
+
# sequentially parse each pair of SE into obj
|
| 37 |
+
se_datas = []
|
| 38 |
+
for sketch, extrude in zip(sketches, extrudes):
|
| 39 |
+
extrude_param, scale, offset = self.parse_extrude(extrude)
|
| 40 |
+
if extrude_param is None or scale is None or offset is None:
|
| 41 |
+
return None
|
| 42 |
+
vertex_str, se_str = self.parse_sketch(sketch, scale, offset)
|
| 43 |
+
if vertex_str is None or se_str is None:
|
| 44 |
+
return None
|
| 45 |
+
se_datas.append(
|
| 46 |
+
{"vertex": vertex_str, "curve": se_str, "extrude": extrude_param}
|
| 47 |
+
)
|
| 48 |
+
self.vertex_dict.clear()
|
| 49 |
+
|
| 50 |
+
return se_datas
|
| 51 |
+
|
| 52 |
+
def parse_sketch(self, sketch, scale, offset):
|
| 53 |
+
faces = self.get_faces(sketch)
|
| 54 |
+
if len(faces) == 0:
|
| 55 |
+
return None, None
|
| 56 |
+
se_str = ""
|
| 57 |
+
for face_idx, face in enumerate(faces): # each face
|
| 58 |
+
face_str = "face\n"
|
| 59 |
+
loops = self.get_loops(face)
|
| 60 |
+
if len(loops) == 0:
|
| 61 |
+
return None, None
|
| 62 |
+
for loop_idx, loop in enumerate(loops): # each loop
|
| 63 |
+
curves = self.get_curves(loop)
|
| 64 |
+
if len(curves) == 0:
|
| 65 |
+
return None, None
|
| 66 |
+
next_curves = curves[1:]
|
| 67 |
+
next_curves += curves[:1]
|
| 68 |
+
cur_str = []
|
| 69 |
+
for curve, next_curve in zip(curves, next_curves): # each curve
|
| 70 |
+
if not self.obj_curve(curve, next_curve, cur_str, scale, offset):
|
| 71 |
+
return None, None
|
| 72 |
+
loop_str = ""
|
| 73 |
+
for c in cur_str:
|
| 74 |
+
loop_str += f"{c}\n"
|
| 75 |
+
if loop_idx == 0:
|
| 76 |
+
face_str += f"out\n{loop_str}\n"
|
| 77 |
+
else:
|
| 78 |
+
face_str += f"in\n{loop_str}\n"
|
| 79 |
+
se_str += face_str
|
| 80 |
+
vertex_str = self.convert_vertices()
|
| 81 |
+
return vertex_str, se_str
|
| 82 |
+
|
| 83 |
+
def parse_extrude(self, extrude):
|
| 84 |
+
ext = extrude.split(",")
|
| 85 |
+
if len(ext) != 18:
|
| 86 |
+
return None, None, None
|
| 87 |
+
|
| 88 |
+
# operation str to int
|
| 89 |
+
ext_op = {"add": 1, "cut": 2, "intersect": 3}.get(ext[0], None)
|
| 90 |
+
if ext_op is None:
|
| 91 |
+
return None, None, None
|
| 92 |
+
# dequantize ext_v, ext_T, scale and offset
|
| 93 |
+
ext_v, ext_T, scale, offset = self.dequantize_extrude_params(ext)
|
| 94 |
+
# get ext_R
|
| 95 |
+
ext_R = np.array(ext[6:15], dtype=int)
|
| 96 |
+
|
| 97 |
+
extrude_param = {"value": ext_v, "T": ext_T, "R": ext_R, "op": ext_op}
|
| 98 |
+
return extrude_param, scale, offset
|
| 99 |
+
|
| 100 |
+
def obj_curve(self, curve, next_curve, cur_str, scale, offset):
|
| 101 |
+
cur = curve.split(",")
|
| 102 |
+
next_cur = next_curve.split(",")
|
| 103 |
+
if cur[0] == "circle":
|
| 104 |
+
if len(cur) != 9:
|
| 105 |
+
return False
|
| 106 |
+
p1, p2, p3, p4 = self.dequantize_circle_points(
|
| 107 |
+
cur, next_cur, scale, offset)
|
| 108 |
+
center = np.asarray([0.5 * (p1[0] + p2[0]), 0.5 * (p3[1] + p4[1])])
|
| 109 |
+
radius = (np.linalg.norm(p1 - p2) + np.linalg.norm(p3 - p4)) / 4.0
|
| 110 |
+
|
| 111 |
+
center = center * scale + offset
|
| 112 |
+
radius = radius * scale
|
| 113 |
+
|
| 114 |
+
center_idx = self.save_vertex(center[0], center[1], "p")
|
| 115 |
+
radius_idx = self.save_vertex(radius, 0.0, "r")
|
| 116 |
+
cur_str.append(f"c {center_idx} {radius_idx}")
|
| 117 |
+
elif cur[0] == "arc":
|
| 118 |
+
if len(cur) != 5:
|
| 119 |
+
return False
|
| 120 |
+
if (
|
| 121 |
+
cur[1:3] == cur[3:5]
|
| 122 |
+
or cur[1:3] == next_cur[1:3]
|
| 123 |
+
or cur[3:5] == next_cur[3:5]
|
| 124 |
+
): # invalid arc
|
| 125 |
+
return False
|
| 126 |
+
start_v, mid_v, end_v = self.dequantize_arc_points(
|
| 127 |
+
cur, next_cur, scale, offset
|
| 128 |
+
)
|
| 129 |
+
try:
|
| 130 |
+
center, _, _, _ = find_arc_geometry(start_v, mid_v, end_v)
|
| 131 |
+
except Exception:
|
| 132 |
+
return False
|
| 133 |
+
start_v = start_v * scale + offset
|
| 134 |
+
mid_v = mid_v * scale + offset
|
| 135 |
+
end_v = end_v * scale + offset
|
| 136 |
+
center = center * scale + offset
|
| 137 |
+
|
| 138 |
+
center_idx = self.save_vertex(center[0], center[1], "p")
|
| 139 |
+
start_idx = self.save_vertex(start_v[0], start_v[1], "p")
|
| 140 |
+
mid_idx = self.save_vertex(mid_v[0], mid_v[1], "p")
|
| 141 |
+
end_idx = self.save_vertex(end_v[0], end_v[1], "p")
|
| 142 |
+
cur_str.append(f"a {start_idx} {mid_idx} {center_idx} {end_idx}")
|
| 143 |
+
elif cur[0] == "line":
|
| 144 |
+
if len(cur) != 3:
|
| 145 |
+
return False
|
| 146 |
+
if cur[1:3] == next_cur[1:3]:
|
| 147 |
+
return False
|
| 148 |
+
start_v, end_v = self.dequantize_line_points(
|
| 149 |
+
cur, next_cur, scale, offset)
|
| 150 |
+
start_v = start_v * scale + offset
|
| 151 |
+
end_v = end_v * scale + offset
|
| 152 |
+
|
| 153 |
+
start_idx = self.save_vertex(start_v[0], start_v[1], "p")
|
| 154 |
+
end_idx = self.save_vertex(end_v[0], end_v[1], "p")
|
| 155 |
+
cur_str.append(f"l {start_idx} {end_idx}")
|
| 156 |
+
else:
|
| 157 |
+
return False
|
| 158 |
+
return True
|
| 159 |
+
|
| 160 |
+
def get_SE(self, cad_seq):
|
| 161 |
+
# sketches: 1) between sequence start and sketch_end,
|
| 162 |
+
sketches_from_start = re.findall(r"^(.+?)(?=<sketch_end>)", cad_seq)
|
| 163 |
+
# sketches: 2) between extrude_end and sketch_end
|
| 164 |
+
sketches_after_extrude = re.findall(
|
| 165 |
+
r"(?<=<extrude_end>)(.+?)(?=<sketch_end>)", cad_seq
|
| 166 |
+
)
|
| 167 |
+
sketches = [x.strip() for x in sketches_from_start] + [
|
| 168 |
+
x.strip() for x in sketches_after_extrude
|
| 169 |
+
]
|
| 170 |
+
# extrudes: between sketch_end and extrude_end
|
| 171 |
+
extrudes = [
|
| 172 |
+
x.strip() for x in re.findall(r"<sketch_end>(.+?)<extrude_end>", cad_seq)
|
| 173 |
+
]
|
| 174 |
+
if len(sketches) != len(extrudes):
|
| 175 |
+
return None, None
|
| 176 |
+
return sketches, extrudes
|
| 177 |
+
|
| 178 |
+
def get_faces(self, sketch):
|
| 179 |
+
faces = sketch.split("<face_end>")
|
| 180 |
+
return [x.strip() for x in faces if x.strip() != ""]
|
| 181 |
+
|
| 182 |
+
def get_loops(self, face):
|
| 183 |
+
loops = face.split("<loop_end>")
|
| 184 |
+
return [x.strip() for x in loops if x.strip() != ""]
|
| 185 |
+
|
| 186 |
+
def get_curves(self, loop):
|
| 187 |
+
curves = loop.split("<curve_end>")
|
| 188 |
+
return [x.strip() for x in curves if x.strip() != ""]
|
| 189 |
+
|
| 190 |
+
def dequantize_circle_points(self, curve, next_curve, scale, offset):
|
| 191 |
+
p1 = dequantize_verts(
|
| 192 |
+
np.array(curve[1:3], dtype=int),
|
| 193 |
+
n_bits=self.bit,
|
| 194 |
+
min_range=-SKETCH_R,
|
| 195 |
+
max_range=SKETCH_R,
|
| 196 |
+
add_noise=False,
|
| 197 |
+
)
|
| 198 |
+
p2 = dequantize_verts(
|
| 199 |
+
np.array(curve[3:5], dtype=int),
|
| 200 |
+
n_bits=self.bit,
|
| 201 |
+
min_range=-SKETCH_R,
|
| 202 |
+
max_range=SKETCH_R,
|
| 203 |
+
add_noise=False,
|
| 204 |
+
)
|
| 205 |
+
p3 = dequantize_verts(
|
| 206 |
+
np.array(curve[5:7], dtype=int),
|
| 207 |
+
n_bits=self.bit,
|
| 208 |
+
min_range=-SKETCH_R,
|
| 209 |
+
max_range=SKETCH_R,
|
| 210 |
+
add_noise=False,
|
| 211 |
+
)
|
| 212 |
+
p4 = dequantize_verts(
|
| 213 |
+
np.array(curve[7:9], dtype=int),
|
| 214 |
+
n_bits=self.bit,
|
| 215 |
+
min_range=-SKETCH_R,
|
| 216 |
+
max_range=SKETCH_R,
|
| 217 |
+
add_noise=False,
|
| 218 |
+
)
|
| 219 |
+
return p1, p2, p3, p4
|
| 220 |
+
|
| 221 |
+
def dequantize_arc_points(self, curve, next_curve, scale, offset):
|
| 222 |
+
start_v = dequantize_verts(
|
| 223 |
+
np.array(curve[1:3], dtype=int),
|
| 224 |
+
n_bits=self.bit,
|
| 225 |
+
min_range=-SKETCH_R,
|
| 226 |
+
max_range=SKETCH_R,
|
| 227 |
+
add_noise=False,
|
| 228 |
+
)
|
| 229 |
+
mid_v = dequantize_verts(
|
| 230 |
+
np.array(curve[3:5], dtype=int),
|
| 231 |
+
n_bits=self.bit,
|
| 232 |
+
min_range=-SKETCH_R,
|
| 233 |
+
max_range=SKETCH_R,
|
| 234 |
+
add_noise=False,
|
| 235 |
+
)
|
| 236 |
+
end_v = dequantize_verts(
|
| 237 |
+
np.array(next_curve[1:3], dtype=int),
|
| 238 |
+
n_bits=self.bit,
|
| 239 |
+
min_range=-SKETCH_R,
|
| 240 |
+
max_range=SKETCH_R,
|
| 241 |
+
add_noise=False,
|
| 242 |
+
)
|
| 243 |
+
return start_v, mid_v, end_v
|
| 244 |
+
|
| 245 |
+
def dequantize_line_points(self, curve, next_curve, scale, offset):
|
| 246 |
+
start_v = dequantize_verts(
|
| 247 |
+
np.array(curve[1:3], dtype=int),
|
| 248 |
+
n_bits=self.bit,
|
| 249 |
+
min_range=-SKETCH_R,
|
| 250 |
+
max_range=SKETCH_R,
|
| 251 |
+
add_noise=False,
|
| 252 |
+
)
|
| 253 |
+
end_v = dequantize_verts(
|
| 254 |
+
np.array(next_curve[1:3], dtype=int),
|
| 255 |
+
n_bits=self.bit,
|
| 256 |
+
min_range=-SKETCH_R,
|
| 257 |
+
max_range=SKETCH_R,
|
| 258 |
+
add_noise=False,
|
| 259 |
+
)
|
| 260 |
+
return start_v, end_v
|
| 261 |
+
|
| 262 |
+
def dequantize_extrude_params(self, extrude):
|
| 263 |
+
ext_v = dequantize_verts(
|
| 264 |
+
np.array(extrude[1:3], dtype=int),
|
| 265 |
+
n_bits=self.bit,
|
| 266 |
+
min_range=-EXTRUDE_R,
|
| 267 |
+
max_range=EXTRUDE_R,
|
| 268 |
+
add_noise=False,
|
| 269 |
+
)
|
| 270 |
+
ext_T = dequantize_verts(
|
| 271 |
+
np.array(extrude[3:6], dtype=int),
|
| 272 |
+
n_bits=self.bit,
|
| 273 |
+
min_range=-EXTRUDE_R,
|
| 274 |
+
max_range=EXTRUDE_R,
|
| 275 |
+
add_noise=False,
|
| 276 |
+
)
|
| 277 |
+
scale = dequantize_verts(
|
| 278 |
+
np.array(extrude[15], dtype=int),
|
| 279 |
+
n_bits=self.bit,
|
| 280 |
+
min_range=0.0,
|
| 281 |
+
max_range=SCALE_R,
|
| 282 |
+
add_noise=False,
|
| 283 |
+
)
|
| 284 |
+
offset = dequantize_verts(
|
| 285 |
+
np.array(extrude[16:18], dtype=int),
|
| 286 |
+
n_bits=self.bit,
|
| 287 |
+
min_range=-OFFSET_R,
|
| 288 |
+
max_range=OFFSET_R,
|
| 289 |
+
add_noise=False,
|
| 290 |
+
)
|
| 291 |
+
return ext_v, ext_T, scale, offset
|
| 292 |
+
|
| 293 |
+
def save_vertex(self, h_x, h_y, text):
|
| 294 |
+
unique_key = f"{text}:x{h_x}y{h_y}"
|
| 295 |
+
index = 0
|
| 296 |
+
for key in self.vertex_dict.keys():
|
| 297 |
+
# Vertex location already exist in dict
|
| 298 |
+
if unique_key == key:
|
| 299 |
+
return index
|
| 300 |
+
index += 1
|
| 301 |
+
# Vertex location does not exist in dict
|
| 302 |
+
self.vertex_dict[unique_key] = [h_x, h_y]
|
| 303 |
+
return index
|
| 304 |
+
|
| 305 |
+
def convert_vertices(self):
|
| 306 |
+
"""Convert all the vertices to .obj format"""
|
| 307 |
+
vertex_strings = ""
|
| 308 |
+
for pt in self.vertex_dict.values():
|
| 309 |
+
# e.g. v 0.123 0.234 0.345 1.0
|
| 310 |
+
vertex_string = f"v {pt[0]} {pt[1]}\n"
|
| 311 |
+
vertex_strings += vertex_string
|
| 312 |
+
return vertex_strings
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def find_arc_geometry(a, b, c):
|
| 316 |
+
A = b[0] - a[0]
|
| 317 |
+
B = b[1] - a[1]
|
| 318 |
+
C = c[0] - a[0]
|
| 319 |
+
D = c[1] - a[1]
|
| 320 |
+
|
| 321 |
+
E = A*(a[0] + b[0]) + B*(a[1] + b[1])
|
| 322 |
+
F = C*(a[0] + c[0]) + D*(a[1] + c[1])
|
| 323 |
+
|
| 324 |
+
G = 2.0*(A*(c[1] - b[1])-B*(c[0] - b[0]))
|
| 325 |
+
|
| 326 |
+
if G == 0:
|
| 327 |
+
raise Exception("zero G")
|
| 328 |
+
|
| 329 |
+
p_0 = (D*E - B*F) / G
|
| 330 |
+
p_1 = (A*F - C*E) / G
|
| 331 |
+
|
| 332 |
+
center = np.array([p_0, p_1])
|
| 333 |
+
radius = np.linalg.norm(center - a)
|
| 334 |
+
|
| 335 |
+
angles = []
|
| 336 |
+
for xx in [a, b, c]:
|
| 337 |
+
angle = angle_from_vector_to_x(xx - center)
|
| 338 |
+
angles.append(angle)
|
| 339 |
+
|
| 340 |
+
ab = b-a
|
| 341 |
+
ac = c-a
|
| 342 |
+
cp = np.cross(ab, ac)
|
| 343 |
+
if cp >= 0:
|
| 344 |
+
start_angle_rads = angles[0]
|
| 345 |
+
end_angle_rads = angles[2]
|
| 346 |
+
else:
|
| 347 |
+
start_angle_rads = angles[2]
|
| 348 |
+
end_angle_rads = angles[0]
|
| 349 |
+
|
| 350 |
+
return center, radius, start_angle_rads, end_angle_rads
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def angle_from_vector_to_x(vec):
|
| 354 |
+
assert vec.size == 2
|
| 355 |
+
# We need to find a unit vector
|
| 356 |
+
angle = 0.0
|
| 357 |
+
|
| 358 |
+
l = np.linalg.norm(vec)
|
| 359 |
+
uvec = vec/l
|
| 360 |
+
|
| 361 |
+
# 2 | 1
|
| 362 |
+
# -------
|
| 363 |
+
# 3 | 4
|
| 364 |
+
if uvec[0] >= 0:
|
| 365 |
+
if uvec[1] >= 0:
|
| 366 |
+
# Qadrant 1
|
| 367 |
+
angle = math.asin(uvec[1])
|
| 368 |
+
else:
|
| 369 |
+
# Qadrant 4
|
| 370 |
+
angle = 2.0*math.pi - math.asin(-uvec[1])
|
| 371 |
+
else:
|
| 372 |
+
if vec[1] >= 0:
|
| 373 |
+
# Qadrant 2
|
| 374 |
+
angle = math.pi - math.asin(uvec[1])
|
| 375 |
+
else:
|
| 376 |
+
# Qadrant 3
|
| 377 |
+
angle = math.pi + math.asin(-uvec[1])
|
| 378 |
+
return angle
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def dequantize_verts(verts, n_bits=8, min_range=-0.5, max_range=0.5, add_noise=False):
|
| 382 |
+
"""Convert quantized vertices to floats."""
|
| 383 |
+
range_quantize = 2**n_bits - 1
|
| 384 |
+
verts = verts.astype("float32")
|
| 385 |
+
verts = verts * (max_range - min_range) / range_quantize + min_range
|
| 386 |
+
return verts
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def write_obj_sample(save_folder, data):
|
| 390 |
+
for idx, write_data in enumerate(data):
|
| 391 |
+
obj_name = Path(save_folder).stem + "_" + \
|
| 392 |
+
str(idx).zfill(3) + "_param.obj"
|
| 393 |
+
obj_file = Path(save_folder) / obj_name
|
| 394 |
+
extrude_param = write_data["extrude"]
|
| 395 |
+
vertex_strings = write_data["vertex"]
|
| 396 |
+
curve_strings = write_data["curve"]
|
| 397 |
+
|
| 398 |
+
"""Write an .obj file with the curves and verts"""
|
| 399 |
+
if extrude_param["op"] == 1: # 'add'
|
| 400 |
+
set_op = "NewBodyFeatureOperation"
|
| 401 |
+
elif extrude_param["op"] == 2: # 'cut'
|
| 402 |
+
set_op = "CutFeatureOperation"
|
| 403 |
+
elif extrude_param["op"] == 3: # 'cut'
|
| 404 |
+
set_op = "IntersectFeatureOperation"
|
| 405 |
+
|
| 406 |
+
with open(obj_file, "w") as fh:
|
| 407 |
+
# Write Meta info
|
| 408 |
+
fh.write("# WaveFront *.obj file\n")
|
| 409 |
+
fh.write("# ExtrudeOperation: " + set_op + "\n")
|
| 410 |
+
fh.write("\n")
|
| 411 |
+
|
| 412 |
+
# Write vertex and curve
|
| 413 |
+
fh.write(vertex_strings)
|
| 414 |
+
fh.write("\n")
|
| 415 |
+
fh.write(curve_strings)
|
| 416 |
+
fh.write("\n")
|
| 417 |
+
|
| 418 |
+
# Write extrude value
|
| 419 |
+
extrude_string = "Extrude "
|
| 420 |
+
for value in extrude_param["value"]:
|
| 421 |
+
extrude_string += str(value) + " "
|
| 422 |
+
fh.write(extrude_string)
|
| 423 |
+
fh.write("\n")
|
| 424 |
+
|
| 425 |
+
# Write refe plane value
|
| 426 |
+
p_orig = parse3d_sample(extrude_param["T"])
|
| 427 |
+
x_axis = parse3d_sample(extrude_param["R"][0:3])
|
| 428 |
+
y_axis = parse3d_sample(extrude_param["R"][3:6])
|
| 429 |
+
z_axis = parse3d_sample(extrude_param["R"][6:9])
|
| 430 |
+
fh.write("T_origin " + p_orig)
|
| 431 |
+
fh.write("\n")
|
| 432 |
+
fh.write("T_xaxis " + x_axis)
|
| 433 |
+
fh.write("\n")
|
| 434 |
+
fh.write("T_yaxis " + y_axis)
|
| 435 |
+
fh.write("\n")
|
| 436 |
+
fh.write("T_zaxis " + z_axis)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def parse3d_sample(point3d):
|
| 440 |
+
x = point3d[0]
|
| 441 |
+
y = point3d[1]
|
| 442 |
+
z = point3d[2]
|
| 443 |
+
return str(x) + " " + str(y) + " " + str(z)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
if __name__ == "__main__":
|
| 447 |
+
parser = argparse.ArgumentParser()
|
| 448 |
+
parser.add_argument("--in-path", type=str, required=True)
|
| 449 |
+
parser.add_argument("--out-path", type=str, required=True)
|
| 450 |
+
args = parser.parse_args()
|
| 451 |
+
|
| 452 |
+
# with open(args.in_path, "r") as f:
|
| 453 |
+
# data = f.readlines()
|
| 454 |
+
with open(args.in_path, 'r') as file:
|
| 455 |
+
data = file.read()
|
| 456 |
+
|
| 457 |
+
data = json.loads(data)
|
| 458 |
+
|
| 459 |
+
num_valid_str = 0
|
| 460 |
+
for idx, item in enumerate(data):
|
| 461 |
+
try:
|
| 462 |
+
cad_parser = CADparser(bit=6)
|
| 463 |
+
# print(idx)
|
| 464 |
+
if type(item) == str:
|
| 465 |
+
parsed_data = cad_parser.perform(item)
|
| 466 |
+
elif type(item) == dict:
|
| 467 |
+
parsed_data = cad_parser.perform(item['output'])
|
| 468 |
+
else:
|
| 469 |
+
raise ValueError("Invalid data type")
|
| 470 |
+
out_path = os.path.join(args.out_path, str(idx).zfill(6))
|
| 471 |
+
os.makedirs(out_path, exist_ok=True)
|
| 472 |
+
if parsed_data is not None:
|
| 473 |
+
num_valid_str += 1
|
| 474 |
+
write_obj_sample(out_path, parsed_data)
|
| 475 |
+
except Exception as e:
|
| 476 |
+
print(e)
|
| 477 |
+
pass
|
| 478 |
+
print(f"Number of valid CAD strings: {num_valid_str}/{len(data)}")
|
CADFusion/src/rendering_utils/parser_visual.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from multiprocessing import Pool
|
| 6 |
+
from glob import glob
|
| 7 |
+
from utils.obj_reconverter import OBJReconverter
|
| 8 |
+
from OCC.Core.BRepCheck import BRepCheck_Analyzer
|
| 9 |
+
from geometry.obj_parser import OBJParser
|
| 10 |
+
from utils.util import write_stl_file
|
| 11 |
+
from OCC.Extend.DataExchange import write_step_file
|
| 12 |
+
|
| 13 |
+
import signal
|
| 14 |
+
from contextlib import contextmanager
|
| 15 |
+
@contextmanager
|
| 16 |
+
def timeout(time):
|
| 17 |
+
# Register a function to raise a TimeoutError on the signal.
|
| 18 |
+
signal.signal(signal.SIGALRM, raise_timeout)
|
| 19 |
+
# Schedule the signal to be sent after ``time``.
|
| 20 |
+
signal.alarm(time)
|
| 21 |
+
try:
|
| 22 |
+
yield
|
| 23 |
+
except TimeoutError:
|
| 24 |
+
raise Exception("time out")
|
| 25 |
+
finally:
|
| 26 |
+
# Unregister the signal so it won't be triggered
|
| 27 |
+
# if the timeout is not reached.
|
| 28 |
+
signal.signal(signal.SIGALRM, signal.SIG_IGN)
|
| 29 |
+
def raise_timeout(signum, frame):
|
| 30 |
+
raise TimeoutError
|
| 31 |
+
|
| 32 |
+
NUM_TRHEADS = 36
|
| 33 |
+
|
| 34 |
+
def find_files(folder, extension):
|
| 35 |
+
return sorted([Path(os.path.join(folder, f)) for f in os.listdir(folder) if f.endswith(extension)])
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def run_parallel(project_folder):
|
| 39 |
+
output_folder = project_folder
|
| 40 |
+
|
| 41 |
+
param_objs = find_files(project_folder, 'param.obj')
|
| 42 |
+
|
| 43 |
+
cur_solid = None
|
| 44 |
+
extrude_idx = 0
|
| 45 |
+
for obj in param_objs:
|
| 46 |
+
try:
|
| 47 |
+
with timeout(30):
|
| 48 |
+
parser = OBJParser(obj)
|
| 49 |
+
_, faces, meta_info = parser.parse_file(1.0)
|
| 50 |
+
converter = OBJReconverter()
|
| 51 |
+
ext_solid, _, _ = converter.parse_obj(faces, meta_info)
|
| 52 |
+
set_op = meta_info["set_op"]
|
| 53 |
+
if set_op == "NewBodyFeatureOperation" or set_op == "JoinFeatureOperation":
|
| 54 |
+
if cur_solid is None:
|
| 55 |
+
cur_solid = ext_solid
|
| 56 |
+
else:
|
| 57 |
+
cur_solid = converter.my_op(cur_solid, ext_solid, 'fuse')
|
| 58 |
+
elif set_op == "CutFeatureOperation":
|
| 59 |
+
cur_solid = converter.my_op(cur_solid, ext_solid, 'cut')
|
| 60 |
+
elif set_op == "IntersectFeatureOperation":
|
| 61 |
+
cur_solid = converter.my_op(cur_solid, ext_solid, 'common')
|
| 62 |
+
else:
|
| 63 |
+
raise Exception("Unknown operation type")
|
| 64 |
+
|
| 65 |
+
analyzer = BRepCheck_Analyzer(cur_solid)
|
| 66 |
+
if not analyzer.IsValid():
|
| 67 |
+
raise Exception("brep check failed")
|
| 68 |
+
|
| 69 |
+
extrude_idx += 1
|
| 70 |
+
|
| 71 |
+
except Exception as ex:
|
| 72 |
+
print(ex)
|
| 73 |
+
msg = [project_folder, str(ex)[:100]]
|
| 74 |
+
return None
|
| 75 |
+
try:
|
| 76 |
+
with timeout(30):
|
| 77 |
+
stl_name = Path(output_folder).stem + '_'+ str(extrude_idx).zfill(3) + "_final.stl"
|
| 78 |
+
output_path = os.path.join(output_folder, stl_name)
|
| 79 |
+
write_stl_file(cur_solid, output_path, linear_deflection=0.001, angular_deflection=0.5)
|
| 80 |
+
|
| 81 |
+
step_name = Path(output_folder).stem + '_'+ str(extrude_idx).zfill(3) + "_final.step"
|
| 82 |
+
output_path = os.path.join(output_folder, step_name)
|
| 83 |
+
write_step_file(cur_solid, output_path)
|
| 84 |
+
|
| 85 |
+
except Exception as ex:
|
| 86 |
+
print(ex)
|
| 87 |
+
msg = [project_folder, str(ex)[:500]]
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
return cur_solid
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
parser = argparse.ArgumentParser()
|
| 95 |
+
parser.add_argument("--data_folder", type=str, required=True)
|
| 96 |
+
parser.add_argument("--single-file", action='store_true', default=False)
|
| 97 |
+
args = parser.parse_args()
|
| 98 |
+
|
| 99 |
+
if args.single_file:
|
| 100 |
+
# If single file, just run the function on that file
|
| 101 |
+
run_parallel(args.data_folder)
|
| 102 |
+
exit(0)
|
| 103 |
+
else:
|
| 104 |
+
solids = []
|
| 105 |
+
# cad_folders = sorted(glob(args.data_folder+'/*'))[50000:] # why after 50000?
|
| 106 |
+
cad_folders = sorted(glob(args.data_folder+'/*'))
|
| 107 |
+
# print("len of cad_folder:", len(cad_folders))
|
| 108 |
+
convert_iter = Pool(NUM_TRHEADS).imap(run_parallel, cad_folders)
|
| 109 |
+
for solid in tqdm(convert_iter, total=len(cad_folders)):
|
| 110 |
+
pass
|
CADFusion/src/rendering_utils/ptl_sampler.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import ntpath
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import multiprocessing
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from glob import glob
|
| 8 |
+
import trimesh
|
| 9 |
+
from trimesh.sample import sample_surface
|
| 10 |
+
from plyfile import PlyData, PlyElement
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def write_ply(points, filename, text=False):
|
| 15 |
+
""" input: Nx3, write points to filename as PLY format. """
|
| 16 |
+
points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])]
|
| 17 |
+
vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')])
|
| 18 |
+
el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
|
| 19 |
+
with open(filename, mode='wb') as f:
|
| 20 |
+
PlyData([el], text=text).write(f)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def find_files(folder, extension):
|
| 24 |
+
return sorted([Path(os.path.join(folder, f)) for f in os.listdir(folder) if f.endswith(extension)])
|
| 25 |
+
|
| 26 |
+
class SamplePoints:
|
| 27 |
+
"""
|
| 28 |
+
Perform sampleing of points.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self):
|
| 32 |
+
"""
|
| 33 |
+
Constructor.
|
| 34 |
+
"""
|
| 35 |
+
parser = self.get_parser()
|
| 36 |
+
self.options = parser.parse_args()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_parser(self):
|
| 40 |
+
"""
|
| 41 |
+
Get parser of tool.
|
| 42 |
+
|
| 43 |
+
:return: parser
|
| 44 |
+
"""
|
| 45 |
+
parser = argparse.ArgumentParser(description='Scale a set of meshes stored as OFF files.')
|
| 46 |
+
parser.add_argument('--in_dir', type=str, help='Path to input directory.')
|
| 47 |
+
parser.add_argument('--out_dir', type=str, help='Path to output directory; files within are overwritten!')
|
| 48 |
+
parser.add_argument("--single-file", action='store_true', default=False)
|
| 49 |
+
return parser
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def run_parallel(self, project_folder):
|
| 53 |
+
out_folder = os.path.join(project_folder, self.options.out_dir)
|
| 54 |
+
if not os.path.exists(out_folder):
|
| 55 |
+
os.makedirs(out_folder)
|
| 56 |
+
|
| 57 |
+
files = find_files(project_folder, 'final.stl')
|
| 58 |
+
|
| 59 |
+
for filepath in files:
|
| 60 |
+
N_POINTS = 2000
|
| 61 |
+
try:
|
| 62 |
+
out_mesh = trimesh.load(str(filepath))
|
| 63 |
+
out_pc, _ = sample_surface(out_mesh, N_POINTS)
|
| 64 |
+
save_path = os.path.join(out_folder, ntpath.basename(filepath)[:-4]+'_pcd.ply')
|
| 65 |
+
write_ply(out_pc, save_path)
|
| 66 |
+
|
| 67 |
+
except Exception as ex:
|
| 68 |
+
return project_folder
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def run(self):
|
| 73 |
+
"""
|
| 74 |
+
Run simplification.
|
| 75 |
+
"""
|
| 76 |
+
if self.options.single_file:
|
| 77 |
+
self.run_parallel(self.options.in_dir)
|
| 78 |
+
else:
|
| 79 |
+
project_folders = sorted(glob(self.options.in_dir+'/*/'))
|
| 80 |
+
num_cpus = multiprocessing.cpu_count()
|
| 81 |
+
convert_iter = multiprocessing.Pool(num_cpus).imap(self.run_parallel, project_folders)
|
| 82 |
+
for _ in tqdm(convert_iter, total=len(project_folders)):
|
| 83 |
+
pass
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == '__main__':
|
| 87 |
+
app = SamplePoints()
|
| 88 |
+
app.run()
|
CADFusion/src/rendering_utils/utils/obj_reconverter.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from .util import create_point, create_unit_vec, get_transform, create_sketch_plane
|
| 4 |
+
|
| 5 |
+
# OCC
|
| 6 |
+
from OCC.Core.BRepCheck import BRepCheck_Analyzer
|
| 7 |
+
from OCC.Core.GC import GC_MakeArcOfCircle
|
| 8 |
+
from OCC.Core.BRepBuilderAPI import (
|
| 9 |
+
BRepBuilderAPI_MakeFace,
|
| 10 |
+
BRepBuilderAPI_MakeWire,
|
| 11 |
+
BRepBuilderAPI_MakeEdge,
|
| 12 |
+
)
|
| 13 |
+
from OCC.Core.BRepAlgoAPI import BRepAlgoAPI_Fuse, BRepAlgoAPI_Cut, BRepAlgoAPI_Common
|
| 14 |
+
from OCC.Core.BRepPrimAPI import BRepPrimAPI_MakePrism
|
| 15 |
+
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface
|
| 16 |
+
from OCC.Core.BRepGProp import brepgprop_VolumeProperties, brepgprop_SurfaceProperties
|
| 17 |
+
from OCC.Core.GProp import GProp_GProps
|
| 18 |
+
from OCC.Core.ShapeFix import ShapeFix_Face, ShapeFix_Wire
|
| 19 |
+
from OCC.Core.gp import gp_Vec, gp_Ax2, gp_Dir, gp_Circ
|
| 20 |
+
from OCC.Extend.DataExchange import write_stl_file
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class OBJReconverter:
|
| 24 |
+
"""OBJ Data Reconverter"""
|
| 25 |
+
|
| 26 |
+
def __init__(self):
|
| 27 |
+
self.vertex_dict = OrderedDict()
|
| 28 |
+
self.PRECISION = 1e-5
|
| 29 |
+
self.eps = 1e-7
|
| 30 |
+
self.x_axis = gp_Dir(1.0, 0.0, 0.0)
|
| 31 |
+
|
| 32 |
+
def convert_curve(self, curve):
|
| 33 |
+
"""
|
| 34 |
+
convert to json dict format
|
| 35 |
+
"""
|
| 36 |
+
json_curve = {}
|
| 37 |
+
|
| 38 |
+
if curve.type == "circle":
|
| 39 |
+
json_curve["type"] = "Circle3D"
|
| 40 |
+
json_curve["center_point"] = {
|
| 41 |
+
"x": curve.center[0],
|
| 42 |
+
"y": curve.center[1],
|
| 43 |
+
"z": 0,
|
| 44 |
+
}
|
| 45 |
+
json_curve["radius"] = curve.radius
|
| 46 |
+
|
| 47 |
+
if curve.type == "line":
|
| 48 |
+
json_curve["type"] = "Line3D"
|
| 49 |
+
json_curve["start_point"] = {
|
| 50 |
+
"x": curve.start[0],
|
| 51 |
+
"y": curve.start[1],
|
| 52 |
+
"z": 0,
|
| 53 |
+
}
|
| 54 |
+
json_curve["end_point"] = {"x": curve.end[0], "y": curve.end[1], "z": 0}
|
| 55 |
+
|
| 56 |
+
if curve.type == "arc":
|
| 57 |
+
json_curve["type"] = "Arc3D"
|
| 58 |
+
json_curve["start_point"] = {
|
| 59 |
+
"x": curve.start[0],
|
| 60 |
+
"y": curve.start[1],
|
| 61 |
+
"z": 0,
|
| 62 |
+
}
|
| 63 |
+
json_curve["end_point"] = {"x": curve.end[0], "y": curve.end[1], "z": 0}
|
| 64 |
+
json_curve["mid_point"] = {"x": curve.mid[0], "y": curve.mid[1], "z": 0}
|
| 65 |
+
json_curve["center_point"] = {
|
| 66 |
+
"x": curve.center[0],
|
| 67 |
+
"y": curve.center[1],
|
| 68 |
+
"z": 0,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
json_curve["is_outer"] = curve.is_outer
|
| 72 |
+
return json_curve
|
| 73 |
+
|
| 74 |
+
def convert_vertices(self):
|
| 75 |
+
"""Convert all the vertices to .obj format"""
|
| 76 |
+
vertex_strings = ""
|
| 77 |
+
for pt in self.vertex_dict.values():
|
| 78 |
+
# e.g. v 0.123 0.234 0.345 1.0
|
| 79 |
+
vertex_string = f"v {pt[0]} {pt[1]}\n"
|
| 80 |
+
vertex_strings += vertex_string
|
| 81 |
+
return vertex_strings
|
| 82 |
+
|
| 83 |
+
def parse_obj(self, faces, meta_info):
|
| 84 |
+
"""
|
| 85 |
+
reconstruct brep from obj file
|
| 86 |
+
"""
|
| 87 |
+
# At least one needs to match
|
| 88 |
+
for face in faces:
|
| 89 |
+
for loop in face:
|
| 90 |
+
if len(loop) > 1:
|
| 91 |
+
for idx, curve in enumerate(loop[:-1]):
|
| 92 |
+
next_curve = np.vstack([loop[idx + 1].start, loop[idx + 1].end])
|
| 93 |
+
diff1 = np.sum(np.abs(curve.start - next_curve), 1)
|
| 94 |
+
diff2 = np.sum(np.abs(curve.end - next_curve), 1)
|
| 95 |
+
|
| 96 |
+
if min(diff2) == 0 or min(diff1) == 0:
|
| 97 |
+
continue # edge connected
|
| 98 |
+
|
| 99 |
+
assert (
|
| 100 |
+
min(diff1) < 1e-3 or min(diff2) < 1e-3
|
| 101 |
+
) # difference should be small
|
| 102 |
+
|
| 103 |
+
if min(diff1) > min(diff2):
|
| 104 |
+
min_idx = np.argmin(diff2)
|
| 105 |
+
if min_idx == 0:
|
| 106 |
+
loop[idx + 1].start_idx = curve.end_idx
|
| 107 |
+
loop[idx + 1].start = curve.end
|
| 108 |
+
else:
|
| 109 |
+
loop[idx + 1].end_idx = curve.end_idx
|
| 110 |
+
loop[idx + 1].end = curve.end
|
| 111 |
+
else:
|
| 112 |
+
min_idx = np.argmin(diff1)
|
| 113 |
+
if min_idx == 0:
|
| 114 |
+
loop[idx + 1].start_idx = curve.start_idx
|
| 115 |
+
loop[idx + 1].start = curve.start
|
| 116 |
+
else:
|
| 117 |
+
loop[idx + 1].end_idx = curve.start_idx
|
| 118 |
+
loop[idx + 1].end = curve.start
|
| 119 |
+
|
| 120 |
+
# Solve start / end connection
|
| 121 |
+
shared_idx = list(
|
| 122 |
+
set([loop[-2].start_idx, loop[-2].end_idx]).intersection(
|
| 123 |
+
set([loop[-1].start_idx, loop[-1].end_idx])
|
| 124 |
+
)
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
assert len(shared_idx) >= 1
|
| 128 |
+
|
| 129 |
+
if len(shared_idx) == 2:
|
| 130 |
+
assert len(loop) == 2 # do nothing
|
| 131 |
+
else:
|
| 132 |
+
if shared_idx[0] == loop[-1].start_idx:
|
| 133 |
+
do_start = False
|
| 134 |
+
else:
|
| 135 |
+
do_start = True
|
| 136 |
+
start_curve = np.vstack([loop[0].start, loop[0].end])
|
| 137 |
+
|
| 138 |
+
if do_start:
|
| 139 |
+
diff = np.sum(np.abs(loop[-1].start - start_curve), 1)
|
| 140 |
+
else:
|
| 141 |
+
diff = np.sum(np.abs(loop[-1].end - start_curve), 1)
|
| 142 |
+
assert min(diff) < 1e-3
|
| 143 |
+
|
| 144 |
+
min_idx = np.argmin(diff)
|
| 145 |
+
if min_idx == 0:
|
| 146 |
+
if do_start:
|
| 147 |
+
loop[-1].start_idx = loop[0].start_idx
|
| 148 |
+
loop[-1].start = loop[0].start
|
| 149 |
+
else:
|
| 150 |
+
loop[-1].end_idx = loop[0].start_idx
|
| 151 |
+
loop[-1].end = loop[0].start
|
| 152 |
+
else:
|
| 153 |
+
if do_start:
|
| 154 |
+
loop[-1].start_idx = loop[0].end_idx
|
| 155 |
+
loop[-1].start = loop[0].end
|
| 156 |
+
else:
|
| 157 |
+
loop[-1].end_idx = loop[0].end_idx
|
| 158 |
+
loop[-1].end = loop[0].end
|
| 159 |
+
|
| 160 |
+
# Parse groups to json loop/curve profile
|
| 161 |
+
extrusion = {}
|
| 162 |
+
extrusion["profiles"] = []
|
| 163 |
+
for face in faces:
|
| 164 |
+
profile = {}
|
| 165 |
+
profile["loops"] = []
|
| 166 |
+
for loop in face:
|
| 167 |
+
pl = {}
|
| 168 |
+
pl["profile_curves"] = []
|
| 169 |
+
for curve in loop:
|
| 170 |
+
# convert to json format
|
| 171 |
+
pl["profile_curves"].append(self.convert_curve(curve))
|
| 172 |
+
profile["loops"].append(pl)
|
| 173 |
+
extrusion["profiles"].append(profile)
|
| 174 |
+
|
| 175 |
+
# Parse transform
|
| 176 |
+
sketch = {}
|
| 177 |
+
transform = {}
|
| 178 |
+
transform["origin"] = {
|
| 179 |
+
"x": meta_info["t_orig"][0],
|
| 180 |
+
"y": meta_info["t_orig"][1],
|
| 181 |
+
"z": meta_info["t_orig"][2],
|
| 182 |
+
}
|
| 183 |
+
transform["x_axis"] = {
|
| 184 |
+
"x": meta_info["t_x"][0],
|
| 185 |
+
"y": meta_info["t_x"][1],
|
| 186 |
+
"z": meta_info["t_x"][2],
|
| 187 |
+
}
|
| 188 |
+
transform["y_axis"] = {
|
| 189 |
+
"x": meta_info["t_y"][0],
|
| 190 |
+
"y": meta_info["t_y"][1],
|
| 191 |
+
"z": meta_info["t_y"][2],
|
| 192 |
+
}
|
| 193 |
+
transform["z_axis"] = {
|
| 194 |
+
"x": meta_info["t_z"][0],
|
| 195 |
+
"y": meta_info["t_z"][1],
|
| 196 |
+
"z": meta_info["t_z"][2],
|
| 197 |
+
}
|
| 198 |
+
sketch["transform"] = transform
|
| 199 |
+
|
| 200 |
+
# Parse extrude
|
| 201 |
+
extrude_params = {}
|
| 202 |
+
extrude_params["extrude_type"] = meta_info["set_op"]
|
| 203 |
+
extrude_params["extrude_values"] = meta_info["extrude_value"]
|
| 204 |
+
|
| 205 |
+
# Create sketch
|
| 206 |
+
all_faces = []
|
| 207 |
+
curve_strings = ""
|
| 208 |
+
curve_count = 0
|
| 209 |
+
for profile in extrusion["profiles"]:
|
| 210 |
+
ref_face, face, curve_string, c_count = self.parse_sketch(sketch, profile)
|
| 211 |
+
curve_strings += curve_string
|
| 212 |
+
curve_count += c_count
|
| 213 |
+
all_faces.append(face)
|
| 214 |
+
|
| 215 |
+
# Merge all faces in the same plane
|
| 216 |
+
plane_face = all_faces[0]
|
| 217 |
+
for face in all_faces[1:]:
|
| 218 |
+
plane_face = self.my_op(plane_face, face, "fuse")
|
| 219 |
+
solid = self.extrude_face(ref_face, plane_face, extrude_params)
|
| 220 |
+
return solid, curve_strings, curve_count
|
| 221 |
+
|
| 222 |
+
def my_op(self, big, small, op_name):
|
| 223 |
+
if op_name == "cut":
|
| 224 |
+
op = BRepAlgoAPI_Cut(big, small)
|
| 225 |
+
elif op_name == "fuse":
|
| 226 |
+
op = BRepAlgoAPI_Fuse(big, small)
|
| 227 |
+
elif op_name == "common":
|
| 228 |
+
op = BRepAlgoAPI_Common(big, small)
|
| 229 |
+
op.SetFuzzyValue(self.PRECISION)
|
| 230 |
+
op.Build()
|
| 231 |
+
return op.Shape()
|
| 232 |
+
|
| 233 |
+
def build_body(self, face, normal, value):
|
| 234 |
+
extrusion_vec = gp_Vec(normal).Multiplied(value)
|
| 235 |
+
make_prism = BRepPrimAPI_MakePrism(face, extrusion_vec)
|
| 236 |
+
make_prism.Build()
|
| 237 |
+
prism = make_prism.Prism()
|
| 238 |
+
return prism.Shape()
|
| 239 |
+
|
| 240 |
+
def extrudeBasedOnType(self, face, normal, distance):
|
| 241 |
+
# Extrude based on the two bound values
|
| 242 |
+
if not (distance[0] < distance[1]):
|
| 243 |
+
raise Exception("incorrect distance")
|
| 244 |
+
large_value = distance[1]
|
| 245 |
+
small_value = distance[0]
|
| 246 |
+
|
| 247 |
+
if large_value == 0:
|
| 248 |
+
return self.build_body(face, -normal, -small_value)
|
| 249 |
+
elif small_value == 0:
|
| 250 |
+
return self.build_body(face, normal, large_value)
|
| 251 |
+
elif np.sign(large_value) == np.sign(small_value):
|
| 252 |
+
if large_value < 0:
|
| 253 |
+
body1 = self.build_body(face, -normal, -small_value)
|
| 254 |
+
body2 = self.build_body(face, -normal, -large_value)
|
| 255 |
+
return self.my_op(body1, body2, "cut")
|
| 256 |
+
else:
|
| 257 |
+
assert large_value > 0
|
| 258 |
+
body1 = self.build_body(face, normal, small_value)
|
| 259 |
+
body2 = self.build_body(face, normal, large_value)
|
| 260 |
+
return self.my_op(body2, body1, "cut")
|
| 261 |
+
else:
|
| 262 |
+
assert np.sign(large_value) != np.sign(small_value)
|
| 263 |
+
body1 = self.build_body(face, normal, large_value)
|
| 264 |
+
body2 = self.build_body(face, -normal, -small_value)
|
| 265 |
+
return self.my_op(body1, body2, "fuse")
|
| 266 |
+
|
| 267 |
+
def extrude_face(self, ref_face, face, extrude_params):
|
| 268 |
+
distance = extrude_params["extrude_values"]
|
| 269 |
+
surf = BRepAdaptor_Surface(ref_face).Plane()
|
| 270 |
+
normal = surf.Axis().Direction()
|
| 271 |
+
extruded_shape = self.extrudeBasedOnType(face, normal, distance)
|
| 272 |
+
return extruded_shape
|
| 273 |
+
|
| 274 |
+
def parse_sketch(self, sketch, profile):
|
| 275 |
+
"""
|
| 276 |
+
Sketch in one closed loop (one out, multiple ins)
|
| 277 |
+
"""
|
| 278 |
+
# Transformation from local to global xyz coord
|
| 279 |
+
transform = get_transform(sketch["transform"])
|
| 280 |
+
|
| 281 |
+
# Create face region (automatically infer from all wires)
|
| 282 |
+
outer_facelist = []
|
| 283 |
+
inner_facelist = []
|
| 284 |
+
curve_count = 0
|
| 285 |
+
outer_string = []
|
| 286 |
+
inner_string = []
|
| 287 |
+
plane = create_sketch_plane(sketch["transform"])
|
| 288 |
+
|
| 289 |
+
for idx, pl in enumerate(profile["loops"]):
|
| 290 |
+
# Create loop
|
| 291 |
+
loop, curve_string, num_curve = self.parse_loop(
|
| 292 |
+
pl["profile_curves"], transform
|
| 293 |
+
)
|
| 294 |
+
# Create face
|
| 295 |
+
face_builder = BRepBuilderAPI_MakeFace(plane, loop)
|
| 296 |
+
if not face_builder.IsDone():
|
| 297 |
+
raise Exception("face builder not done")
|
| 298 |
+
face = face_builder.Face()
|
| 299 |
+
# Fix face
|
| 300 |
+
fixer = ShapeFix_Face(face)
|
| 301 |
+
fixer.SetPrecision(self.PRECISION)
|
| 302 |
+
fixer.FixOrientation()
|
| 303 |
+
|
| 304 |
+
analyzer = BRepCheck_Analyzer(fixer.Face())
|
| 305 |
+
if not analyzer.IsValid():
|
| 306 |
+
raise Exception("face check failed")
|
| 307 |
+
|
| 308 |
+
curve_count += num_curve
|
| 309 |
+
|
| 310 |
+
if pl["profile_curves"][0]["is_outer"]:
|
| 311 |
+
outer_facelist.append(fixer.Face())
|
| 312 |
+
outer_string.append(curve_string)
|
| 313 |
+
else:
|
| 314 |
+
inner_facelist.append(fixer.Face())
|
| 315 |
+
inner_string.append(curve_string)
|
| 316 |
+
|
| 317 |
+
# Create final closed loop face
|
| 318 |
+
assert len(outer_facelist) > 0
|
| 319 |
+
final_face = outer_facelist[0]
|
| 320 |
+
for face in outer_facelist[1:]:
|
| 321 |
+
final_face = self.my_op(final_face, face, "fuse")
|
| 322 |
+
for face in inner_facelist:
|
| 323 |
+
final_face = self.my_op(final_face, face, "cut")
|
| 324 |
+
|
| 325 |
+
# Append inner outer information to string
|
| 326 |
+
assert len(outer_string) == 1
|
| 327 |
+
out_str = ""
|
| 328 |
+
in_str = ""
|
| 329 |
+
for c_str in outer_string:
|
| 330 |
+
out_str += "out\n" + c_str + "\n"
|
| 331 |
+
for c_str in inner_string:
|
| 332 |
+
in_str += "in\n" + c_str + "\n"
|
| 333 |
+
final_str = "face\n" + out_str + in_str
|
| 334 |
+
|
| 335 |
+
return outer_facelist[0], final_face, final_str, curve_count
|
| 336 |
+
|
| 337 |
+
def parse_loop(self, profile_loop, transform):
|
| 338 |
+
"""Create face in one closed loop"""
|
| 339 |
+
topo_wire = BRepBuilderAPI_MakeWire()
|
| 340 |
+
curve_strings = ""
|
| 341 |
+
curve_count = 0
|
| 342 |
+
|
| 343 |
+
# Loop through all the curves in one loop
|
| 344 |
+
for profile_curve in profile_loop:
|
| 345 |
+
curve_edge, curve_string = self.parse_curve(profile_curve, transform)
|
| 346 |
+
topo_wire.Add(curve_edge)
|
| 347 |
+
if not topo_wire.IsDone():
|
| 348 |
+
raise Exception("wire builder not done")
|
| 349 |
+
|
| 350 |
+
curve_string += "\n"
|
| 351 |
+
curve_count += 1
|
| 352 |
+
curve_strings += curve_string
|
| 353 |
+
|
| 354 |
+
fixer = ShapeFix_Wire()
|
| 355 |
+
fixer.Load(topo_wire.Wire())
|
| 356 |
+
fixer.SetPrecision(self.PRECISION)
|
| 357 |
+
fixer.FixClosed()
|
| 358 |
+
fixer.Perform()
|
| 359 |
+
return fixer.Wire(), curve_strings, curve_count
|
| 360 |
+
|
| 361 |
+
def parse_curve(self, curve, transform):
|
| 362 |
+
if curve["type"] == "Line3D":
|
| 363 |
+
return self.create_line(curve, transform)
|
| 364 |
+
elif curve["type"] == "Circle3D":
|
| 365 |
+
return self.create_circle(curve, transform)
|
| 366 |
+
elif curve["type"] == "Arc3D":
|
| 367 |
+
return self.create_arc(curve, transform)
|
| 368 |
+
else:
|
| 369 |
+
raise Exception("unknown curve type")
|
| 370 |
+
|
| 371 |
+
def create_line(self, line, transform):
|
| 372 |
+
start = create_point(line["start_point"], transform)
|
| 373 |
+
end = create_point(line["end_point"], transform)
|
| 374 |
+
if start.Distance(end) == 0:
|
| 375 |
+
raise Exception("start/end point same location")
|
| 376 |
+
topo_edge = BRepBuilderAPI_MakeEdge(start, end)
|
| 377 |
+
|
| 378 |
+
# Save pre-transform
|
| 379 |
+
star_idx = self.save_vertex(
|
| 380 |
+
line["start_point"]["x"] + 0.0, line["start_point"]["y"] + 0.0, "p"
|
| 381 |
+
)
|
| 382 |
+
end_idx = self.save_vertex(
|
| 383 |
+
line["end_point"]["x"] + 0.0, line["end_point"]["y"] + 0.0, "p"
|
| 384 |
+
)
|
| 385 |
+
curve_string = f"l {star_idx} {end_idx}"
|
| 386 |
+
return topo_edge.Edge(), curve_string
|
| 387 |
+
|
| 388 |
+
def create_arc(self, arc, transform):
|
| 389 |
+
start = create_point(arc["start_point"], transform)
|
| 390 |
+
mid = create_point(arc["mid_point"], transform)
|
| 391 |
+
end = create_point(arc["end_point"], transform)
|
| 392 |
+
arc_occ = GC_MakeArcOfCircle(start, mid, end).Value()
|
| 393 |
+
topo_edge = BRepBuilderAPI_MakeEdge(arc_occ)
|
| 394 |
+
|
| 395 |
+
# Save pre-transform
|
| 396 |
+
start_idx = self.save_vertex(
|
| 397 |
+
arc["start_point"]["x"] + 0.0, arc["start_point"]["y"] + 0.0, "p"
|
| 398 |
+
)
|
| 399 |
+
end_idx = self.save_vertex(
|
| 400 |
+
arc["end_point"]["x"] + 0.0, arc["end_point"]["y"] + 0.0, "p"
|
| 401 |
+
)
|
| 402 |
+
center_idx = self.save_vertex(
|
| 403 |
+
arc["center_point"]["x"] + 0.0, arc["center_point"]["y"] + 0.0, "p"
|
| 404 |
+
)
|
| 405 |
+
mid_idx = self.save_vertex(
|
| 406 |
+
arc["mid_point"]["x"] + 0.0, arc["mid_point"]["y"] + 0.0, "p"
|
| 407 |
+
)
|
| 408 |
+
curve_string = f"a {start_idx} {mid_idx} {center_idx} {end_idx}"
|
| 409 |
+
return topo_edge.Edge(), curve_string
|
| 410 |
+
|
| 411 |
+
def create_circle(self, circle, transform):
|
| 412 |
+
center = create_point(circle["center_point"], transform)
|
| 413 |
+
radius = circle["radius"]
|
| 414 |
+
normal = create_unit_vec({"x": 0.0, "y": 0.0, "z": 1.0}, transform)
|
| 415 |
+
ref_vector3d = self.x_axis.Transformed(transform)
|
| 416 |
+
axis = gp_Ax2(center, normal, ref_vector3d)
|
| 417 |
+
gp_circle = gp_Circ(axis, abs(float(radius)))
|
| 418 |
+
topo_edge = BRepBuilderAPI_MakeEdge(gp_circle)
|
| 419 |
+
|
| 420 |
+
center_idx = self.save_vertex(
|
| 421 |
+
circle["center_point"]["x"] + 0.0, circle["center_point"]["y"] + 0.0, "p"
|
| 422 |
+
)
|
| 423 |
+
radius_idx = self.save_vertex(abs(float(radius)) + 0.0, 0, "r")
|
| 424 |
+
curve_string = f"c {center_idx} {radius_idx}"
|
| 425 |
+
return topo_edge.Edge(), curve_string
|
| 426 |
+
|
| 427 |
+
def save_vertex(self, h_x, h_y, text):
|
| 428 |
+
unique_key = f"{text}:x{h_x}y{h_y}"
|
| 429 |
+
index = 0
|
| 430 |
+
for key in self.vertex_dict.keys():
|
| 431 |
+
# Vertex location already exist in dict
|
| 432 |
+
if unique_key == key:
|
| 433 |
+
return index
|
| 434 |
+
index += 1
|
| 435 |
+
# Vertex location does not exist in dict
|
| 436 |
+
self.vertex_dict[unique_key] = [h_x, h_y]
|
| 437 |
+
return index
|
CADFusion/src/rendering_utils/utils/util.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from OCC.Core.gp import gp_Pnt, gp_Vec, gp_Dir, gp_XYZ, gp_Ax3, gp_Trsf, gp_Pln
|
| 3 |
+
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
|
| 4 |
+
from OCC.Core.StlAPI import StlAPI_Writer
|
| 5 |
+
|
| 6 |
+
def create_xyz(xyz):
|
| 7 |
+
return gp_XYZ(xyz["x"], xyz["y"], xyz["z"])
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_ax3(transform_dict):
|
| 11 |
+
origin = create_xyz(transform_dict["origin"])
|
| 12 |
+
x_axis = create_xyz(transform_dict["x_axis"])
|
| 13 |
+
y_axis = create_xyz(transform_dict["y_axis"])
|
| 14 |
+
z_axis = create_xyz(transform_dict["z_axis"])
|
| 15 |
+
# Create new coord (orig, Norm, x-axis)
|
| 16 |
+
axis3 = gp_Ax3(gp_Pnt(origin), gp_Dir(z_axis), gp_Dir(x_axis))
|
| 17 |
+
return axis3
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_transform(transform_dict):
|
| 21 |
+
axis3 = get_ax3(transform_dict)
|
| 22 |
+
transform_to_local = gp_Trsf()
|
| 23 |
+
transform_to_local.SetTransformation(axis3)
|
| 24 |
+
return transform_to_local.Inverted()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def create_sketch_plane(transform_dict):
|
| 28 |
+
axis3 = get_ax3(transform_dict)
|
| 29 |
+
return gp_Pln(axis3)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def create_point(point_dict, transform):
|
| 33 |
+
pt2d = gp_Pnt(point_dict["x"], point_dict["y"], point_dict["z"])
|
| 34 |
+
return pt2d.Transformed(transform)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def create_unit_vec(vec_dict, transform):
|
| 38 |
+
vec2d = gp_Dir(vec_dict["x"], vec_dict["y"], vec_dict["z"])
|
| 39 |
+
return vec2d.Transformed(transform)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def write_stl_file(a_shape, filename, mode="ascii", linear_deflection=0.001, angular_deflection=0.5):
|
| 43 |
+
""" export the shape to a STL file
|
| 44 |
+
Be careful, the shape first need to be explicitely meshed using BRepMesh_IncrementalMesh
|
| 45 |
+
a_shape: the topods_shape to export
|
| 46 |
+
filename: the filename
|
| 47 |
+
mode: optional, "ascii" by default. Can either be "binary"
|
| 48 |
+
linear_deflection: optional, default to 0.001. Lower, more occurate mesh
|
| 49 |
+
angular_deflection: optional, default to 0.5. Lower, more accurate_mesh
|
| 50 |
+
"""
|
| 51 |
+
if a_shape.IsNull():
|
| 52 |
+
raise AssertionError("Shape is null.")
|
| 53 |
+
if mode not in ["ascii", "binary"]:
|
| 54 |
+
raise AssertionError("mode should be either ascii or binary")
|
| 55 |
+
if os.path.isfile(filename):
|
| 56 |
+
print("Warning: %s file already exists and will be replaced" % filename)
|
| 57 |
+
# first mesh the shape
|
| 58 |
+
mesh = BRepMesh_IncrementalMesh(a_shape, linear_deflection, False, angular_deflection, True)
|
| 59 |
+
#mesh.SetDeflection(0.05)
|
| 60 |
+
mesh.Perform()
|
| 61 |
+
if not mesh.IsDone():
|
| 62 |
+
raise AssertionError("Mesh is not done.")
|
| 63 |
+
|
| 64 |
+
stl_exporter = StlAPI_Writer()
|
| 65 |
+
if mode == "ascii":
|
| 66 |
+
stl_exporter.SetASCIIMode(True)
|
| 67 |
+
else: # binary, just set the ASCII flag to False
|
| 68 |
+
stl_exporter.SetASCIIMode(False)
|
| 69 |
+
stl_exporter.Write(a_shape, filename)
|
| 70 |
+
|
| 71 |
+
if not os.path.isfile(filename):
|
| 72 |
+
raise IOError("File not written to disk.")
|
CADFusion/src/test/VLM_score.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
import base64
|
| 4 |
+
import json
|
| 5 |
+
import time
|
| 6 |
+
import argparse
|
| 7 |
+
from mimetypes import guess_type
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import re
|
| 10 |
+
|
| 11 |
+
from openai import AzureOpenAI
|
| 12 |
+
from azure.identity import AzureCliCredential, get_bearer_token_provider
|
| 13 |
+
|
| 14 |
+
scope = "api://trapi/.default"
|
| 15 |
+
credential = get_bearer_token_provider(AzureCliCredential(),scope)
|
| 16 |
+
|
| 17 |
+
api_version = '2024-12-01-preview'
|
| 18 |
+
# deployment_name = 'gpt-4.1-mini_2025-04-14'
|
| 19 |
+
deployment_name = 'gpt-4o_2024-08-06'
|
| 20 |
+
instance = '<trapi/path>' # See https://aka.ms/trapi/models for the instance name, remove /openai (library adds it implicitly)
|
| 21 |
+
endpoint = f'https://trapi.research.microsoft.com/{instance}'
|
| 22 |
+
|
| 23 |
+
client = AzureOpenAI(
|
| 24 |
+
azure_endpoint=endpoint,
|
| 25 |
+
azure_ad_token_provider=credential,
|
| 26 |
+
api_version=api_version,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def local_image_to_data_url(image_path):
|
| 30 |
+
mime_type, _ = guess_type(image_path)
|
| 31 |
+
if mime_type is None:
|
| 32 |
+
mime_type = 'application/octet-stream'
|
| 33 |
+
with open(image_path, "rb") as image_file:
|
| 34 |
+
base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
|
| 35 |
+
return f"data:{mime_type};base64,{base64_encoded_data}"
|
| 36 |
+
|
| 37 |
+
def ask_gpt(image_path, prompt):
|
| 38 |
+
image_url = local_image_to_data_url(image_path)
|
| 39 |
+
message_text = [
|
| 40 |
+
{"role": "system", "content": "You are an AI assistant that helps people find information."},
|
| 41 |
+
{"role": "user", "content": [
|
| 42 |
+
{"type": "text", "text": prompt},
|
| 43 |
+
{"type": "image_url", "image_url": {"url": image_url}},
|
| 44 |
+
]}
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
completion = client.chat.completions.create(
|
| 48 |
+
model=deployment_name,
|
| 49 |
+
messages=message_text,)
|
| 50 |
+
output = completion.choices[0].message.content
|
| 51 |
+
return output
|
| 52 |
+
|
| 53 |
+
if __name__ == '__main__':
|
| 54 |
+
import argparse
|
| 55 |
+
parser = argparse.ArgumentParser()
|
| 56 |
+
parser.add_argument('--test-path', type=str, default='data/sl_data/test.jsonl', help='Path to the JSONL file containing test data')
|
| 57 |
+
parser.add_argument('--name', type=str, default='original_seq', help='Run name of the testee')
|
| 58 |
+
parser.add_argument('--figure-dir', type=str, default='exp/figures')
|
| 59 |
+
parser.add_argument('--save-path', type=str, default='exp/evals', help='Target folder to save the results')
|
| 60 |
+
parser.add_argument('--repetition', type=int, default=5, help='Number of repetitions for each image')
|
| 61 |
+
args = parser.parse_args()
|
| 62 |
+
|
| 63 |
+
results = []
|
| 64 |
+
jsonl_path = args.test_path
|
| 65 |
+
name = args.name
|
| 66 |
+
figures_dir = f"{args.figure_dir}/{name}/"
|
| 67 |
+
save_path = f"{args.save_path}/{name}.jsonl"
|
| 68 |
+
|
| 69 |
+
with open(jsonl_path, 'r+') as file:
|
| 70 |
+
test_data = json.load(file)
|
| 71 |
+
repetition = args.repetition
|
| 72 |
+
results = []
|
| 73 |
+
for i in tqdm(range(len(test_data[:800]))):
|
| 74 |
+
item = test_data[i]
|
| 75 |
+
for j in range(repetition):
|
| 76 |
+
img_num = i * repetition + j
|
| 77 |
+
image_name = f"{img_num:06d}.png"
|
| 78 |
+
image_path = os.path.join(figures_dir, image_name)
|
| 79 |
+
if os.path.exists(image_path):
|
| 80 |
+
description = item['description']
|
| 81 |
+
try:
|
| 82 |
+
score = ask_gpt(image_path, f"The following is a text description of a 3D CAD figure and an image of a CAD instance. Measure if the figure corresponds to the given description, and give a score in the scale of 10. Only return the score. Do not comment on issues such as texture, smoothness and colors.\n description:{description}\n")
|
| 83 |
+
|
| 84 |
+
# "The following is an original image of a CAD instance, a text description on editing and an image of the edited result. Measure if the figure corresponds to the given description, and give a score in the scale of 10. Only return the score. Do not comment on issues such as texture, smoothness and colors.\n description:{description}\n"
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(img_num)
|
| 87 |
+
print(e)
|
| 88 |
+
score = -1
|
| 89 |
+
result = {
|
| 90 |
+
"index": img_num,
|
| 91 |
+
"gpt_score": score
|
| 92 |
+
}
|
| 93 |
+
results.append(result)
|
| 94 |
+
with open(save_path, 'w+') as file:
|
| 95 |
+
json.dump(results, file, indent=4)
|
CADFusion/src/test/chamfer_dist.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import random
|
| 7 |
+
import warnings
|
| 8 |
+
from glob import glob
|
| 9 |
+
from scipy.stats import entropy
|
| 10 |
+
from sklearn.neighbors import NearestNeighbors
|
| 11 |
+
from plyfile import PlyData
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from multiprocessing import Pool
|
| 14 |
+
from chamfer_distance import ChamferDistance
|
| 15 |
+
|
| 16 |
+
random.seed(0)
|
| 17 |
+
N_POINTS = 2000
|
| 18 |
+
NUM_TRHEADS = 16
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def find_files(folder, extension):
|
| 22 |
+
return sorted([Path(os.path.join(folder, f)) for f in os.listdir(folder) if f.endswith(extension)])
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def read_ply(path):
|
| 26 |
+
with open(path, 'rb') as f:
|
| 27 |
+
plydata = PlyData.read(f)
|
| 28 |
+
x = np.array(plydata['vertex']['x'])
|
| 29 |
+
y = np.array(plydata['vertex']['y'])
|
| 30 |
+
z = np.array(plydata['vertex']['z'])
|
| 31 |
+
vertex = np.stack([x, y, z], axis=1)
|
| 32 |
+
return vertex
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def distChamfer(a, b):
|
| 36 |
+
x, y = a, b
|
| 37 |
+
bs, num_points, points_dim = x.size()
|
| 38 |
+
xx = torch.bmm(x, x.transpose(2, 1))
|
| 39 |
+
yy = torch.bmm(y, y.transpose(2, 1))
|
| 40 |
+
zz = torch.bmm(x, y.transpose(2, 1))
|
| 41 |
+
diag_ind = torch.arange(0, num_points).to(a).long()
|
| 42 |
+
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
|
| 43 |
+
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
|
| 44 |
+
P = (rx.transpose(2, 1) + ry - 2 * zz)
|
| 45 |
+
return P.min(1)[0], P.min(2)[0]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _pairwise_CD(sample_pcs, ref_pcs, batch_size):
|
| 49 |
+
N_sample = sample_pcs.shape[0]
|
| 50 |
+
N_ref = ref_pcs.shape[0]
|
| 51 |
+
all_cd = []
|
| 52 |
+
all_emd = []
|
| 53 |
+
iterator = range(N_sample)
|
| 54 |
+
matched_gt = []
|
| 55 |
+
pbar = tqdm(iterator)
|
| 56 |
+
chamfer_dist = ChamferDistance()
|
| 57 |
+
|
| 58 |
+
for sample_b_start in pbar:
|
| 59 |
+
sample_batch = sample_pcs[sample_b_start]
|
| 60 |
+
|
| 61 |
+
cd_lst = []
|
| 62 |
+
emd_lst = []
|
| 63 |
+
for ref_b_start in range(0, N_ref, batch_size):
|
| 64 |
+
ref_b_end = min(N_ref, ref_b_start + batch_size)
|
| 65 |
+
ref_batch = ref_pcs[ref_b_start:ref_b_end]
|
| 66 |
+
|
| 67 |
+
batch_size_ref = ref_batch.size(0)
|
| 68 |
+
sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1)
|
| 69 |
+
sample_batch_exp = sample_batch_exp.contiguous()
|
| 70 |
+
|
| 71 |
+
dl, dr, idx1, idx2 = chamfer_dist(sample_batch_exp,ref_batch)
|
| 72 |
+
cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))
|
| 73 |
+
|
| 74 |
+
cd_lst = torch.cat(cd_lst, dim=1)
|
| 75 |
+
all_cd.append(cd_lst)
|
| 76 |
+
|
| 77 |
+
hit = np.argmin(cd_lst.detach().cpu().numpy()[0])
|
| 78 |
+
matched_gt.append(hit)
|
| 79 |
+
pbar.set_postfix({"cov": len(np.unique(matched_gt)) * 1.0 / N_ref})
|
| 80 |
+
|
| 81 |
+
all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref
|
| 82 |
+
|
| 83 |
+
return all_cd
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def compute_cov_mmd(sample_pcs, ref_pcs, batch_size):
|
| 87 |
+
all_dist = _pairwise_CD(sample_pcs, ref_pcs, batch_size)
|
| 88 |
+
print(all_dist.shape, flush=True)
|
| 89 |
+
N_sample, N_ref = all_dist.size(0), all_dist.size(1)
|
| 90 |
+
min_val_fromsmp, min_idx = torch.min(all_dist, dim=1)
|
| 91 |
+
min_val, _ = torch.min(all_dist, dim=0)
|
| 92 |
+
mmd = min_val.mean()
|
| 93 |
+
cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)
|
| 94 |
+
cov = torch.tensor(cov).to(all_dist)
|
| 95 |
+
|
| 96 |
+
return {
|
| 97 |
+
# 'med-CD': torch.diagonal(all_dist).median().item(),
|
| 98 |
+
'avg-CD': torch.diagonal(all_dist).mean().item(),
|
| 99 |
+
'COV-CD': cov.item(),
|
| 100 |
+
'MMD-CD': mmd.item()
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, in_unit_sphere, resolution=28):
|
| 105 |
+
'''Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models For 3D Point Clouds```.
|
| 106 |
+
Args:
|
| 107 |
+
sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points.
|
| 108 |
+
ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points.
|
| 109 |
+
resolution: (int) grid-resolution. Affects granularity of measurements.
|
| 110 |
+
'''
|
| 111 |
+
sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1]
|
| 112 |
+
ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1]
|
| 113 |
+
return jensen_shannon_divergence(sample_grid_var, ref_grid_var)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False):
|
| 117 |
+
'''Given a collection of point-clouds, estimate the entropy of the random variables
|
| 118 |
+
corresponding to occupancy-grid activation patterns.
|
| 119 |
+
Inputs:
|
| 120 |
+
pclouds: (numpy array) #point-clouds x points per point-cloud x 3
|
| 121 |
+
grid_resolution (int) size of occupancy grid that will be used.
|
| 122 |
+
'''
|
| 123 |
+
epsilon = 10e-4
|
| 124 |
+
bound = 1 + epsilon
|
| 125 |
+
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
|
| 126 |
+
print(abs(np.max(pclouds)), abs(np.min(pclouds)))
|
| 127 |
+
warnings.warn('Point-clouds are not in unit cube.')
|
| 128 |
+
|
| 129 |
+
if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:
|
| 130 |
+
warnings.warn('Point-clouds are not in unit sphere.')
|
| 131 |
+
|
| 132 |
+
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
|
| 133 |
+
grid_coordinates = grid_coordinates.reshape(-1, 3)
|
| 134 |
+
grid_counters = np.zeros(len(grid_coordinates))
|
| 135 |
+
grid_bernoulli_rvars = np.zeros(len(grid_coordinates))
|
| 136 |
+
nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates)
|
| 137 |
+
|
| 138 |
+
for pc in pclouds:
|
| 139 |
+
_, indices = nn.kneighbors(pc)
|
| 140 |
+
indices = np.squeeze(indices)
|
| 141 |
+
for i in indices:
|
| 142 |
+
grid_counters[i] += 1
|
| 143 |
+
indices = np.unique(indices)
|
| 144 |
+
for i in indices:
|
| 145 |
+
grid_bernoulli_rvars[i] += 1
|
| 146 |
+
|
| 147 |
+
acc_entropy = 0.0
|
| 148 |
+
n = float(len(pclouds))
|
| 149 |
+
for g in grid_bernoulli_rvars:
|
| 150 |
+
p = 0.0
|
| 151 |
+
if g > 0:
|
| 152 |
+
p = float(g) / n
|
| 153 |
+
acc_entropy += entropy([p, 1.0 - p])
|
| 154 |
+
|
| 155 |
+
return acc_entropy / len(grid_counters), grid_counters
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
|
| 159 |
+
'''Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,
|
| 160 |
+
that is placed in the unit-cube.
|
| 161 |
+
If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere.
|
| 162 |
+
'''
|
| 163 |
+
grid = np.ndarray((resolution, resolution, resolution, 3), np.float32)
|
| 164 |
+
spacing = 1.0 / float(resolution - 1) * 2
|
| 165 |
+
for i in range(resolution):
|
| 166 |
+
for j in range(resolution):
|
| 167 |
+
for k in range(resolution):
|
| 168 |
+
grid[i, j, k, 0] = i * spacing - 0.5 * 2
|
| 169 |
+
grid[i, j, k, 1] = j * spacing - 0.5 * 2
|
| 170 |
+
grid[i, j, k, 2] = k * spacing - 0.5 * 2
|
| 171 |
+
|
| 172 |
+
if clip_sphere:
|
| 173 |
+
grid = grid.reshape(-1, 3)
|
| 174 |
+
grid = grid[np.linalg.norm(grid, axis=1) <= 0.5]
|
| 175 |
+
|
| 176 |
+
return grid, spacing
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def jensen_shannon_divergence(P, Q):
|
| 180 |
+
if np.any(P < 0) or np.any(Q < 0):
|
| 181 |
+
raise ValueError('Negative values.')
|
| 182 |
+
if len(P) != len(Q):
|
| 183 |
+
raise ValueError('Non equal size.')
|
| 184 |
+
|
| 185 |
+
P_ = P / np.sum(P) # Ensure probabilities.
|
| 186 |
+
Q_ = Q / np.sum(Q)
|
| 187 |
+
|
| 188 |
+
e1 = entropy(P_, base=2)
|
| 189 |
+
e2 = entropy(Q_, base=2)
|
| 190 |
+
e_sum = entropy((P_ + Q_) / 2.0, base=2)
|
| 191 |
+
res = e_sum - ((e1 + e2) / 2.0)
|
| 192 |
+
|
| 193 |
+
res2 = _jsdiv(P_, Q_)
|
| 194 |
+
|
| 195 |
+
if not np.allclose(res, res2, atol=10e-5, rtol=0):
|
| 196 |
+
warnings.warn('Numerical values of two JSD methods don\'t agree.')
|
| 197 |
+
|
| 198 |
+
return res
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _jsdiv(P, Q):
|
| 202 |
+
'''another way of computing JSD'''
|
| 203 |
+
|
| 204 |
+
def _kldiv(A, B):
|
| 205 |
+
a = A.copy()
|
| 206 |
+
b = B.copy()
|
| 207 |
+
idx = np.logical_and(a > 0, b > 0)
|
| 208 |
+
a = a[idx]
|
| 209 |
+
b = b[idx]
|
| 210 |
+
return np.sum([v for v in a * np.log2(a / b)])
|
| 211 |
+
|
| 212 |
+
P_ = P / np.sum(P)
|
| 213 |
+
Q_ = Q / np.sum(Q)
|
| 214 |
+
|
| 215 |
+
M = 0.5 * (P_ + Q_)
|
| 216 |
+
|
| 217 |
+
return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def downsample_pc(points, n):
|
| 221 |
+
sample_idx = random.sample(list(range(points.shape[0])), n)
|
| 222 |
+
return points[sample_idx]
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def normalize_pc(points):
|
| 226 |
+
scale = np.max(np.abs(points))
|
| 227 |
+
points = points / scale
|
| 228 |
+
return points
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def collect_pc(cad_folder):
|
| 232 |
+
pc_path = find_files(os.path.join(cad_folder, 'ptl'), 'final_pcd.ply')
|
| 233 |
+
if len(pc_path) == 0:
|
| 234 |
+
return []
|
| 235 |
+
pc_path = pc_path[-1] # final pcd
|
| 236 |
+
pc = read_ply(pc_path)
|
| 237 |
+
if pc.shape[0] > N_POINTS:
|
| 238 |
+
pc = downsample_pc(pc, N_POINTS)
|
| 239 |
+
pc = normalize_pc(pc)
|
| 240 |
+
return pc
|
| 241 |
+
|
| 242 |
+
def collect_pc2(cad_folder):
|
| 243 |
+
pc = read_ply(cad_folder)
|
| 244 |
+
if pc.shape[0] > N_POINTS:
|
| 245 |
+
pc = downsample_pc(pc, N_POINTS)
|
| 246 |
+
pc = normalize_pc(pc)
|
| 247 |
+
return pc
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def main():
|
| 251 |
+
parser = argparse.ArgumentParser()
|
| 252 |
+
parser.add_argument("--fake", type=str)
|
| 253 |
+
parser.add_argument("--real", type=str)
|
| 254 |
+
parser.add_argument("--output", type=str)
|
| 255 |
+
split = 1
|
| 256 |
+
args = parser.parse_args()
|
| 257 |
+
if args.output is None:
|
| 258 |
+
args.output = args.fake + '_cad_results.txt'
|
| 259 |
+
chamfer_dist = ChamferDistance()
|
| 260 |
+
cd = []
|
| 261 |
+
for i in tqdm(range(952)):
|
| 262 |
+
fake_pcs = []
|
| 263 |
+
real_pcs = []
|
| 264 |
+
for j in range(split):
|
| 265 |
+
fake_index = i * split + j
|
| 266 |
+
fake_folder = os.path.join(args.fake, f'{fake_index:06d}')
|
| 267 |
+
if not os.path.exists(fake_folder):
|
| 268 |
+
continue
|
| 269 |
+
else:
|
| 270 |
+
fake_pc = collect_pc(fake_folder)
|
| 271 |
+
if len(fake_pc) == 0:
|
| 272 |
+
continue
|
| 273 |
+
fake_pcs.append(fake_pc)
|
| 274 |
+
|
| 275 |
+
real_folder = os.path.join(args.real, f'{i:06d}')
|
| 276 |
+
if not os.path.exists(real_folder):
|
| 277 |
+
continue
|
| 278 |
+
else:
|
| 279 |
+
real_pc = collect_pc(real_folder)
|
| 280 |
+
if len(real_pc) == 0:
|
| 281 |
+
continue
|
| 282 |
+
real_pcs.append(real_pc)
|
| 283 |
+
|
| 284 |
+
if len(fake_pcs) == 0 or len(real_pcs) == 0:
|
| 285 |
+
continue
|
| 286 |
+
sample_pcs = np.stack(fake_pcs, axis=0)
|
| 287 |
+
ref_pcs = np.stack(real_pcs, axis=0)
|
| 288 |
+
|
| 289 |
+
sample_pcs = torch.tensor(sample_pcs, dtype=torch.float32).cuda()
|
| 290 |
+
ref_pcs = torch.tensor(ref_pcs, dtype=torch.float32).cuda()
|
| 291 |
+
print(sample_pcs.shape, ref_pcs.shape)
|
| 292 |
+
dl, dr, idx1, idx2 = chamfer_dist(sample_pcs, ref_pcs)
|
| 293 |
+
min_val = (dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1).squeeze(0).min().item()
|
| 294 |
+
cd.append(min_val)
|
| 295 |
+
|
| 296 |
+
cd = np.array(cd)
|
| 297 |
+
mean = np.mean(cd)
|
| 298 |
+
median = np.median(cd)
|
| 299 |
+
print('mean:', mean)
|
| 300 |
+
print('median:', median)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
if __name__ == '__main__':
|
| 304 |
+
import time
|
| 305 |
+
start_time = time.time()
|
| 306 |
+
main()
|
| 307 |
+
end_time = time.time()
|
| 308 |
+
print(end_time - start_time)
|
CADFusion/src/test/dist_eval.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import random
|
| 7 |
+
import warnings
|
| 8 |
+
from glob import glob
|
| 9 |
+
from scipy.stats import entropy
|
| 10 |
+
from sklearn.neighbors import NearestNeighbors
|
| 11 |
+
from plyfile import PlyData
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from multiprocessing import Pool
|
| 14 |
+
from chamfer_distance import ChamferDistance
|
| 15 |
+
|
| 16 |
+
random.seed(0)
|
| 17 |
+
N_POINTS = 2000
|
| 18 |
+
NUM_TRHEADS = 16
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def find_files(folder, extension):
|
| 22 |
+
return sorted([Path(os.path.join(folder, f)) for f in os.listdir(folder) if f.endswith(extension)])
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def read_ply(path):
|
| 26 |
+
with open(path, 'rb') as f:
|
| 27 |
+
plydata = PlyData.read(f)
|
| 28 |
+
x = np.array(plydata['vertex']['x'])
|
| 29 |
+
y = np.array(plydata['vertex']['y'])
|
| 30 |
+
z = np.array(plydata['vertex']['z'])
|
| 31 |
+
vertex = np.stack([x, y, z], axis=1)
|
| 32 |
+
return vertex
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def distChamfer(a, b):
|
| 36 |
+
x, y = a, b
|
| 37 |
+
bs, num_points, points_dim = x.size()
|
| 38 |
+
xx = torch.bmm(x, x.transpose(2, 1))
|
| 39 |
+
yy = torch.bmm(y, y.transpose(2, 1))
|
| 40 |
+
zz = torch.bmm(x, y.transpose(2, 1))
|
| 41 |
+
diag_ind = torch.arange(0, num_points).to(a).long()
|
| 42 |
+
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
|
| 43 |
+
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
|
| 44 |
+
P = (rx.transpose(2, 1) + ry - 2 * zz)
|
| 45 |
+
return P.min(1)[0], P.min(2)[0]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _pairwise_CD(sample_pcs, ref_pcs, batch_size):
|
| 49 |
+
N_sample = sample_pcs.shape[0]
|
| 50 |
+
N_ref = ref_pcs.shape[0]
|
| 51 |
+
all_cd = []
|
| 52 |
+
all_emd = []
|
| 53 |
+
iterator = range(N_sample)
|
| 54 |
+
matched_gt = []
|
| 55 |
+
pbar = tqdm(iterator)
|
| 56 |
+
chamfer_dist = ChamferDistance()
|
| 57 |
+
|
| 58 |
+
for sample_b_start in pbar:
|
| 59 |
+
sample_batch = sample_pcs[sample_b_start]
|
| 60 |
+
|
| 61 |
+
cd_lst = []
|
| 62 |
+
emd_lst = []
|
| 63 |
+
for ref_b_start in range(0, N_ref, batch_size):
|
| 64 |
+
ref_b_end = min(N_ref, ref_b_start + batch_size)
|
| 65 |
+
ref_batch = ref_pcs[ref_b_start:ref_b_end]
|
| 66 |
+
|
| 67 |
+
batch_size_ref = ref_batch.size(0)
|
| 68 |
+
sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1)
|
| 69 |
+
sample_batch_exp = sample_batch_exp.contiguous()
|
| 70 |
+
|
| 71 |
+
dl, dr, idx1, idx2 = chamfer_dist(sample_batch_exp,ref_batch)
|
| 72 |
+
cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))
|
| 73 |
+
|
| 74 |
+
cd_lst = torch.cat(cd_lst, dim=1)
|
| 75 |
+
all_cd.append(cd_lst)
|
| 76 |
+
|
| 77 |
+
hit = np.argmin(cd_lst.detach().cpu().numpy()[0])
|
| 78 |
+
matched_gt.append(hit)
|
| 79 |
+
pbar.set_postfix({"cov": len(np.unique(matched_gt)) * 1.0 / N_ref})
|
| 80 |
+
|
| 81 |
+
all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref
|
| 82 |
+
|
| 83 |
+
return all_cd
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def compute_cov_mmd(sample_pcs, ref_pcs, batch_size):
|
| 87 |
+
all_dist = _pairwise_CD(sample_pcs, ref_pcs, batch_size)
|
| 88 |
+
print(all_dist.shape, flush=True)
|
| 89 |
+
N_sample, N_ref = all_dist.size(0), all_dist.size(1)
|
| 90 |
+
min_val_fromsmp, min_idx = torch.min(all_dist, dim=1)
|
| 91 |
+
min_val, _ = torch.min(all_dist, dim=0)
|
| 92 |
+
mmd = min_val.mean()
|
| 93 |
+
cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)
|
| 94 |
+
cov = torch.tensor(cov).to(all_dist)
|
| 95 |
+
|
| 96 |
+
return {
|
| 97 |
+
# 'med-CD': torch.diagonal(all_dist).median().item(),
|
| 98 |
+
'avg-CD': torch.diagonal(all_dist).mean().item(),
|
| 99 |
+
'COV-CD': cov.item(),
|
| 100 |
+
'MMD-CD': mmd.item()
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, in_unit_sphere, resolution=28):
|
| 105 |
+
'''Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models For 3D Point Clouds```.
|
| 106 |
+
Args:
|
| 107 |
+
sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points.
|
| 108 |
+
ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points.
|
| 109 |
+
resolution: (int) grid-resolution. Affects granularity of measurements.
|
| 110 |
+
'''
|
| 111 |
+
sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1]
|
| 112 |
+
ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1]
|
| 113 |
+
return jensen_shannon_divergence(sample_grid_var, ref_grid_var)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False):
|
| 117 |
+
'''Given a collection of point-clouds, estimate the entropy of the random variables
|
| 118 |
+
corresponding to occupancy-grid activation patterns.
|
| 119 |
+
Inputs:
|
| 120 |
+
pclouds: (numpy array) #point-clouds x points per point-cloud x 3
|
| 121 |
+
grid_resolution (int) size of occupancy grid that will be used.
|
| 122 |
+
'''
|
| 123 |
+
epsilon = 10e-4
|
| 124 |
+
bound = 1 + epsilon
|
| 125 |
+
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
|
| 126 |
+
print(abs(np.max(pclouds)), abs(np.min(pclouds)))
|
| 127 |
+
warnings.warn('Point-clouds are not in unit cube.')
|
| 128 |
+
|
| 129 |
+
if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:
|
| 130 |
+
warnings.warn('Point-clouds are not in unit sphere.')
|
| 131 |
+
|
| 132 |
+
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
|
| 133 |
+
grid_coordinates = grid_coordinates.reshape(-1, 3)
|
| 134 |
+
grid_counters = np.zeros(len(grid_coordinates))
|
| 135 |
+
grid_bernoulli_rvars = np.zeros(len(grid_coordinates))
|
| 136 |
+
nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates)
|
| 137 |
+
|
| 138 |
+
for pc in pclouds:
|
| 139 |
+
_, indices = nn.kneighbors(pc)
|
| 140 |
+
indices = np.squeeze(indices)
|
| 141 |
+
for i in indices:
|
| 142 |
+
grid_counters[i] += 1
|
| 143 |
+
indices = np.unique(indices)
|
| 144 |
+
for i in indices:
|
| 145 |
+
grid_bernoulli_rvars[i] += 1
|
| 146 |
+
|
| 147 |
+
acc_entropy = 0.0
|
| 148 |
+
n = float(len(pclouds))
|
| 149 |
+
for g in grid_bernoulli_rvars:
|
| 150 |
+
p = 0.0
|
| 151 |
+
if g > 0:
|
| 152 |
+
p = float(g) / n
|
| 153 |
+
acc_entropy += entropy([p, 1.0 - p])
|
| 154 |
+
|
| 155 |
+
return acc_entropy / len(grid_counters), grid_counters
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
|
| 159 |
+
'''Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,
|
| 160 |
+
that is placed in the unit-cube.
|
| 161 |
+
If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere.
|
| 162 |
+
'''
|
| 163 |
+
grid = np.ndarray((resolution, resolution, resolution, 3), np.float32)
|
| 164 |
+
spacing = 1.0 / float(resolution - 1) * 2
|
| 165 |
+
for i in range(resolution):
|
| 166 |
+
for j in range(resolution):
|
| 167 |
+
for k in range(resolution):
|
| 168 |
+
grid[i, j, k, 0] = i * spacing - 0.5 * 2
|
| 169 |
+
grid[i, j, k, 1] = j * spacing - 0.5 * 2
|
| 170 |
+
grid[i, j, k, 2] = k * spacing - 0.5 * 2
|
| 171 |
+
|
| 172 |
+
if clip_sphere:
|
| 173 |
+
grid = grid.reshape(-1, 3)
|
| 174 |
+
grid = grid[np.linalg.norm(grid, axis=1) <= 0.5]
|
| 175 |
+
|
| 176 |
+
return grid, spacing
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def jensen_shannon_divergence(P, Q):
|
| 180 |
+
if np.any(P < 0) or np.any(Q < 0):
|
| 181 |
+
raise ValueError('Negative values.')
|
| 182 |
+
if len(P) != len(Q):
|
| 183 |
+
raise ValueError('Non equal size.')
|
| 184 |
+
|
| 185 |
+
P_ = P / np.sum(P) # Ensure probabilities.
|
| 186 |
+
Q_ = Q / np.sum(Q)
|
| 187 |
+
|
| 188 |
+
e1 = entropy(P_, base=2)
|
| 189 |
+
e2 = entropy(Q_, base=2)
|
| 190 |
+
e_sum = entropy((P_ + Q_) / 2.0, base=2)
|
| 191 |
+
res = e_sum - ((e1 + e2) / 2.0)
|
| 192 |
+
|
| 193 |
+
res2 = _jsdiv(P_, Q_)
|
| 194 |
+
|
| 195 |
+
if not np.allclose(res, res2, atol=10e-5, rtol=0):
|
| 196 |
+
warnings.warn('Numerical values of two JSD methods don\'t agree.')
|
| 197 |
+
|
| 198 |
+
return res
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _jsdiv(P, Q):
|
| 202 |
+
'''another way of computing JSD'''
|
| 203 |
+
|
| 204 |
+
def _kldiv(A, B):
|
| 205 |
+
a = A.copy()
|
| 206 |
+
b = B.copy()
|
| 207 |
+
idx = np.logical_and(a > 0, b > 0)
|
| 208 |
+
a = a[idx]
|
| 209 |
+
b = b[idx]
|
| 210 |
+
return np.sum([v for v in a * np.log2(a / b)])
|
| 211 |
+
|
| 212 |
+
P_ = P / np.sum(P)
|
| 213 |
+
Q_ = Q / np.sum(Q)
|
| 214 |
+
|
| 215 |
+
M = 0.5 * (P_ + Q_)
|
| 216 |
+
|
| 217 |
+
return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def downsample_pc(points, n):
|
| 221 |
+
sample_idx = random.sample(list(range(points.shape[0])), n)
|
| 222 |
+
return points[sample_idx]
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def normalize_pc(points):
|
| 226 |
+
scale = np.max(np.abs(points))
|
| 227 |
+
points = points / scale
|
| 228 |
+
return points
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def collect_pc(cad_folder):
|
| 232 |
+
pc_path = find_files(os.path.join(cad_folder, 'ptl'), 'final_pcd.ply')
|
| 233 |
+
if len(pc_path) == 0:
|
| 234 |
+
return []
|
| 235 |
+
pc_path = pc_path[-1] # final pcd
|
| 236 |
+
pc = read_ply(pc_path)
|
| 237 |
+
if pc.shape[0] > N_POINTS:
|
| 238 |
+
pc = downsample_pc(pc, N_POINTS)
|
| 239 |
+
pc = normalize_pc(pc)
|
| 240 |
+
return pc
|
| 241 |
+
|
| 242 |
+
def collect_pc2(cad_folder):
|
| 243 |
+
pc = read_ply(cad_folder)
|
| 244 |
+
if pc.shape[0] > N_POINTS:
|
| 245 |
+
pc = downsample_pc(pc, N_POINTS)
|
| 246 |
+
pc = normalize_pc(pc)
|
| 247 |
+
return pc
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def main():
|
| 251 |
+
parser = argparse.ArgumentParser()
|
| 252 |
+
parser.add_argument("--fake", type=str)
|
| 253 |
+
parser.add_argument("--real", type=str)
|
| 254 |
+
parser.add_argument("--output", type=str)
|
| 255 |
+
parser.add_argument("--n_test", type=int, default=200)
|
| 256 |
+
parser.add_argument("--multi", type=int, default=1)
|
| 257 |
+
parser.add_argument("--times", type=int, default=10)
|
| 258 |
+
parser.add_argument("--batch_size", type=int, default=64)
|
| 259 |
+
args = parser.parse_args()
|
| 260 |
+
|
| 261 |
+
print("n_test: {}, multiplier: {}, repeat times: {}".format(args.n_test, args.multi, args.times))
|
| 262 |
+
if args.output is None:
|
| 263 |
+
args.output = args.fake + '_cad_results.txt'
|
| 264 |
+
|
| 265 |
+
# Load fake pcd
|
| 266 |
+
|
| 267 |
+
fake_folders = sorted(glob(args.fake+'/*/'))
|
| 268 |
+
real_folders = sorted(glob(args.real+'/*/'))
|
| 269 |
+
|
| 270 |
+
fake_overlapped = []
|
| 271 |
+
real_overlapped = []
|
| 272 |
+
for i in range(800):
|
| 273 |
+
if f'{args.fake}/{i:06d}/' in fake_folders and f'{args.real}/{i:06d}/' in real_folders:
|
| 274 |
+
if len(glob(f'{args.fake}/{i:06d}/ptl/*')) > 0 and len(glob(f'{args.real}/{i:06d}/ptl/*')) > 0:
|
| 275 |
+
fake_overlapped.append(f'{args.fake}/{i:06d}/')
|
| 276 |
+
real_overlapped.append(f'{args.real}/{i:06d}/')
|
| 277 |
+
print(len(fake_overlapped), len(real_overlapped))
|
| 278 |
+
|
| 279 |
+
fake_folders = fake_overlapped
|
| 280 |
+
real_folders = real_overlapped
|
| 281 |
+
|
| 282 |
+
sample_pcs = []
|
| 283 |
+
load_iter = Pool(NUM_TRHEADS).imap(collect_pc, fake_folders)
|
| 284 |
+
for pc in tqdm(load_iter, total=len(fake_folders)):
|
| 285 |
+
if len(pc) > 0:
|
| 286 |
+
sample_pcs.append(pc)
|
| 287 |
+
sample_pcs = np.stack(sample_pcs, axis=0)
|
| 288 |
+
print("fake point clouds: {}".format(sample_pcs.shape))
|
| 289 |
+
|
| 290 |
+
# Load reference pcd
|
| 291 |
+
ref_pcs = []
|
| 292 |
+
load_iter = Pool(NUM_TRHEADS).imap(collect_pc, real_folders)
|
| 293 |
+
for pc in tqdm(load_iter, total=len(real_folders)):
|
| 294 |
+
if len(pc) > 0:
|
| 295 |
+
ref_pcs.append(pc)
|
| 296 |
+
ref_pcs = np.stack(ref_pcs, axis=0)
|
| 297 |
+
print("real point clouds: {}".format(ref_pcs.shape))
|
| 298 |
+
|
| 299 |
+
# # Testing
|
| 300 |
+
fp = open(args.output, "w")
|
| 301 |
+
|
| 302 |
+
rand_sample_pcs = sample_pcs
|
| 303 |
+
rand_ref_pcs = ref_pcs
|
| 304 |
+
|
| 305 |
+
jsd = jsd_between_point_cloud_sets(rand_sample_pcs, rand_ref_pcs, in_unit_sphere=False)
|
| 306 |
+
with torch.no_grad():
|
| 307 |
+
rand_sample_pcs = torch.tensor(rand_sample_pcs).cuda()
|
| 308 |
+
rand_ref_pcs = torch.tensor(rand_ref_pcs).cuda()
|
| 309 |
+
result = compute_cov_mmd(rand_sample_pcs, rand_ref_pcs, batch_size=args.batch_size)
|
| 310 |
+
result.update({"JSD": jsd})
|
| 311 |
+
|
| 312 |
+
print(result)
|
| 313 |
+
print(result, file=fp)
|
| 314 |
+
fp.close()
|
| 315 |
+
|
| 316 |
+
# Testing
|
| 317 |
+
# fp = open(args.output, "w")
|
| 318 |
+
# result_list = []
|
| 319 |
+
# for i in range(args.times):
|
| 320 |
+
# print("iteration {}...".format(i))
|
| 321 |
+
# select_idx = random.sample(list(range(len(sample_pcs))), int(args.multi * args.n_test))
|
| 322 |
+
# rand_sample_pcs = sample_pcs[select_idx]
|
| 323 |
+
|
| 324 |
+
# select_idx = random.sample(list(range(len(ref_pcs))), args.n_test)
|
| 325 |
+
# rand_ref_pcs = ref_pcs[select_idx]
|
| 326 |
+
|
| 327 |
+
# jsd = jsd_between_point_cloud_sets(rand_sample_pcs, rand_ref_pcs, in_unit_sphere=False)
|
| 328 |
+
# with torch.no_grad():
|
| 329 |
+
# rand_sample_pcs = torch.tensor(rand_sample_pcs).cuda()
|
| 330 |
+
# rand_ref_pcs = torch.tensor(rand_ref_pcs).cuda()
|
| 331 |
+
# result = compute_cov_mmd(rand_sample_pcs, rand_ref_pcs, batch_size=args.batch_size)
|
| 332 |
+
# result.update({"JSD": jsd})
|
| 333 |
+
|
| 334 |
+
# print(result)
|
| 335 |
+
# print(result, file=fp)
|
| 336 |
+
# result_list.append(result)
|
| 337 |
+
# avg_result = {}
|
| 338 |
+
# for k in result_list[0].keys():
|
| 339 |
+
# avg_result.update({"avg-" + k: np.mean([x[k] for x in result_list])})
|
| 340 |
+
# print("average result:")
|
| 341 |
+
# print(avg_result)
|
| 342 |
+
# print(avg_result, file=fp)
|
| 343 |
+
# fp.close()
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
if __name__ == '__main__':
|
| 347 |
+
import time
|
| 348 |
+
start_time = time.time()
|
| 349 |
+
main()
|
| 350 |
+
end_time = time.time()
|
| 351 |
+
print(end_time - start_time)
|
CADFusion/src/test/f1_eval.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
We did not implement the Hungarian matching algorithm from text2cad, but provided a vanilla matching for f1. It is because
|
| 7 |
+
1. We argue that CAD scenarios are too complicated to be evaluated with a simple matching algorithm, especially when performed on the primitive level. Moreover, matching every primitive exactly is against the goal of our framework which attempt to encourage CAD models generate visually correct objects instead of accurate primitives compared to the ground truth.
|
| 8 |
+
2. In our exploration, discrepancies on the number of primitives between model generation and the ground truth usually indicates the entire failure of the sketch so that using any of the algorithm does not affect the final evaluation result anyway.
|
| 9 |
+
3. Our evaluation is a lower bound of the performance of the model on the matching algorithm, therefore it does not affect the overall integrety of our framework.
|
| 10 |
+
|
| 11 |
+
We encourage users to implement their own matching algorithm if they want to evaluate the model with a more strict metric.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
parser = argparse.ArgumentParser(description='Evaluate F1 scores for generated sketches.')
|
| 15 |
+
parser.add_argument('--test-path', type=str, default='data/sl_data/test.jsonl', help='Path to the JSONL file containing test data')
|
| 16 |
+
parser.add_argument('--file_path', type=str, required=True, help='Path to the JSONL file containing generated sketches.')
|
| 17 |
+
args = parser.parse_args()
|
| 18 |
+
file_path = args.file_path
|
| 19 |
+
data_path = args.test_path
|
| 20 |
+
with open(data_path, 'r') as f:
|
| 21 |
+
data = json.load(f)
|
| 22 |
+
|
| 23 |
+
def find_f1(ground_truth, pred, token):
|
| 24 |
+
num_tok_gt = len(re.findall(token, ground_truth))
|
| 25 |
+
num_tok_pred = len(re.findall(token, pred))
|
| 26 |
+
# print(num_tok_gt, num_tok_pred)
|
| 27 |
+
min_tok = min(num_tok_gt, num_tok_pred)
|
| 28 |
+
if min_tok <= 0:
|
| 29 |
+
return -1
|
| 30 |
+
tok_recall = min_tok / num_tok_gt
|
| 31 |
+
tok_precision = min_tok / num_tok_pred
|
| 32 |
+
tok_f1 = 2 * tok_recall * tok_precision / (tok_recall + tok_precision)
|
| 33 |
+
return tok_f1
|
| 34 |
+
|
| 35 |
+
with open(file_path, 'r') as f:
|
| 36 |
+
gen = json.load(f)
|
| 37 |
+
line = []
|
| 38 |
+
arc = []
|
| 39 |
+
circle = []
|
| 40 |
+
ext = []
|
| 41 |
+
for i in range(1000):
|
| 42 |
+
ground_truth = data[i]['output']
|
| 43 |
+
pred = gen[i]['output']
|
| 44 |
+
ext_f1 = find_f1(ground_truth, pred, r'<extrude_end>')
|
| 45 |
+
if ext_f1 > 0:
|
| 46 |
+
ext.append(ext_f1)
|
| 47 |
+
|
| 48 |
+
skext_gt = ground_truth.split('<extrude_end>')[:-1]
|
| 49 |
+
skext_pred = pred.split('<extrude_end>')[:-1]
|
| 50 |
+
min_len_skext = min(len(skext_gt), len(skext_pred))
|
| 51 |
+
if min_len_skext == 0:
|
| 52 |
+
continue
|
| 53 |
+
line_f1 = 0
|
| 54 |
+
arc_f1 = 0
|
| 55 |
+
circle_f1 = 0
|
| 56 |
+
for gt, pr in zip(skext_gt, skext_pred):
|
| 57 |
+
line_f1 += find_f1(gt, pr, r'line.*?<curve_end>')
|
| 58 |
+
arc_f1 += find_f1(gt, pr, r'arc.*?<curve_end>')
|
| 59 |
+
circle_f1 += find_f1(gt, pr, r'circle.*?<curve_end>')
|
| 60 |
+
|
| 61 |
+
line_f1 = line_f1 / min_len_skext
|
| 62 |
+
arc_f1 = arc_f1 / min_len_skext
|
| 63 |
+
circle_f1 = circle_f1 / min_len_skext
|
| 64 |
+
if line_f1 > 0:
|
| 65 |
+
line.append(line_f1)
|
| 66 |
+
if arc_f1 > 0:
|
| 67 |
+
arc.append(arc_f1)
|
| 68 |
+
if circle_f1 > 0:
|
| 69 |
+
circle.append(circle_f1)
|
| 70 |
+
line_avg = sum(line) / len(line)
|
| 71 |
+
arc_avg = sum(arc) / len(arc)
|
| 72 |
+
circle_avg = sum(circle) / len(circle)
|
| 73 |
+
avgf1 = (line_avg + arc_avg + circle_avg) / 3
|
| 74 |
+
print(file_path, line_avg, arc_avg, circle_avg, avgf1, sum(ext) / len(ext))
|
CADFusion/src/test/generate.ipynb
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "2d243f81",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import argparse\n",
|
| 11 |
+
"import random\n",
|
| 12 |
+
"import os\n",
|
| 13 |
+
"import subprocess\n",
|
| 14 |
+
"import shutil\n",
|
| 15 |
+
"\n",
|
| 16 |
+
"from PIL import Image\n",
|
| 17 |
+
"from huggingface_hub import login\n",
|
| 18 |
+
"from utils import MAX_LENGTH, prepare_model_and_tokenizer\n",
|
| 19 |
+
"from visual_utils.parser import CADparser, write_obj_sample\n",
|
| 20 |
+
"from IPython.display import clear_output"
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "markdown",
|
| 25 |
+
"id": "b98812ed",
|
| 26 |
+
"metadata": {},
|
| 27 |
+
"source": [
|
| 28 |
+
"### Initializing model and arguments"
|
| 29 |
+
]
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"cell_type": "code",
|
| 33 |
+
"execution_count": 49,
|
| 34 |
+
"id": "df625563",
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"outputs": [],
|
| 37 |
+
"source": [
|
| 38 |
+
"parser = argparse.ArgumentParser()\n",
|
| 39 |
+
"# parser.add_argument(\"--model-name\", type=str, default=\"llama3\")\n",
|
| 40 |
+
"parser.add_argument(\"--device-map\", type=str, default='auto')\n",
|
| 41 |
+
"parser.add_argument(\"--lora-rank\", type=int, default=32)\n",
|
| 42 |
+
"parser.add_argument(\"--lora-alpha\", type=int, default=32)\n",
|
| 43 |
+
"parser.add_argument(\"--lora-dropout\", type=float, default=0.05)\n",
|
| 44 |
+
"parser.add_argument(\"--pretrained-path\", type=str, required=True)\n",
|
| 45 |
+
"parser.add_argument(\"--top-p\", type=float, default=0.9)\n",
|
| 46 |
+
"parser.add_argument(\"--temperature\", type=float, default=0.9)\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"arguments = ['--pretrained-path', '/home/v-wangruiyu/repos/CADFusion/exp/model_ckpt/CADFusion_v1_1', '--temperature', '0.3']\n",
|
| 49 |
+
"args = parser.parse_args(arguments)"
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "code",
|
| 54 |
+
"execution_count": null,
|
| 55 |
+
"id": "5624f320",
|
| 56 |
+
"metadata": {},
|
| 57 |
+
"outputs": [],
|
| 58 |
+
"source": [
|
| 59 |
+
"login() # put your own hf token to access llama\n",
|
| 60 |
+
"random.seed(0)\n",
|
| 61 |
+
"model, tokenizer = prepare_model_and_tokenizer(args)\n",
|
| 62 |
+
"model.eval()\n",
|
| 63 |
+
"clear_output()"
|
| 64 |
+
]
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"cell_type": "markdown",
|
| 68 |
+
"id": "86b9cb09",
|
| 69 |
+
"metadata": {},
|
| 70 |
+
"source": [
|
| 71 |
+
"### Custom prompting"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "code",
|
| 76 |
+
"execution_count": 180,
|
| 77 |
+
"id": "db06d560",
|
| 78 |
+
"metadata": {},
|
| 79 |
+
"outputs": [],
|
| 80 |
+
"source": [
|
| 81 |
+
"description = input(\"Please input a description of a 3D shape: \")\n",
|
| 82 |
+
"# description = 'The 3D shape is a cylinder.'\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"prompt = 'Below is a description of a 3D shape:\\n'\n",
|
| 85 |
+
"prompt += description\n",
|
| 86 |
+
"prompt += '\\nGenerate a Computer-Aided Design(CAD) command sequence of the 3D shape:\\n'"
|
| 87 |
+
]
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"cell_type": "markdown",
|
| 91 |
+
"id": "bb16f861",
|
| 92 |
+
"metadata": {},
|
| 93 |
+
"source": [
|
| 94 |
+
"### Inference and rendering"
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"cell_type": "markdown",
|
| 99 |
+
"id": "59c5f38e",
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"source": [
|
| 102 |
+
"#### Model Inference"
|
| 103 |
+
]
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"cell_type": "code",
|
| 107 |
+
"execution_count": 181,
|
| 108 |
+
"id": "ab5ff2e8",
|
| 109 |
+
"metadata": {},
|
| 110 |
+
"outputs": [
|
| 111 |
+
{
|
| 112 |
+
"name": "stderr",
|
| 113 |
+
"output_type": "stream",
|
| 114 |
+
"text": [
|
| 115 |
+
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
|
| 116 |
+
]
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
"data": {
|
| 120 |
+
"text/plain": [
|
| 121 |
+
"'circle,31,53,31,9,53,31,9,31 <curve_end> <loop_end> circle,31,51,31,11,51,31,11,31 <curve_end> <loop_end> <face_end> circle,31,51,31,11,51,31,11,31 <curve_end> <loop_end> <face_end> <sketch_end> add,0,62,31,31,31,1,0,0,0,0,1,0,-1,0,7,31,31 <extrude_end>'"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
"execution_count": 181,
|
| 125 |
+
"metadata": {},
|
| 126 |
+
"output_type": "execute_result"
|
| 127 |
+
}
|
| 128 |
+
],
|
| 129 |
+
"source": [
|
| 130 |
+
"batch = tokenizer(\n",
|
| 131 |
+
" prompt,\n",
|
| 132 |
+
" return_tensors=\"pt\",\n",
|
| 133 |
+
")\n",
|
| 134 |
+
"batch = {k: v.cuda() for k, v in batch.items()}\n",
|
| 135 |
+
"\n",
|
| 136 |
+
"generate_ids = model.generate(\n",
|
| 137 |
+
" **batch,\n",
|
| 138 |
+
" do_sample=True,\n",
|
| 139 |
+
" max_new_tokens=MAX_LENGTH,\n",
|
| 140 |
+
" temperature=args.temperature,\n",
|
| 141 |
+
" top_p=args.top_p,\n",
|
| 142 |
+
" repetition_penalty=1.3,\n",
|
| 143 |
+
")\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"gen_strs = tokenizer.batch_decode(\n",
|
| 146 |
+
" generate_ids,\n",
|
| 147 |
+
" skip_special_tokens=True,\n",
|
| 148 |
+
" clean_up_tokenization_spaces=False,\n",
|
| 149 |
+
")\n",
|
| 150 |
+
"gen_strs = gen_strs[0][len(prompt):]\n",
|
| 151 |
+
"gen_strs"
|
| 152 |
+
]
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"cell_type": "markdown",
|
| 156 |
+
"id": "f56d6fcf",
|
| 157 |
+
"metadata": {},
|
| 158 |
+
"source": [
|
| 159 |
+
"#### Render .obj file"
|
| 160 |
+
]
|
| 161 |
+
},
|
| 162 |
+
{
|
| 163 |
+
"cell_type": "code",
|
| 164 |
+
"execution_count": 182,
|
| 165 |
+
"id": "95498ccb",
|
| 166 |
+
"metadata": {},
|
| 167 |
+
"outputs": [],
|
| 168 |
+
"source": [
|
| 169 |
+
"out_path = 'visual_cache/gen_obj'\n",
|
| 170 |
+
"# remove the existing output directory if it exists\n",
|
| 171 |
+
"if os.path.exists(out_path):\n",
|
| 172 |
+
" shutil.rmtree(out_path)\n",
|
| 173 |
+
"# create the output directory\n",
|
| 174 |
+
"os.makedirs(out_path, exist_ok=True)\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"cad_parser = CADparser(bit=6)\n",
|
| 177 |
+
"parsed_data = cad_parser.perform(gen_strs)\n",
|
| 178 |
+
"write_obj_sample(out_path, parsed_data)"
|
| 179 |
+
]
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"cell_type": "markdown",
|
| 183 |
+
"id": "79b5dfaf",
|
| 184 |
+
"metadata": {},
|
| 185 |
+
"source": [
|
| 186 |
+
"#### Render .step, .stl, .ply files\n",
|
| 187 |
+
"N.B. if the Statistics on Transfer logs do not show up, the model may not have produced renderable outputs. Re-run the inference or change your prompt to see if it gets better results. "
|
| 188 |
+
]
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
"cell_type": "code",
|
| 192 |
+
"execution_count": null,
|
| 193 |
+
"id": "8a49694f",
|
| 194 |
+
"metadata": {},
|
| 195 |
+
"outputs": [],
|
| 196 |
+
"source": [
|
| 197 |
+
"out_path = os.path.abspath(out_path)\n",
|
| 198 |
+
"py_path = os.path.abspath('../rendering_utils/parser_visual.py')\n",
|
| 199 |
+
"subprocess.run(['python3', py_path, '--data_folder', out_path, '--single-file'])\n",
|
| 200 |
+
"py_path = os.path.abspath('../rendering_utils/ptl_sampler.py')\n",
|
| 201 |
+
"subprocess.run(['python3', py_path, '--in_dir', out_path, '--out_dir', 'ptl', '--single-file'])\n",
|
| 202 |
+
"# clear_output()"
|
| 203 |
+
]
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"cell_type": "markdown",
|
| 207 |
+
"id": "0e0f1fd1",
|
| 208 |
+
"metadata": {},
|
| 209 |
+
"source": [
|
| 210 |
+
"#### Image rendering"
|
| 211 |
+
]
|
| 212 |
+
},
|
| 213 |
+
{
|
| 214 |
+
"cell_type": "code",
|
| 215 |
+
"execution_count": null,
|
| 216 |
+
"id": "586f3a91",
|
| 217 |
+
"metadata": {},
|
| 218 |
+
"outputs": [],
|
| 219 |
+
"source": [
|
| 220 |
+
"visual_obj_path = 'visual_cache'\n",
|
| 221 |
+
"output_figure_path = 'visual_cache/figures'\n",
|
| 222 |
+
"if os.path.exists(output_figure_path):\n",
|
| 223 |
+
" shutil.rmtree(output_figure_path)\n",
|
| 224 |
+
"py_path = os.path.abspath('../rendering_utils/img_renderer.py')\n",
|
| 225 |
+
"os.makedirs(output_figure_path, exist_ok=True)\n",
|
| 226 |
+
"try:\n",
|
| 227 |
+
" xvfb_process = subprocess.Popen(\n",
|
| 228 |
+
" [\"Xvfb\", \":99\", \"-screen\", \"0\", \"640x480x24\"],\n",
|
| 229 |
+
" stdout=subprocess.DEVNULL,\n",
|
| 230 |
+
" stderr=subprocess.DEVNULL\n",
|
| 231 |
+
" )\n",
|
| 232 |
+
" print(\"Xvfb started in the background.\")\n",
|
| 233 |
+
"except FileNotFoundError:\n",
|
| 234 |
+
" print(\"Error: Xvfb not found. Please ensure it is installed and in your system's PATH.\")\n",
|
| 235 |
+
"\n",
|
| 236 |
+
"os.environ['DISPLAY'] = ':99'\n",
|
| 237 |
+
"try:\n",
|
| 238 |
+
" subprocess.run(\n",
|
| 239 |
+
" ['python3', py_path, '--input_dir', visual_obj_path, '--output_dir', output_figure_path]\n",
|
| 240 |
+
" )\n",
|
| 241 |
+
" print(\"Rendering script completed successfully.\")\n",
|
| 242 |
+
"finally:\n",
|
| 243 |
+
" if xvfb_process.poll() is None: # Check if Xvfb is still running\n",
|
| 244 |
+
" xvfb_process.terminate()\n",
|
| 245 |
+
" print(\"Xvfb terminated.\")\n",
|
| 246 |
+
" else:\n",
|
| 247 |
+
" print(\"Xvfb already exited.\")\n",
|
| 248 |
+
" \n",
|
| 249 |
+
"del os.environ['DISPLAY']\n",
|
| 250 |
+
"clear_output()\n",
|
| 251 |
+
"\n",
|
| 252 |
+
"input_image_path = os.path.join(output_figure_path, 'gen_ob.png')\n",
|
| 253 |
+
"if os.path.exists(input_image_path):\n",
|
| 254 |
+
" img = Image.open(input_image_path)\n",
|
| 255 |
+
" img.show()\n",
|
| 256 |
+
"else:\n",
|
| 257 |
+
" print(f\"{input_image_path} does not exist.\")"
|
| 258 |
+
]
|
| 259 |
+
},
|
| 260 |
+
{
|
| 261 |
+
"cell_type": "markdown",
|
| 262 |
+
"id": "c78fed0f",
|
| 263 |
+
"metadata": {},
|
| 264 |
+
"source": [
|
| 265 |
+
"#### Files retrieval\n",
|
| 266 |
+
"By default, the produced step, stl, obj and ply files are stored under the visual_cache folder. You can save them to your custom places for further use. Do not put them in the cache folder as they will be deleted after the next run."
|
| 267 |
+
]
|
| 268 |
+
}
|
| 269 |
+
],
|
| 270 |
+
"metadata": {
|
| 271 |
+
"kernelspec": {
|
| 272 |
+
"display_name": "cdfs",
|
| 273 |
+
"language": "python",
|
| 274 |
+
"name": "python3"
|
| 275 |
+
},
|
| 276 |
+
"language_info": {
|
| 277 |
+
"codemirror_mode": {
|
| 278 |
+
"name": "ipython",
|
| 279 |
+
"version": 3
|
| 280 |
+
},
|
| 281 |
+
"file_extension": ".py",
|
| 282 |
+
"mimetype": "text/x-python",
|
| 283 |
+
"name": "python",
|
| 284 |
+
"nbconvert_exporter": "python",
|
| 285 |
+
"pygments_lexer": "ipython3",
|
| 286 |
+
"version": "3.9.23"
|
| 287 |
+
}
|
| 288 |
+
},
|
| 289 |
+
"nbformat": 4,
|
| 290 |
+
"nbformat_minor": 5
|
| 291 |
+
}
|
CADFusion/src/test/inference.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
from huggingface_hub import login
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from utils import MAX_LENGTH, prepare_model_and_tokenizer
|
| 8 |
+
|
| 9 |
+
login()
|
| 10 |
+
|
| 11 |
+
random.seed(0)
|
| 12 |
+
|
| 13 |
+
def conditional_sample(args):
|
| 14 |
+
model, tokenizer = prepare_model_and_tokenizer(args)
|
| 15 |
+
|
| 16 |
+
model.eval()
|
| 17 |
+
with open(args.in_path, 'r', encoding='utf-8') as file:
|
| 18 |
+
data = json.load(file)
|
| 19 |
+
|
| 20 |
+
print(data[0])
|
| 21 |
+
data = [item for item in data if item['description'] != 'null']
|
| 22 |
+
|
| 23 |
+
global_count=0
|
| 24 |
+
responses = []
|
| 25 |
+
if args.full:
|
| 26 |
+
data=data
|
| 27 |
+
else:
|
| 28 |
+
random.shuffle(data)
|
| 29 |
+
data = data[:args.sample_len]
|
| 30 |
+
|
| 31 |
+
for item in tqdm(data):
|
| 32 |
+
prompts = []
|
| 33 |
+
for _ in range(args.num_samples):
|
| 34 |
+
prompt = 'Below is a description of a 3D shape:\n'
|
| 35 |
+
prompt += item['description']
|
| 36 |
+
prompt += '\nGenerate a Computer-Aided Design(CAD) command sequence of the 3D shape:\n'
|
| 37 |
+
|
| 38 |
+
prompts.append(prompt)
|
| 39 |
+
|
| 40 |
+
outputs = []
|
| 41 |
+
|
| 42 |
+
while len(outputs) < args.num_samples:
|
| 43 |
+
batch_prompts = prompts[len(outputs) : len(outputs) + args.batch_size]
|
| 44 |
+
|
| 45 |
+
batch = tokenizer(
|
| 46 |
+
list(batch_prompts),
|
| 47 |
+
return_tensors="pt",
|
| 48 |
+
)
|
| 49 |
+
batch = {k: v.cuda() for k, v in batch.items()}
|
| 50 |
+
|
| 51 |
+
generate_ids = model.generate(
|
| 52 |
+
**batch,
|
| 53 |
+
do_sample=True,
|
| 54 |
+
max_new_tokens=MAX_LENGTH,
|
| 55 |
+
temperature=args.temperature,
|
| 56 |
+
top_p=args.top_p,
|
| 57 |
+
repetition_penalty=1.3,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
gen_strs = tokenizer.batch_decode(
|
| 61 |
+
generate_ids,
|
| 62 |
+
skip_special_tokens=True,
|
| 63 |
+
clean_up_tokenization_spaces=False,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
outputs.extend(gen_strs)
|
| 67 |
+
print(f"Generated {len(outputs)}/{args.num_samples}samples.")
|
| 68 |
+
|
| 69 |
+
for prompt, output in zip(prompts, outputs):
|
| 70 |
+
result = {
|
| 71 |
+
'index': global_count,
|
| 72 |
+
# 'pic_name': item['pic_name'],
|
| 73 |
+
'ground_truth': item['command_sequence'],
|
| 74 |
+
'description': item['description'],
|
| 75 |
+
'prompt': prompt,
|
| 76 |
+
'output': output[len(prompt):]
|
| 77 |
+
}
|
| 78 |
+
if 'original_seq' in item.keys():
|
| 79 |
+
result['original_seq'] = item['original_seq']
|
| 80 |
+
responses.append(result)
|
| 81 |
+
global_count += 1
|
| 82 |
+
|
| 83 |
+
with open(args.out_path, "w+") as f:
|
| 84 |
+
json.dump(responses, f, indent=4)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
parser = argparse.ArgumentParser()
|
| 90 |
+
parser.add_argument("--model-name", type=str, default="llama3")
|
| 91 |
+
parser.add_argument("--lora-rank", type=int, default=32)
|
| 92 |
+
parser.add_argument("--lora-alpha", type=int, default=32)
|
| 93 |
+
parser.add_argument("--lora-dropout", type=float, default=0.05)
|
| 94 |
+
parser.add_argument("--sample-len", type=int, default=100)
|
| 95 |
+
parser.add_argument("--pretrained-path", type=str, required=True)
|
| 96 |
+
parser.add_argument("--num-samples", type=int, default=500)
|
| 97 |
+
parser.add_argument("--batch-size", type=int, default=32)
|
| 98 |
+
parser.add_argument("--in-path", type=str, default="test_description.json")
|
| 99 |
+
parser.add_argument("--out-path", type=str, default="cad_samples.jsonl")
|
| 100 |
+
parser.add_argument("--temperature", type=float, default=0.9)
|
| 101 |
+
parser.add_argument("--device-map", type=str, default='auto')
|
| 102 |
+
parser.add_argument("--top-p", type=float, default=0.9)
|
| 103 |
+
parser.add_argument("--full", action="store_true", default=False)
|
| 104 |
+
args = parser.parse_args()
|
| 105 |
+
|
| 106 |
+
conditional_sample(args)
|
CADFusion/src/test/utils.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import transformers
|
| 3 |
+
from peft import LoraConfig, PeftModel, get_peft_model
|
| 4 |
+
|
| 5 |
+
IGNORE_INDEX = -100
|
| 6 |
+
MAX_LENGTH = 512
|
| 7 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
| 8 |
+
DEFAULT_EOS_TOKEN = "</s>"
|
| 9 |
+
DEFAULT_BOS_TOKEN = "<s>"
|
| 10 |
+
DEFAULT_UNK_TOKEN = "<unk>"
|
| 11 |
+
|
| 12 |
+
def smart_tokenizer_and_embedding_resize(
|
| 13 |
+
special_tokens_dict,
|
| 14 |
+
llama_tokenizer,
|
| 15 |
+
model,
|
| 16 |
+
):
|
| 17 |
+
"""Resize tokenizer and embedding.
|
| 18 |
+
|
| 19 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
| 20 |
+
"""
|
| 21 |
+
num_new_tokens = llama_tokenizer.add_special_tokens(special_tokens_dict)
|
| 22 |
+
model.resize_token_embeddings(len(llama_tokenizer))
|
| 23 |
+
|
| 24 |
+
if num_new_tokens > 0:
|
| 25 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
| 26 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
| 27 |
+
|
| 28 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
| 29 |
+
dim=0, keepdim=True
|
| 30 |
+
)
|
| 31 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
| 32 |
+
dim=0, keepdim=True
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 36 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
| 37 |
+
|
| 38 |
+
def prepare_model_and_tokenizer(args):
|
| 39 |
+
model_id = "meta-llama/Meta-Llama-3-8B"
|
| 40 |
+
print(f"Model size: {model_id}")
|
| 41 |
+
if hasattr(args, 'device_map'):
|
| 42 |
+
device_map = args.device_map
|
| 43 |
+
else:
|
| 44 |
+
device_map = 'auto'
|
| 45 |
+
pipeline = transformers.pipeline("text2text-generation",
|
| 46 |
+
model=model_id, model_kwargs={"torch_dtype": torch.float32}, device_map=device_map)
|
| 47 |
+
tokenizer = pipeline.tokenizer
|
| 48 |
+
base_model = pipeline.model
|
| 49 |
+
|
| 50 |
+
special_tokens_dict = dict()
|
| 51 |
+
if tokenizer.pad_token is None:
|
| 52 |
+
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
|
| 53 |
+
if tokenizer.eos_token is None:
|
| 54 |
+
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
|
| 55 |
+
if tokenizer.bos_token is None:
|
| 56 |
+
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
|
| 57 |
+
if tokenizer.unk_token is None:
|
| 58 |
+
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
|
| 59 |
+
|
| 60 |
+
smart_tokenizer_and_embedding_resize(
|
| 61 |
+
special_tokens_dict=special_tokens_dict,
|
| 62 |
+
llama_tokenizer=tokenizer,
|
| 63 |
+
model=base_model,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
peft_config = LoraConfig(
|
| 67 |
+
r=args.lora_rank,
|
| 68 |
+
lora_alpha=args.lora_alpha,
|
| 69 |
+
lora_dropout=args.lora_dropout,
|
| 70 |
+
bias="none",
|
| 71 |
+
task_type="CAUSAL_LM",
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
tokenizer.padding_side = 'left'
|
| 75 |
+
peftmodel = get_peft_model(base_model, peft_config)
|
| 76 |
+
if args.pretrained_path:
|
| 77 |
+
# load a previous checkpoint if the path is given
|
| 78 |
+
model = PeftModel.from_pretrained(base_model, args.pretrained_path, device_map=device_map)
|
| 79 |
+
peft_state_dict = {f"{k}": v for k, v in model.state_dict().items()}
|
| 80 |
+
peftmodel.load_state_dict(peft_state_dict)
|
| 81 |
+
|
| 82 |
+
for name, param in peftmodel.named_parameters():
|
| 83 |
+
if "lora" in name: # Check if "lora" is in the parameter's name
|
| 84 |
+
param.requires_grad = True
|
| 85 |
+
peftmodel.print_trainable_parameters()
|
| 86 |
+
return peftmodel, tokenizer
|
CADFusion/src/test/visual_utils/__init__.py
ADDED
|
File without changes
|
CADFusion/src/test/visual_utils/parser.py
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
import re
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
# hyperparameters from SkexGen project
|
| 11 |
+
SKETCH_R = 1
|
| 12 |
+
RADIUS_R = 1
|
| 13 |
+
EXTRUDE_R = 1.0
|
| 14 |
+
SCALE_R = 1.4
|
| 15 |
+
OFFSET_R = 0.9
|
| 16 |
+
PIX_PAD = 4
|
| 17 |
+
CMD_PAD = 3
|
| 18 |
+
COORD_PAD = 4
|
| 19 |
+
EXT_PAD = 1
|
| 20 |
+
EXTRA_PAD = 1
|
| 21 |
+
R_PAD = 2
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class CADparser:
|
| 25 |
+
"""Parse CAD sequence to CAD object."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, bit):
|
| 28 |
+
self.vertex_dict = OrderedDict()
|
| 29 |
+
self.bit = bit
|
| 30 |
+
|
| 31 |
+
def perform(self, cad_seq):
|
| 32 |
+
# divide into sketch and extrude
|
| 33 |
+
sketches, extrudes = self.get_SE(cad_seq)
|
| 34 |
+
if sketches is None or extrudes is None:
|
| 35 |
+
return None
|
| 36 |
+
# sequentially parse each pair of SE into obj
|
| 37 |
+
se_datas = []
|
| 38 |
+
for sketch, extrude in zip(sketches, extrudes):
|
| 39 |
+
extrude_param, scale, offset = self.parse_extrude(extrude)
|
| 40 |
+
if extrude_param is None or scale is None or offset is None:
|
| 41 |
+
return None
|
| 42 |
+
vertex_str, se_str = self.parse_sketch(sketch, scale, offset)
|
| 43 |
+
if vertex_str is None or se_str is None:
|
| 44 |
+
return None
|
| 45 |
+
se_datas.append(
|
| 46 |
+
{"vertex": vertex_str, "curve": se_str, "extrude": extrude_param}
|
| 47 |
+
)
|
| 48 |
+
self.vertex_dict.clear()
|
| 49 |
+
|
| 50 |
+
return se_datas
|
| 51 |
+
|
| 52 |
+
def parse_sketch(self, sketch, scale, offset):
|
| 53 |
+
faces = self.get_faces(sketch)
|
| 54 |
+
if len(faces) == 0:
|
| 55 |
+
return None, None
|
| 56 |
+
se_str = ""
|
| 57 |
+
for face_idx, face in enumerate(faces): # each face
|
| 58 |
+
face_str = "face\n"
|
| 59 |
+
loops = self.get_loops(face)
|
| 60 |
+
if len(loops) == 0:
|
| 61 |
+
return None, None
|
| 62 |
+
for loop_idx, loop in enumerate(loops): # each loop
|
| 63 |
+
curves = self.get_curves(loop)
|
| 64 |
+
if len(curves) == 0:
|
| 65 |
+
return None, None
|
| 66 |
+
next_curves = curves[1:]
|
| 67 |
+
next_curves += curves[:1]
|
| 68 |
+
cur_str = []
|
| 69 |
+
for curve, next_curve in zip(curves, next_curves): # each curve
|
| 70 |
+
if not self.obj_curve(curve, next_curve, cur_str, scale, offset):
|
| 71 |
+
return None, None
|
| 72 |
+
loop_str = ""
|
| 73 |
+
for c in cur_str:
|
| 74 |
+
loop_str += f"{c}\n"
|
| 75 |
+
if loop_idx == 0:
|
| 76 |
+
face_str += f"out\n{loop_str}\n"
|
| 77 |
+
else:
|
| 78 |
+
face_str += f"in\n{loop_str}\n"
|
| 79 |
+
se_str += face_str
|
| 80 |
+
vertex_str = self.convert_vertices()
|
| 81 |
+
return vertex_str, se_str
|
| 82 |
+
|
| 83 |
+
def parse_extrude(self, extrude):
|
| 84 |
+
ext = extrude.split(",")
|
| 85 |
+
if len(ext) != 18:
|
| 86 |
+
return None, None, None
|
| 87 |
+
|
| 88 |
+
# operation str to int
|
| 89 |
+
ext_op = {"add": 1, "cut": 2, "intersect": 3}.get(ext[0], None)
|
| 90 |
+
if ext_op is None:
|
| 91 |
+
return None, None, None
|
| 92 |
+
# dequantize ext_v, ext_T, scale and offset
|
| 93 |
+
ext_v, ext_T, scale, offset = self.dequantize_extrude_params(ext)
|
| 94 |
+
# get ext_R
|
| 95 |
+
ext_R = np.array(ext[6:15], dtype=int)
|
| 96 |
+
|
| 97 |
+
extrude_param = {"value": ext_v, "T": ext_T, "R": ext_R, "op": ext_op}
|
| 98 |
+
return extrude_param, scale, offset
|
| 99 |
+
|
| 100 |
+
def obj_curve(self, curve, next_curve, cur_str, scale, offset):
|
| 101 |
+
cur = curve.split(",")
|
| 102 |
+
next_cur = next_curve.split(",")
|
| 103 |
+
if cur[0] == "circle":
|
| 104 |
+
if len(cur) != 9:
|
| 105 |
+
return False
|
| 106 |
+
p1, p2, p3, p4 = self.dequantize_circle_points(
|
| 107 |
+
cur, next_cur, scale, offset)
|
| 108 |
+
center = np.asarray([0.5 * (p1[0] + p2[0]), 0.5 * (p3[1] + p4[1])])
|
| 109 |
+
radius = (np.linalg.norm(p1 - p2) + np.linalg.norm(p3 - p4)) / 4.0
|
| 110 |
+
|
| 111 |
+
center = center * scale + offset
|
| 112 |
+
radius = radius * scale
|
| 113 |
+
|
| 114 |
+
center_idx = self.save_vertex(center[0], center[1], "p")
|
| 115 |
+
radius_idx = self.save_vertex(radius, 0.0, "r")
|
| 116 |
+
cur_str.append(f"c {center_idx} {radius_idx}")
|
| 117 |
+
elif cur[0] == "arc":
|
| 118 |
+
if len(cur) != 5:
|
| 119 |
+
return False
|
| 120 |
+
if (
|
| 121 |
+
cur[1:3] == cur[3:5]
|
| 122 |
+
or cur[1:3] == next_cur[1:3]
|
| 123 |
+
or cur[3:5] == next_cur[3:5]
|
| 124 |
+
): # invalid arc
|
| 125 |
+
return False
|
| 126 |
+
start_v, mid_v, end_v = self.dequantize_arc_points(
|
| 127 |
+
cur, next_cur, scale, offset
|
| 128 |
+
)
|
| 129 |
+
try:
|
| 130 |
+
center, _, _, _ = find_arc_geometry(start_v, mid_v, end_v)
|
| 131 |
+
except Exception:
|
| 132 |
+
return False
|
| 133 |
+
start_v = start_v * scale + offset
|
| 134 |
+
mid_v = mid_v * scale + offset
|
| 135 |
+
end_v = end_v * scale + offset
|
| 136 |
+
center = center * scale + offset
|
| 137 |
+
|
| 138 |
+
center_idx = self.save_vertex(center[0], center[1], "p")
|
| 139 |
+
start_idx = self.save_vertex(start_v[0], start_v[1], "p")
|
| 140 |
+
mid_idx = self.save_vertex(mid_v[0], mid_v[1], "p")
|
| 141 |
+
end_idx = self.save_vertex(end_v[0], end_v[1], "p")
|
| 142 |
+
cur_str.append(f"a {start_idx} {mid_idx} {center_idx} {end_idx}")
|
| 143 |
+
elif cur[0] == "line":
|
| 144 |
+
if len(cur) != 3:
|
| 145 |
+
return False
|
| 146 |
+
if cur[1:3] == next_cur[1:3]:
|
| 147 |
+
return False
|
| 148 |
+
start_v, end_v = self.dequantize_line_points(
|
| 149 |
+
cur, next_cur, scale, offset)
|
| 150 |
+
start_v = start_v * scale + offset
|
| 151 |
+
end_v = end_v * scale + offset
|
| 152 |
+
|
| 153 |
+
start_idx = self.save_vertex(start_v[0], start_v[1], "p")
|
| 154 |
+
end_idx = self.save_vertex(end_v[0], end_v[1], "p")
|
| 155 |
+
cur_str.append(f"l {start_idx} {end_idx}")
|
| 156 |
+
else:
|
| 157 |
+
return False
|
| 158 |
+
return True
|
| 159 |
+
|
| 160 |
+
def get_SE(self, cad_seq):
|
| 161 |
+
# sketches: 1) between sequence start and sketch_end,
|
| 162 |
+
sketches_from_start = re.findall(r"^(.+?)(?=<sketch_end>)", cad_seq)
|
| 163 |
+
# sketches: 2) between extrude_end and sketch_end
|
| 164 |
+
sketches_after_extrude = re.findall(
|
| 165 |
+
r"(?<=<extrude_end>)(.+?)(?=<sketch_end>)", cad_seq
|
| 166 |
+
)
|
| 167 |
+
sketches = [x.strip() for x in sketches_from_start] + [
|
| 168 |
+
x.strip() for x in sketches_after_extrude
|
| 169 |
+
]
|
| 170 |
+
# extrudes: between sketch_end and extrude_end
|
| 171 |
+
extrudes = [
|
| 172 |
+
x.strip() for x in re.findall(r"<sketch_end>(.+?)<extrude_end>", cad_seq)
|
| 173 |
+
]
|
| 174 |
+
if len(sketches) != len(extrudes):
|
| 175 |
+
return None, None
|
| 176 |
+
return sketches, extrudes
|
| 177 |
+
|
| 178 |
+
def get_faces(self, sketch):
|
| 179 |
+
faces = sketch.split("<face_end>")
|
| 180 |
+
return [x.strip() for x in faces if x.strip() != ""]
|
| 181 |
+
|
| 182 |
+
def get_loops(self, face):
|
| 183 |
+
loops = face.split("<loop_end>")
|
| 184 |
+
return [x.strip() for x in loops if x.strip() != ""]
|
| 185 |
+
|
| 186 |
+
def get_curves(self, loop):
|
| 187 |
+
curves = loop.split("<curve_end>")
|
| 188 |
+
return [x.strip() for x in curves if x.strip() != ""]
|
| 189 |
+
|
| 190 |
+
def dequantize_circle_points(self, curve, next_curve, scale, offset):
|
| 191 |
+
p1 = dequantize_verts(
|
| 192 |
+
np.array(curve[1:3], dtype=int),
|
| 193 |
+
n_bits=self.bit,
|
| 194 |
+
min_range=-SKETCH_R,
|
| 195 |
+
max_range=SKETCH_R,
|
| 196 |
+
add_noise=False,
|
| 197 |
+
)
|
| 198 |
+
p2 = dequantize_verts(
|
| 199 |
+
np.array(curve[3:5], dtype=int),
|
| 200 |
+
n_bits=self.bit,
|
| 201 |
+
min_range=-SKETCH_R,
|
| 202 |
+
max_range=SKETCH_R,
|
| 203 |
+
add_noise=False,
|
| 204 |
+
)
|
| 205 |
+
p3 = dequantize_verts(
|
| 206 |
+
np.array(curve[5:7], dtype=int),
|
| 207 |
+
n_bits=self.bit,
|
| 208 |
+
min_range=-SKETCH_R,
|
| 209 |
+
max_range=SKETCH_R,
|
| 210 |
+
add_noise=False,
|
| 211 |
+
)
|
| 212 |
+
p4 = dequantize_verts(
|
| 213 |
+
np.array(curve[7:9], dtype=int),
|
| 214 |
+
n_bits=self.bit,
|
| 215 |
+
min_range=-SKETCH_R,
|
| 216 |
+
max_range=SKETCH_R,
|
| 217 |
+
add_noise=False,
|
| 218 |
+
)
|
| 219 |
+
return p1, p2, p3, p4
|
| 220 |
+
|
| 221 |
+
def dequantize_arc_points(self, curve, next_curve, scale, offset):
|
| 222 |
+
start_v = dequantize_verts(
|
| 223 |
+
np.array(curve[1:3], dtype=int),
|
| 224 |
+
n_bits=self.bit,
|
| 225 |
+
min_range=-SKETCH_R,
|
| 226 |
+
max_range=SKETCH_R,
|
| 227 |
+
add_noise=False,
|
| 228 |
+
)
|
| 229 |
+
mid_v = dequantize_verts(
|
| 230 |
+
np.array(curve[3:5], dtype=int),
|
| 231 |
+
n_bits=self.bit,
|
| 232 |
+
min_range=-SKETCH_R,
|
| 233 |
+
max_range=SKETCH_R,
|
| 234 |
+
add_noise=False,
|
| 235 |
+
)
|
| 236 |
+
end_v = dequantize_verts(
|
| 237 |
+
np.array(next_curve[1:3], dtype=int),
|
| 238 |
+
n_bits=self.bit,
|
| 239 |
+
min_range=-SKETCH_R,
|
| 240 |
+
max_range=SKETCH_R,
|
| 241 |
+
add_noise=False,
|
| 242 |
+
)
|
| 243 |
+
return start_v, mid_v, end_v
|
| 244 |
+
|
| 245 |
+
def dequantize_line_points(self, curve, next_curve, scale, offset):
|
| 246 |
+
start_v = dequantize_verts(
|
| 247 |
+
np.array(curve[1:3], dtype=int),
|
| 248 |
+
n_bits=self.bit,
|
| 249 |
+
min_range=-SKETCH_R,
|
| 250 |
+
max_range=SKETCH_R,
|
| 251 |
+
add_noise=False,
|
| 252 |
+
)
|
| 253 |
+
end_v = dequantize_verts(
|
| 254 |
+
np.array(next_curve[1:3], dtype=int),
|
| 255 |
+
n_bits=self.bit,
|
| 256 |
+
min_range=-SKETCH_R,
|
| 257 |
+
max_range=SKETCH_R,
|
| 258 |
+
add_noise=False,
|
| 259 |
+
)
|
| 260 |
+
return start_v, end_v
|
| 261 |
+
|
| 262 |
+
def dequantize_extrude_params(self, extrude):
|
| 263 |
+
ext_v = dequantize_verts(
|
| 264 |
+
np.array(extrude[1:3], dtype=int),
|
| 265 |
+
n_bits=self.bit,
|
| 266 |
+
min_range=-EXTRUDE_R,
|
| 267 |
+
max_range=EXTRUDE_R,
|
| 268 |
+
add_noise=False,
|
| 269 |
+
)
|
| 270 |
+
ext_T = dequantize_verts(
|
| 271 |
+
np.array(extrude[3:6], dtype=int),
|
| 272 |
+
n_bits=self.bit,
|
| 273 |
+
min_range=-EXTRUDE_R,
|
| 274 |
+
max_range=EXTRUDE_R,
|
| 275 |
+
add_noise=False,
|
| 276 |
+
)
|
| 277 |
+
scale = dequantize_verts(
|
| 278 |
+
np.array(extrude[15], dtype=int),
|
| 279 |
+
n_bits=self.bit,
|
| 280 |
+
min_range=0.0,
|
| 281 |
+
max_range=SCALE_R,
|
| 282 |
+
add_noise=False,
|
| 283 |
+
)
|
| 284 |
+
offset = dequantize_verts(
|
| 285 |
+
np.array(extrude[16:18], dtype=int),
|
| 286 |
+
n_bits=self.bit,
|
| 287 |
+
min_range=-OFFSET_R,
|
| 288 |
+
max_range=OFFSET_R,
|
| 289 |
+
add_noise=False,
|
| 290 |
+
)
|
| 291 |
+
return ext_v, ext_T, scale, offset
|
| 292 |
+
|
| 293 |
+
def save_vertex(self, h_x, h_y, text):
|
| 294 |
+
unique_key = f"{text}:x{h_x}y{h_y}"
|
| 295 |
+
index = 0
|
| 296 |
+
for key in self.vertex_dict.keys():
|
| 297 |
+
# Vertex location already exist in dict
|
| 298 |
+
if unique_key == key:
|
| 299 |
+
return index
|
| 300 |
+
index += 1
|
| 301 |
+
# Vertex location does not exist in dict
|
| 302 |
+
self.vertex_dict[unique_key] = [h_x, h_y]
|
| 303 |
+
return index
|
| 304 |
+
|
| 305 |
+
def convert_vertices(self):
|
| 306 |
+
"""Convert all the vertices to .obj format"""
|
| 307 |
+
vertex_strings = ""
|
| 308 |
+
for pt in self.vertex_dict.values():
|
| 309 |
+
# e.g. v 0.123 0.234 0.345 1.0
|
| 310 |
+
vertex_string = f"v {pt[0]} {pt[1]}\n"
|
| 311 |
+
vertex_strings += vertex_string
|
| 312 |
+
return vertex_strings
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def find_arc_geometry(a, b, c):
|
| 316 |
+
A = b[0] - a[0]
|
| 317 |
+
B = b[1] - a[1]
|
| 318 |
+
C = c[0] - a[0]
|
| 319 |
+
D = c[1] - a[1]
|
| 320 |
+
|
| 321 |
+
E = A*(a[0] + b[0]) + B*(a[1] + b[1])
|
| 322 |
+
F = C*(a[0] + c[0]) + D*(a[1] + c[1])
|
| 323 |
+
|
| 324 |
+
G = 2.0*(A*(c[1] - b[1])-B*(c[0] - b[0]))
|
| 325 |
+
|
| 326 |
+
if G == 0:
|
| 327 |
+
raise Exception("zero G")
|
| 328 |
+
|
| 329 |
+
p_0 = (D*E - B*F) / G
|
| 330 |
+
p_1 = (A*F - C*E) / G
|
| 331 |
+
|
| 332 |
+
center = np.array([p_0, p_1])
|
| 333 |
+
radius = np.linalg.norm(center - a)
|
| 334 |
+
|
| 335 |
+
angles = []
|
| 336 |
+
for xx in [a, b, c]:
|
| 337 |
+
angle = angle_from_vector_to_x(xx - center)
|
| 338 |
+
angles.append(angle)
|
| 339 |
+
|
| 340 |
+
ab = b-a
|
| 341 |
+
ac = c-a
|
| 342 |
+
cp = np.cross(ab, ac)
|
| 343 |
+
if cp >= 0:
|
| 344 |
+
start_angle_rads = angles[0]
|
| 345 |
+
end_angle_rads = angles[2]
|
| 346 |
+
else:
|
| 347 |
+
start_angle_rads = angles[2]
|
| 348 |
+
end_angle_rads = angles[0]
|
| 349 |
+
|
| 350 |
+
return center, radius, start_angle_rads, end_angle_rads
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def angle_from_vector_to_x(vec):
|
| 354 |
+
assert vec.size == 2
|
| 355 |
+
# We need to find a unit vector
|
| 356 |
+
angle = 0.0
|
| 357 |
+
|
| 358 |
+
l = np.linalg.norm(vec)
|
| 359 |
+
uvec = vec/l
|
| 360 |
+
|
| 361 |
+
# 2 | 1
|
| 362 |
+
# -------
|
| 363 |
+
# 3 | 4
|
| 364 |
+
if uvec[0] >= 0:
|
| 365 |
+
if uvec[1] >= 0:
|
| 366 |
+
# Qadrant 1
|
| 367 |
+
angle = math.asin(uvec[1])
|
| 368 |
+
else:
|
| 369 |
+
# Qadrant 4
|
| 370 |
+
angle = 2.0*math.pi - math.asin(-uvec[1])
|
| 371 |
+
else:
|
| 372 |
+
if vec[1] >= 0:
|
| 373 |
+
# Qadrant 2
|
| 374 |
+
angle = math.pi - math.asin(uvec[1])
|
| 375 |
+
else:
|
| 376 |
+
# Qadrant 3
|
| 377 |
+
angle = math.pi + math.asin(-uvec[1])
|
| 378 |
+
return angle
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def dequantize_verts(verts, n_bits=8, min_range=-0.5, max_range=0.5, add_noise=False):
|
| 382 |
+
"""Convert quantized vertices to floats."""
|
| 383 |
+
range_quantize = 2**n_bits - 1
|
| 384 |
+
verts = verts.astype("float32")
|
| 385 |
+
verts = verts * (max_range - min_range) / range_quantize + min_range
|
| 386 |
+
return verts
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def write_obj_sample(save_folder, data):
|
| 390 |
+
for idx, write_data in enumerate(data):
|
| 391 |
+
obj_name = Path(save_folder).stem + "_" + \
|
| 392 |
+
str(idx).zfill(3) + "_param.obj"
|
| 393 |
+
obj_file = Path(save_folder) / obj_name
|
| 394 |
+
extrude_param = write_data["extrude"]
|
| 395 |
+
vertex_strings = write_data["vertex"]
|
| 396 |
+
curve_strings = write_data["curve"]
|
| 397 |
+
|
| 398 |
+
"""Write an .obj file with the curves and verts"""
|
| 399 |
+
if extrude_param["op"] == 1: # 'add'
|
| 400 |
+
set_op = "NewBodyFeatureOperation"
|
| 401 |
+
elif extrude_param["op"] == 2: # 'cut'
|
| 402 |
+
set_op = "CutFeatureOperation"
|
| 403 |
+
elif extrude_param["op"] == 3: # 'cut'
|
| 404 |
+
set_op = "IntersectFeatureOperation"
|
| 405 |
+
|
| 406 |
+
with open(obj_file, "w") as fh:
|
| 407 |
+
# Write Meta info
|
| 408 |
+
fh.write("# WaveFront *.obj file\n")
|
| 409 |
+
fh.write("# ExtrudeOperation: " + set_op + "\n")
|
| 410 |
+
fh.write("\n")
|
| 411 |
+
|
| 412 |
+
# Write vertex and curve
|
| 413 |
+
fh.write(vertex_strings)
|
| 414 |
+
fh.write("\n")
|
| 415 |
+
fh.write(curve_strings)
|
| 416 |
+
fh.write("\n")
|
| 417 |
+
|
| 418 |
+
# Write extrude value
|
| 419 |
+
extrude_string = "Extrude "
|
| 420 |
+
for value in extrude_param["value"]:
|
| 421 |
+
extrude_string += str(value) + " "
|
| 422 |
+
fh.write(extrude_string)
|
| 423 |
+
fh.write("\n")
|
| 424 |
+
|
| 425 |
+
# Write refe plane value
|
| 426 |
+
p_orig = parse3d_sample(extrude_param["T"])
|
| 427 |
+
x_axis = parse3d_sample(extrude_param["R"][0:3])
|
| 428 |
+
y_axis = parse3d_sample(extrude_param["R"][3:6])
|
| 429 |
+
z_axis = parse3d_sample(extrude_param["R"][6:9])
|
| 430 |
+
fh.write("T_origin " + p_orig)
|
| 431 |
+
fh.write("\n")
|
| 432 |
+
fh.write("T_xaxis " + x_axis)
|
| 433 |
+
fh.write("\n")
|
| 434 |
+
fh.write("T_yaxis " + y_axis)
|
| 435 |
+
fh.write("\n")
|
| 436 |
+
fh.write("T_zaxis " + z_axis)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def parse3d_sample(point3d):
|
| 440 |
+
x = point3d[0]
|
| 441 |
+
y = point3d[1]
|
| 442 |
+
z = point3d[2]
|
| 443 |
+
return str(x) + " " + str(y) + " " + str(z)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
if __name__ == "__main__":
|
| 447 |
+
parser = argparse.ArgumentParser()
|
| 448 |
+
parser.add_argument("--in-path", type=str, required=True)
|
| 449 |
+
parser.add_argument("--out-path", type=str, required=True)
|
| 450 |
+
args = parser.parse_args()
|
| 451 |
+
|
| 452 |
+
# with open(args.in_path, "r") as f:
|
| 453 |
+
# data = f.readlines()
|
| 454 |
+
with open(args.in_path, 'r') as file:
|
| 455 |
+
data = file.read()
|
| 456 |
+
|
| 457 |
+
data = json.loads(data)
|
| 458 |
+
|
| 459 |
+
num_valid_str = 0
|
| 460 |
+
for idx, item in enumerate(data):
|
| 461 |
+
try:
|
| 462 |
+
cad_parser = CADparser(bit=6)
|
| 463 |
+
# print(idx)
|
| 464 |
+
if type(item) == str:
|
| 465 |
+
parsed_data = cad_parser.perform(item)
|
| 466 |
+
elif type(item) == dict:
|
| 467 |
+
parsed_data = cad_parser.perform(item['output'])
|
| 468 |
+
else:
|
| 469 |
+
raise ValueError("Invalid data type")
|
| 470 |
+
out_path = os.path.join(args.out_path, str(idx).zfill(6))
|
| 471 |
+
os.makedirs(out_path, exist_ok=True)
|
| 472 |
+
if parsed_data is not None:
|
| 473 |
+
num_valid_str += 1
|
| 474 |
+
write_obj_sample(out_path, parsed_data)
|
| 475 |
+
except Exception as e:
|
| 476 |
+
print(e)
|
| 477 |
+
pass
|
| 478 |
+
print(f"Number of valid CAD strings: {num_valid_str}/{len(data)}")
|
CADFusion/src/train/CAD_dataset.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import random
|
| 5 |
+
import transformers
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
from utils import IGNORE_INDEX, MAX_LENGTH
|
| 10 |
+
|
| 11 |
+
class CADDataset(Dataset):
|
| 12 |
+
def __init__(self, json_fn, cutoff=True, llama_tokenizer=None):
|
| 13 |
+
if not os.path.exists(json_fn):
|
| 14 |
+
raise ValueError(f"{json_fn} does not exist")
|
| 15 |
+
self.inputs = json.load(open(json_fn, "r"))
|
| 16 |
+
print(len(self.inputs))
|
| 17 |
+
self.inputs = [item for item in self.inputs if 'null' not in item['description']]
|
| 18 |
+
random.shuffle(self.inputs)
|
| 19 |
+
if cutoff:
|
| 20 |
+
self.inputs = self.inputs[:18953]
|
| 21 |
+
print(len(self.inputs))
|
| 22 |
+
self.llama_tokenizer = llama_tokenizer
|
| 23 |
+
|
| 24 |
+
def __len__(self):
|
| 25 |
+
return len(self.inputs)
|
| 26 |
+
|
| 27 |
+
def __getitem__(self, index):
|
| 28 |
+
item = self.inputs[index]
|
| 29 |
+
seq = item['command_sequence']
|
| 30 |
+
des = item['description']
|
| 31 |
+
val = self.tokenize(seq, des)
|
| 32 |
+
return val
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def tokenize(self, seq, des):
|
| 36 |
+
tokens, prompt_length = self.conditional_generation_task(seq=seq, des=des)
|
| 37 |
+
input_ids = tokens.input_ids[0]
|
| 38 |
+
labels = tokens.input_ids[0].clone() # Clone the input_ids for labels
|
| 39 |
+
# Set the labels for the prompt part to IGNORE_INDEX so they are ignored in loss calculation
|
| 40 |
+
labels[:prompt_length] = IGNORE_INDEX
|
| 41 |
+
input_id_lens = label_lens = (
|
| 42 |
+
tokens.input_ids.ne(self.llama_tokenizer.pad_token_id).sum().item()
|
| 43 |
+
)
|
| 44 |
+
return dict(
|
| 45 |
+
input_ids=input_ids,
|
| 46 |
+
input_id_lens=input_id_lens,
|
| 47 |
+
labels=labels,
|
| 48 |
+
label_lens=label_lens,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def conditional_generation_task(self, seq, des):
|
| 53 |
+
prompt = 'Below is a description of a 3D shape:\n'
|
| 54 |
+
prompt += des
|
| 55 |
+
prompt += '\nGenerate a Computer-Aided Design(CAD) command sequence of the 3D shape:\n'
|
| 56 |
+
full_text = prompt + seq + self.llama_tokenizer.eos_token
|
| 57 |
+
tokens = self.llama_tokenizer(
|
| 58 |
+
full_text,
|
| 59 |
+
max_length=MAX_LENGTH,
|
| 60 |
+
return_tensors="pt",
|
| 61 |
+
truncation=True,
|
| 62 |
+
)
|
| 63 |
+
prompt_length = len(self.llama_tokenizer(prompt)['input_ids'])
|
| 64 |
+
return tokens, prompt_length
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclass
|
| 68 |
+
class DataCollatorForSupervisedDataset(object):
|
| 69 |
+
"""Collate examples for supervised fine-tuning."""
|
| 70 |
+
|
| 71 |
+
tokenizer: transformers.PreTrainedTokenizer
|
| 72 |
+
|
| 73 |
+
def __call__(self, instances):
|
| 74 |
+
input_ids, labels = tuple(
|
| 75 |
+
[instance[key].clone().detach() for instance in instances]
|
| 76 |
+
for key in ("input_ids", "labels")
|
| 77 |
+
)
|
| 78 |
+
# force left padding
|
| 79 |
+
reversed_sequences = [torch.flip(input_id, [0]) for input_id in input_ids]
|
| 80 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(reversed_sequences, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
| 81 |
+
input_ids = torch.flip(input_ids, [0, 1])
|
| 82 |
+
labels = torch.nn.utils.rnn.pad_sequence(
|
| 83 |
+
labels, batch_first=True, padding_value=IGNORE_INDEX
|
| 84 |
+
)
|
| 85 |
+
return dict(
|
| 86 |
+
input_ids=input_ids,
|
| 87 |
+
labels=labels,
|
| 88 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
| 89 |
+
)
|
CADFusion/src/train/dpo.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import json
|
| 5 |
+
import random
|
| 6 |
+
import transformers
|
| 7 |
+
from huggingface_hub import login
|
| 8 |
+
|
| 9 |
+
login() # put your huggingface token here
|
| 10 |
+
os.environ["WANDB_PROJECT"] = "CADFusion_VF"
|
| 11 |
+
|
| 12 |
+
from datasets import Dataset
|
| 13 |
+
from trl import DPOTrainer, DPOConfig
|
| 14 |
+
from utils import prepare_model_and_tokenizer
|
| 15 |
+
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
parser.add_argument("--run-name", type=str, required=True)
|
| 18 |
+
parser.add_argument("--lora-rank", type=int, default=32)
|
| 19 |
+
parser.add_argument("--lora-alpha", type=int, default=32)
|
| 20 |
+
parser.add_argument("--lora-dropout", type=float, default=0.05)
|
| 21 |
+
parser.add_argument("--sample-cutoff", default=100000, type=int)
|
| 22 |
+
parser.add_argument("--pretrained-path", type=str, required=True)
|
| 23 |
+
parser.add_argument("--data-path", type=str, required=True)
|
| 24 |
+
parser.add_argument("--output-path", type=str, required=True)
|
| 25 |
+
parser.add_argument("--num-epochs", type=int, default=3)
|
| 26 |
+
parser.add_argument("--batch-size", type=int, default=2)
|
| 27 |
+
parser.add_argument("--eval-freq", default=1000, type=int)
|
| 28 |
+
parser.add_argument("--save-freq", default=500, type=int)
|
| 29 |
+
parser.add_argument("--debug", action="store_true", default=False)
|
| 30 |
+
args = parser.parse_args()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
with open(args.data_path, 'r') as f:
|
| 35 |
+
raw_data = json.load(f)
|
| 36 |
+
|
| 37 |
+
random.shuffle(raw_data)
|
| 38 |
+
|
| 39 |
+
if len(raw_data) > args.sample_cutoff + 100:
|
| 40 |
+
ds = {
|
| 41 |
+
"train": Dataset.from_list(raw_data[:args.sample_cutoff]),
|
| 42 |
+
"val": Dataset.from_list(raw_data[-100:])
|
| 43 |
+
}
|
| 44 |
+
else:
|
| 45 |
+
ds = {
|
| 46 |
+
"train": Dataset.from_list(raw_data[:-100]),
|
| 47 |
+
"val": Dataset.from_list(raw_data[-100:])
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
llama_model, llama_tokenizer = prepare_model_and_tokenizer(args)
|
| 51 |
+
|
| 52 |
+
for name, param in llama_model.named_parameters():
|
| 53 |
+
if "lora" in name: # Check if "lora" is in the parameter's name
|
| 54 |
+
param.requires_grad = True
|
| 55 |
+
|
| 56 |
+
training_args = DPOConfig(
|
| 57 |
+
run_name=args.run_name,
|
| 58 |
+
learning_rate=1.41e-5,
|
| 59 |
+
per_device_train_batch_size=2,
|
| 60 |
+
per_device_eval_batch_size=args.batch_size,
|
| 61 |
+
report_to="wandb",
|
| 62 |
+
num_train_epochs=args.num_epochs,
|
| 63 |
+
do_eval=True,
|
| 64 |
+
eval_steps=args.eval_freq,
|
| 65 |
+
save_steps=args.save_freq,
|
| 66 |
+
output_dir=args.output_path
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
trainer = DPOTrainer(
|
| 70 |
+
llama_model,
|
| 71 |
+
None,
|
| 72 |
+
args=training_args,
|
| 73 |
+
train_dataset=ds['train'],
|
| 74 |
+
eval_dataset=ds['val'],
|
| 75 |
+
tokenizer=llama_tokenizer,
|
| 76 |
+
)
|
| 77 |
+
trainer.save_model()
|
| 78 |
+
trainer.train()
|
| 79 |
+
trainer.save_model()
|
CADFusion/src/train/llama_finetune.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import transformers
|
| 6 |
+
|
| 7 |
+
from CAD_dataset import CADDataset, DataCollatorForSupervisedDataset
|
| 8 |
+
from huggingface_hub import login
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from peft import LoraConfig, get_peft_model
|
| 11 |
+
from transformers import Trainer, TrainingArguments
|
| 12 |
+
from utils import prepare_model_and_tokenizer
|
| 13 |
+
|
| 14 |
+
login() # put your huggingface token here
|
| 15 |
+
|
| 16 |
+
def setup_datasets(args, llama_tokenizer, transform_args={}):
|
| 17 |
+
datasets = {
|
| 18 |
+
"train": CADDataset(
|
| 19 |
+
args.data_path,
|
| 20 |
+
llama_tokenizer=llama_tokenizer,
|
| 21 |
+
),
|
| 22 |
+
"val": CADDataset(
|
| 23 |
+
args.eval_data_path,
|
| 24 |
+
llama_tokenizer=llama_tokenizer,
|
| 25 |
+
),
|
| 26 |
+
}
|
| 27 |
+
return datasets
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def setup_training_args(args):
|
| 31 |
+
output_dir = args.expdir / args.run_name
|
| 32 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 33 |
+
|
| 34 |
+
if args.debug:
|
| 35 |
+
os.environ["WANDB_DISABLED"] = "True"
|
| 36 |
+
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
|
| 37 |
+
training_args = TrainingArguments(
|
| 38 |
+
fsdp=False,
|
| 39 |
+
fp16=False,
|
| 40 |
+
bf16=False,
|
| 41 |
+
do_eval=True,
|
| 42 |
+
gradient_checkpointing=False,
|
| 43 |
+
ddp_find_unused_parameters=False,
|
| 44 |
+
num_train_epochs=args.num_epochs,
|
| 45 |
+
eval_steps=args.eval_freq,
|
| 46 |
+
save_steps=args.save_freq,
|
| 47 |
+
logging_steps=10,
|
| 48 |
+
evaluation_strategy="steps",
|
| 49 |
+
per_device_train_batch_size=args.batch_size,
|
| 50 |
+
per_device_eval_batch_size=args.batch_size,
|
| 51 |
+
learning_rate=args.lr,
|
| 52 |
+
lr_scheduler_type=args.lr_scheduler,
|
| 53 |
+
warmup_steps=args.num_warmup_steps,
|
| 54 |
+
weight_decay=args.weight_decay,
|
| 55 |
+
gradient_accumulation_steps=args.grad_accum,
|
| 56 |
+
output_dir=output_dir,
|
| 57 |
+
run_name=args.run_name,
|
| 58 |
+
report_to="wandb",
|
| 59 |
+
dataloader_num_workers=8,
|
| 60 |
+
remove_unused_columns=False,
|
| 61 |
+
# label_names=["cad_ids"], # this is to make trainer behave as expected
|
| 62 |
+
)
|
| 63 |
+
return training_args
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def setup_trainer(args):
|
| 67 |
+
training_args = setup_training_args(args)
|
| 68 |
+
if args.device_map == 'accelerate':
|
| 69 |
+
args.device_map = {'': training_args.local_rank}
|
| 70 |
+
model, llama_tokenizer = prepare_model_and_tokenizer(args)
|
| 71 |
+
|
| 72 |
+
datasets = setup_datasets(args, llama_tokenizer)
|
| 73 |
+
|
| 74 |
+
data_collator = DataCollatorForSupervisedDataset(
|
| 75 |
+
tokenizer=llama_tokenizer,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
trainer = Trainer(
|
| 79 |
+
model=model,
|
| 80 |
+
args=training_args,
|
| 81 |
+
train_dataset=datasets["train"],
|
| 82 |
+
eval_dataset=datasets["val"],
|
| 83 |
+
data_collator=data_collator,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
return trainer
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def main(args):
|
| 90 |
+
trainer = setup_trainer(args)
|
| 91 |
+
|
| 92 |
+
if args.resume_dir is not None:
|
| 93 |
+
train_result = trainer.train(resume_from_checkpoint=args.resume_dir)
|
| 94 |
+
else:
|
| 95 |
+
train_result = trainer.train()
|
| 96 |
+
|
| 97 |
+
print(train_result)
|
| 98 |
+
trainer.save_state()
|
| 99 |
+
trainer.save_model()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
parser = argparse.ArgumentParser()
|
| 104 |
+
parser.add_argument("--run-name", type=str, required=True)
|
| 105 |
+
parser.add_argument("--expdir", type=Path, default="exp")
|
| 106 |
+
parser.add_argument("--model-name", default="llama3")
|
| 107 |
+
parser.add_argument("--lora-rank", type=int, default=32)
|
| 108 |
+
parser.add_argument("--lora-alpha", type=int, default=32)
|
| 109 |
+
parser.add_argument("--lora-dropout", type=float, default=0.05)
|
| 110 |
+
parser.add_argument("--data-path", type=Path, default="data/train.json")
|
| 111 |
+
parser.add_argument("--eval-data-path", type=Path, default="data/eval.json")
|
| 112 |
+
parser.add_argument("--pretrained-path", type=Path, default=None)
|
| 113 |
+
parser.add_argument("--num-epochs", type=int, default=40)
|
| 114 |
+
parser.add_argument("--batch-size", type=int, default=1)
|
| 115 |
+
parser.add_argument("--grad-accum", type=int, default=1)
|
| 116 |
+
parser.add_argument("--lr", type=float, default=1e-4)
|
| 117 |
+
parser.add_argument("--lr-scheduler", type=str, default="cosine")
|
| 118 |
+
parser.add_argument("--num-warmup-steps", type=int, default=100)
|
| 119 |
+
parser.add_argument("--weight-decay", type=float, default=0.0)
|
| 120 |
+
parser.add_argument("--eval-freq", default=1000, type=int)
|
| 121 |
+
parser.add_argument("--save-freq", default=50000, type=int)
|
| 122 |
+
parser.add_argument("--device-map", type=str, default='auto')
|
| 123 |
+
parser.add_argument("--resume-dir", type=Path, default=None)
|
| 124 |
+
parser.add_argument("--debug", action="store_true", default=False)
|
| 125 |
+
args = parser.parse_args()
|
| 126 |
+
os.environ["WANDB_PROJECT"] = "CADFusion_SL"
|
| 127 |
+
main(args)
|
CADFusion/src/train/utils.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import transformers
|
| 3 |
+
from peft import LoraConfig, PeftModel, get_peft_model
|
| 4 |
+
|
| 5 |
+
IGNORE_INDEX = -100
|
| 6 |
+
MAX_LENGTH = 512
|
| 7 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
| 8 |
+
DEFAULT_EOS_TOKEN = "</s>"
|
| 9 |
+
DEFAULT_BOS_TOKEN = "<s>"
|
| 10 |
+
DEFAULT_UNK_TOKEN = "<unk>"
|
| 11 |
+
|
| 12 |
+
def smart_tokenizer_and_embedding_resize(
|
| 13 |
+
special_tokens_dict,
|
| 14 |
+
llama_tokenizer,
|
| 15 |
+
model,
|
| 16 |
+
):
|
| 17 |
+
"""Resize tokenizer and embedding.
|
| 18 |
+
|
| 19 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
| 20 |
+
"""
|
| 21 |
+
num_new_tokens = llama_tokenizer.add_special_tokens(special_tokens_dict)
|
| 22 |
+
model.resize_token_embeddings(len(llama_tokenizer))
|
| 23 |
+
|
| 24 |
+
if num_new_tokens > 0:
|
| 25 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
| 26 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
| 27 |
+
|
| 28 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
| 29 |
+
dim=0, keepdim=True
|
| 30 |
+
)
|
| 31 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
| 32 |
+
dim=0, keepdim=True
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 36 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
| 37 |
+
|
| 38 |
+
def prepare_model_and_tokenizer(args):
|
| 39 |
+
model_id = "meta-llama/Meta-Llama-3-8B"
|
| 40 |
+
print(f"Model size: {model_id}")
|
| 41 |
+
if hasattr(args, 'device_map'):
|
| 42 |
+
device_map = args.device_map
|
| 43 |
+
else:
|
| 44 |
+
device_map = 'auto'
|
| 45 |
+
pipeline = transformers.pipeline("text2text-generation",
|
| 46 |
+
model=model_id, model_kwargs={"torch_dtype": torch.float32}, device_map=device_map)
|
| 47 |
+
tokenizer = pipeline.tokenizer
|
| 48 |
+
base_model = pipeline.model
|
| 49 |
+
|
| 50 |
+
special_tokens_dict = dict()
|
| 51 |
+
if tokenizer.pad_token is None:
|
| 52 |
+
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
|
| 53 |
+
if tokenizer.eos_token is None:
|
| 54 |
+
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
|
| 55 |
+
if tokenizer.bos_token is None:
|
| 56 |
+
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
|
| 57 |
+
if tokenizer.unk_token is None:
|
| 58 |
+
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
|
| 59 |
+
|
| 60 |
+
smart_tokenizer_and_embedding_resize(
|
| 61 |
+
special_tokens_dict=special_tokens_dict,
|
| 62 |
+
llama_tokenizer=tokenizer,
|
| 63 |
+
model=base_model,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
peft_config = LoraConfig(
|
| 67 |
+
r=args.lora_rank,
|
| 68 |
+
lora_alpha=args.lora_alpha,
|
| 69 |
+
lora_dropout=args.lora_dropout,
|
| 70 |
+
bias="none",
|
| 71 |
+
task_type="CAUSAL_LM",
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
tokenizer.padding_side = 'left'
|
| 75 |
+
peftmodel = get_peft_model(base_model, peft_config)
|
| 76 |
+
if args.pretrained_path:
|
| 77 |
+
# load a previous checkpoint if the path is given
|
| 78 |
+
model = PeftModel.from_pretrained(base_model, args.pretrained_path, device_map=device_map)
|
| 79 |
+
peft_state_dict = {f"{k}": v for k, v in model.state_dict().items()}
|
| 80 |
+
peftmodel.load_state_dict(peft_state_dict)
|
| 81 |
+
|
| 82 |
+
for name, param in peftmodel.named_parameters():
|
| 83 |
+
if "lora" in name: # Check if "lora" is in the parameter's name
|
| 84 |
+
param.requires_grad = True
|
| 85 |
+
peftmodel.print_trainable_parameters()
|
| 86 |
+
return peftmodel, tokenizer
|