kshdes37 commited on
Commit
91daf98
·
verified ·
1 Parent(s): ea9e3ff

Upload 50 files

Browse files
Files changed (50) hide show
  1. CADFusion/.gitignore +171 -0
  2. CADFusion/CODE_OF_CONDUCT.md +9 -0
  3. CADFusion/LICENSE +21 -0
  4. CADFusion/README.md +194 -0
  5. CADFusion/SECURITY.md +41 -0
  6. CADFusion/SUPPORT.md +25 -0
  7. CADFusion/data/sl_data/convert.py +125 -0
  8. CADFusion/data/sl_data/sl_data.zip +3 -0
  9. CADFusion/data/vf_data/example_vf_data.zip +3 -0
  10. CADFusion/ds_config.yaml +22 -0
  11. CADFusion/pyproject.toml +38 -0
  12. CADFusion/scripts/alternate_VF.sh +47 -0
  13. CADFusion/scripts/alternate_VF_quadra_gpu.sh +50 -0
  14. CADFusion/scripts/generate_samples.sh +44 -0
  15. CADFusion/scripts/make_dpo_data.sh +5 -0
  16. CADFusion/scripts/preprocess_skexgen.sh +28 -0
  17. CADFusion/scripts/train_loop.sh +42 -0
  18. CADFusion/scripts/train_with_shuffling.sh +20 -0
  19. CADFusion/src/data_preprocessing/call_openai.py +37 -0
  20. CADFusion/src/data_preprocessing/captioning.py +101 -0
  21. CADFusion/src/data_preprocessing/convert.py +120 -0
  22. CADFusion/src/dpo/llava_utils.py +95 -0
  23. CADFusion/src/dpo/make_dpo_dataset.py +162 -0
  24. CADFusion/src/dpo/openai_utils.py +88 -0
  25. CADFusion/src/rendering_utils/geometry/arc.py +32 -0
  26. CADFusion/src/rendering_utils/geometry/circle.py +27 -0
  27. CADFusion/src/rendering_utils/geometry/curve.py +13 -0
  28. CADFusion/src/rendering_utils/geometry/geom_utils.py +95 -0
  29. CADFusion/src/rendering_utils/geometry/line.py +24 -0
  30. CADFusion/src/rendering_utils/geometry/obj_parser.py +276 -0
  31. CADFusion/src/rendering_utils/geometry/obj_utils.py +93 -0
  32. CADFusion/src/rendering_utils/img_renderer.py +84 -0
  33. CADFusion/src/rendering_utils/parser.py +478 -0
  34. CADFusion/src/rendering_utils/parser_visual.py +110 -0
  35. CADFusion/src/rendering_utils/ptl_sampler.py +88 -0
  36. CADFusion/src/rendering_utils/utils/obj_reconverter.py +437 -0
  37. CADFusion/src/rendering_utils/utils/util.py +72 -0
  38. CADFusion/src/test/VLM_score.py +95 -0
  39. CADFusion/src/test/chamfer_dist.py +308 -0
  40. CADFusion/src/test/dist_eval.py +351 -0
  41. CADFusion/src/test/f1_eval.py +74 -0
  42. CADFusion/src/test/generate.ipynb +291 -0
  43. CADFusion/src/test/inference.py +106 -0
  44. CADFusion/src/test/utils.py +86 -0
  45. CADFusion/src/test/visual_utils/__init__.py +0 -0
  46. CADFusion/src/test/visual_utils/parser.py +478 -0
  47. CADFusion/src/train/CAD_dataset.py +89 -0
  48. CADFusion/src/train/dpo.py +79 -0
  49. CADFusion/src/train/llama_finetune.py +127 -0
  50. CADFusion/src/train/utils.py +86 -0
CADFusion/.gitignore ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # PyPI configuration file
171
+ .pypirc
CADFusion/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Microsoft Open Source Code of Conduct
2
+
3
+ This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4
+
5
+ Resources:
6
+
7
+ - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8
+ - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9
+ - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
CADFusion/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Microsoft Corporation.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
CADFusion/README.md ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CADFusion
2
+
3
+
4
+ This repo is the official implementation of paper **[ICML 2025] Text-to-CAD Generation Through Infusing Visual Feedback in Large Language Models** by *Ruiyu Wang, Yu Yuan, Shizhao Sun, Jiang Bian*.
5
+
6
+ [Paper](https://arxiv.org/abs/2501.19054) | [Video](https://www.youtube-nocookie.com/embed/LK8LAzR0v5M?si=FD1Vg9wjkROTKjDV) | [Huggingface](https://huggingface.co/microsoft/CADFusion)
7
+
8
+ CADFusion is a text-to-CAD generation framework that leverages visual feedback to enhance the performance of large language models (LLMs) in generating CAD models from textual descriptions. It consists of two main components: sequential learning and visual learning. The sequential learning component fine-tunes LLMs on a text-to-CAD dataset, while the visual learning component alternates between training a visual feedback model and fine-tuning the LLM with the generated visual feedback.
9
+
10
+ ## Installation
11
+
12
+ - Create a conda environment and install the generic dependencies.
13
+
14
+ ```
15
+ name=<your-env-name>
16
+ conda create -n $name python=3.9
17
+ conda activate $name
18
+ python -m pip install -e .
19
+ ```
20
+
21
+ - Install the additional dependencies for training.
22
+
23
+ ```
24
+ python -m pip install -e .["train"]
25
+ ```
26
+
27
+ - Install the additional dependencies for evaluation and rendering.
28
+
29
+ ```
30
+ python -m pip install -e .["render"]
31
+ conda install -c conda-forge pythonocc-core=7.7.0
32
+ python -m pip install git+https://github.com/otaheri/chamfer_distance@dc9987dcf70888d387d96893ba1fb9ba9a333992
33
+ python -m pip install -e .["eval"]
34
+ ```
35
+
36
+ ## Data Preparation
37
+ CADFusion is trained by alternating the **Sequential Learning (SL)** stage and the **Visual Feedback (VF)** stage.
38
+ We introduce how to prepare the training data for these two stages in the below.
39
+
40
+ ### Data for Sequential Learning
41
+
42
+ #### Approach 1: use human-annotated textual descriptions provided by us
43
+ We provide human-annoated textual descriptions and their correspoding CAD model IDs in [Skexgen](https://github.com/samxuxiang/SkexGen) under `data/sl_data/sl_data.zip`. It should contain the following files after unzipping:
44
+ ```
45
+ data/sl_data
46
+ ├── train.json
47
+ ├── val.json
48
+ ├── test.json
49
+ ```
50
+ To use our annotated data, download the SkexGen data, unzip it as the reference dataset and run the convertion script to get the dataset. In detail, run the following command:
51
+ ```
52
+ # make sure you are in the root directory of this repo and have the 'data/sl_data/sl_data.zip' unzipped
53
+ gdown --id 1so_CCGLIhqGEDQxMoiR--A4CQk4MjuOp
54
+ unzip cad_data.zip
55
+ python3 data/sl_data/convert.py
56
+ ```
57
+ The `train.json`, `val.json` and `test.json` under `data/sl_data` are the datasets.
58
+
59
+ #### Approach 2: create human-annotated textual descriptions by yourself
60
+ We provide a script to execute all the preprocessing steps until human annotation.
61
+ ```
62
+ ./scripts/preprocess_skexgen.sh
63
+ ```
64
+ If you want to customize the internal steps, expand the following section for more details.
65
+ <details>
66
+ <summary>Start from scratch (click to expand).</summary>
67
+
68
+ 1. Download the [SkexGen](https://github.com/samxuxiang/SkexGen) data by: [Google Drive link](https://drive.google.com/file/d/1so_CCGLIhqGEDQxMoiR--A4CQk4MjuOp/view).
69
+
70
+ ```
71
+ gdown --id 1so_CCGLIhqGEDQxMoiR--A4CQk4MjuOp
72
+ unzip cad_data.zip
73
+ ```
74
+
75
+ 2. Convert the SkexGen data into sequences. Note that `train_deduplicate_s.pkl`, `val.pkl` and `test.pkl` should be converted separately.
76
+ ```
77
+ python3 src/data_preprocessing/convert.py --in_path <skexgen_path> --out_path <sequence_path>
78
+ ```
79
+
80
+ 3. Render the sequences into images. *Note that running the last step on linux requires the installation of an x server (e.g. `xvfb`). See [this discussion.](https://github.com/tpaviot/pythonocc-core/issues/1302#issuecomment-2053526444)*
81
+ ```
82
+ python3 src/rendering_utils/parser.py --in-path <sequence_path> --out-path <visual_object_folder>
83
+ timeout 180 python3 src/rendering_utils/parser_visual.py --data_folder <visual_object_folder>
84
+ python3 src/rendering_utils/img_renderer.py --input_dir <visual_object_folder> --output_dir <image_folder>
85
+ ```
86
+
87
+ 4. Annotate these data with LLM captioning.
88
+ ```
89
+ # Generic:
90
+ python3 src/data_preprocessing/captioning.py --image-folder-path <image_folder> --out-path <sl_data_path>
91
+
92
+ ```
93
+ * We use openai and azure system for LLM calling. You are welcome to use your own LLMs and prompts by changing `line 21, 22` of `src/data_preprocessing/captioning.py` with your own client definition and function calls.
94
+ </details>
95
+
96
+
97
+ ### Data for Visual Feedback
98
+
99
+ The Visual Feedback dataset should be automatically generated from the Visual Feedback pipeline described in the Training section.
100
+ We provide an example under `data/vf_data/example_vf_data.json` to help people understand how it should look like.
101
+ You can retrieve this file by unzipping `data/vf_data/example_vf_data.zip`.
102
+ We do not recommend using this example data as the training data, as the policy update should depend on its own generations.
103
+
104
+
105
+ ## Training
106
+ Our training receipe contains two parts. In the first part, we conduct initial sequential learning. In the second part, we conduct alternate training between sequential learning and visual feedback.
107
+ ### Initial Sequential Learning
108
+ We use the following script to train the model in the sequential learning stage.
109
+ ```
110
+ ./scripts/train_with_shuffling.sh <run_name>
111
+ ```
112
+
113
+ You are also welcome to customize the training procedure. A normal training script on multiple GPUs is provided. Change `num_processes` in `ds_config.yaml` to specify how many GPUs will be used.
114
+ ```
115
+ CUDA_VISIBLE_DEVICES=<gpu_ids> accelerate launch --config_file ds_config.yaml src/train/llama_finetune.py \
116
+ --num-epochs <num_epochs> --run-name <run_name> --data-path <train_data> --eval-data-path <eval_data> \
117
+ --device-map accelerate --model-name llama3 --expdir <model_saving_path>
118
+ ```
119
+
120
+ In our work we shuffle the dataset per x epochs. To train model with this implementation, inspect and modify `scripts/train_with_shuffling.sh`.
121
+
122
+ ### Alternate Training between Sequential Learning and Visual Feedback
123
+ We provide a script for executing our alternate training round. See `scripts/alternate_VF.sh`.
124
+ ```
125
+ ./scripts/alternate_VF.sh # change the value of base_name in the script as instructed
126
+ ```
127
+ We also provide a script for training on multiple gpus for saving time: `scripts/alternate_VF_quadra_gpu.sh`. In our setting, we use 4 GPUs for training. You can change the script to use more GPUs if you have them available.
128
+
129
+ If you only want to conduct a single round of visual learning, run
130
+ ```
131
+ python src/train/dpo.py --run-name <dpo_run_name> --pretrained-path <pretrained_model_path> --data-path <dpo_data_Path> --output-path <model_saving_path>
132
+ ```
133
+ By default it runs dpo for 3 epochs, but you can change by adding flag `--num-epochs x`.
134
+
135
+
136
+ ## Model Checkpoints
137
+ We provide two versions.
138
+ v1.0 has 5 rounds of alternate training and is used for evaluation in our paper.
139
+ v1.1 has 9 rounds of alternate training and is considered to have better performance than v1.0.
140
+ - [CADFusion v1.0](https://huggingface.co/microsoft/CADFusion/tree/main/v1_0)
141
+ - [CADFusion v1.1](https://huggingface.co/microsoft/CADFusion/tree/main/v1_1)
142
+
143
+ You should download, unzip and place them under the `exp/model_ckpt` folder for using.
144
+
145
+ ## Inference & Visualization
146
+ Use `scripts/generate_samples.sh`.
147
+ ```
148
+ ./scripts/generate_samples.sh <run_name> test --full
149
+ ```
150
+ You can find samples generated in `exp/model_generation/<run_name>.jsonl` and rendered figures under the `exp/figures/<run_name>` folder. The point clouds, .obj files, .step and .stl files are saved under `exp/visual_objects/<run_name>` directory for your own usage and evaluation.
151
+
152
+ ## Evaluation
153
+ Use the functions in `src/test`. This includes the Chamfer Distance (`chamfer_dist.py`), Minimum Matching Distance, Coverage, Jensen-Shannon Divergence (`dist_eval.py`), and the VLM score (`VLM_score.py`).
154
+
155
+ For VLM Score, we use Azure OpenAI API to access the GPT-4o model for scoring the CAD objects.
156
+ In this way, you should log in your own azure account before using this module.
157
+ If your are using other LLM/VLM service and feel difficult to adapt to our setup, we provide the prompt in the python module that is available for you to integrate into your own testing pipeline.
158
+
159
+ ###
160
+
161
+ ## Acknowledgements
162
+ We would like to acknowledge that the CAD rendering and distributional metrics in this repository is partially based on and adapted from the [SkexGen](https://github.com/samxuxiang/SkexGen) project.
163
+
164
+ ## Citation
165
+ If you find our work useful, please cite the following paper
166
+ ```
167
+ @inproceedings{wang2025texttocad,
168
+ title = {Text-to-CAD Generation Through Infusing Visual Feedback in Large Language Models},
169
+ author = {Wang, Ruiyu and Yuan, Yu and Sun, Shizhao and Bian, Jiang},
170
+ booktitle = {International Conference on Machine Learning},
171
+ year={2025}
172
+ }
173
+ ```
174
+ ## Contributing
175
+
176
+ This project welcomes contributions and suggestions. Most contributions require you to agree to a
177
+ Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
178
+ the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
179
+
180
+ When you submit a pull request, a CLA bot will automatically determine whether you need to provide
181
+ a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
182
+ provided by the bot. You will only need to do this once across all repos using our CLA.
183
+
184
+ This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
185
+ For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
186
+ contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
187
+
188
+ ## Trademarks
189
+
190
+ This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
191
+ trademarks or logos is subject to and must follow
192
+ [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
193
+ Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
194
+ Any use of third-party trademarks or logos are subject to those third-party's policies.
CADFusion/SECURITY.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- BEGIN MICROSOFT SECURITY.MD V0.0.9 BLOCK -->
2
+
3
+ ## Security
4
+
5
+ Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
6
+
7
+ If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
8
+
9
+ ## Reporting Security Issues
10
+
11
+ **Please do not report security vulnerabilities through public GitHub issues.**
12
+
13
+ Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
14
+
15
+ If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
16
+
17
+ You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18
+
19
+ Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20
+
21
+ * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22
+ * Full paths of source file(s) related to the manifestation of the issue
23
+ * The location of the affected source code (tag/branch/commit or direct URL)
24
+ * Any special configuration required to reproduce the issue
25
+ * Step-by-step instructions to reproduce the issue
26
+ * Proof-of-concept or exploit code (if possible)
27
+ * Impact of the issue, including how an attacker might exploit the issue
28
+
29
+ This information will help us triage your report more quickly.
30
+
31
+ If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
32
+
33
+ ## Preferred Languages
34
+
35
+ We prefer all communications to be in English.
36
+
37
+ ## Policy
38
+
39
+ Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
40
+
41
+ <!-- END MICROSOFT SECURITY.MD BLOCK -->
CADFusion/SUPPORT.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: The maintainer of this repo has not yet edited this file
2
+
3
+ **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4
+
5
+ - **No CSS support:** Fill out this template with information about how to file issues and get help.
6
+ - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
7
+ - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
8
+
9
+ *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10
+
11
+ # Support
12
+
13
+ ## How to file issues and get help
14
+
15
+ This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16
+ issues before filing new issues to avoid duplicates. For new issues, file your bug or
17
+ feature request as a new Issue.
18
+
19
+ For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20
+ FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21
+ CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22
+
23
+ ## Microsoft Support Policy
24
+
25
+ Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
CADFusion/data/sl_data/convert.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pickle
3
+
4
+ SKETCH_R = 1
5
+ RADIUS_R = 1
6
+ EXTRUDE_R = 1.0
7
+ SCALE_R = 1.4
8
+ OFFSET_R = 0.9
9
+ PIX_PAD = 4
10
+ CMD_PAD = 3
11
+ COORD_PAD = 4
12
+ EXT_PAD = 1
13
+ EXTRA_PAD = 1
14
+ R_PAD = 2
15
+
16
+ def create_curve_str(se_xy, se_cmd):
17
+ curve_str = ""
18
+ xy_offset = 0
19
+ if se_cmd == 0: # line
20
+ curve_str = " line," + ",".join(str(x) for x in se_xy[0])
21
+ xy_offset = 2
22
+ elif se_cmd == 1: # arc
23
+ curve_str = " arc," + ",".join(str(x) for x in se_xy[0:2].flatten())
24
+ xy_offset = 3
25
+ elif se_cmd == 2: # circle
26
+ curve_str = " circle," + ",".join(str(x) for x in se_xy[0:4].flatten())
27
+ xy_offset = 5
28
+ curve_str += " <curve_end>"
29
+ return curve_str, xy_offset
30
+
31
+
32
+ def create_sketch_str(se_xy, se_cmd):
33
+ sketch_str = ""
34
+ len_xy, len_cmd = len(se_xy), len(se_cmd)
35
+ xy_idx = 0
36
+ for cmd_item in se_cmd: # for each command
37
+ if 0 <= cmd_item <= 2: # curve
38
+ curve_str, xy_offset = create_curve_str(se_xy[xy_idx:], cmd_item)
39
+ sketch_str += curve_str
40
+ xy_idx += xy_offset
41
+ elif cmd_item == -1: # loop
42
+ sketch_str += " <loop_end>"
43
+ xy_idx += 1
44
+ elif cmd_item == -2: # face
45
+ sketch_str += " <face_end>"
46
+ xy_idx += 1
47
+ elif cmd_item == -3: # sketch
48
+ sketch_str += " <sketch_end>"
49
+ xy_idx += 1
50
+ else:
51
+ raise ValueError("Invalid command: " + str(cmd_item))
52
+ if xy_idx != len_xy:
53
+ raise ValueError("xy_idx != len_xy")
54
+ return sketch_str
55
+
56
+
57
+ def create_extrude_str(se_ext):
58
+ extrude_str = ""
59
+ # extrude operation
60
+ if se_ext[14] == 1:
61
+ extrude_str += "add"
62
+ elif se_ext[14] == 2:
63
+ extrude_str += "cut"
64
+ elif se_ext[14] == 3:
65
+ extrude_str += "intersect"
66
+ else:
67
+ raise ValueError("Invalid extrude operation: " + str(se_ext[14]))
68
+ # other extrude parameters
69
+ extrude_str = (
70
+ extrude_str + "," + ",".join(str(x - EXT_PAD) for x in se_ext[0:5])
71
+ ) # ext_v, ext_T
72
+ extrude_str = (
73
+ extrude_str + "," + ",".join(str(x - R_PAD) for x in se_ext[5:14])
74
+ ) # ext_R
75
+ extrude_str = (
76
+ extrude_str + "," + ",".join(str(x - EXT_PAD) for x in se_ext[15:18])
77
+ ) # scale, offset
78
+ # extrude end
79
+ extrude_str += " <extrude_end>"
80
+ return extrude_str
81
+
82
+ def create_command_sequence(item):
83
+ se_str = ""
84
+ num_se = item["num_se"]
85
+ for se_idx in range(num_se): # for each sketch-extrude
86
+ xy, cmd, ext = (
87
+ item["se_xy"][se_idx] - COORD_PAD,
88
+ item["se_cmd"][se_idx] - CMD_PAD,
89
+ item["se_ext"][se_idx],
90
+ )
91
+ se_str = se_str + " " + create_sketch_str(xy, cmd).strip()
92
+ se_str = se_str + " " + create_extrude_str(ext).strip()
93
+ return se_str.strip()
94
+
95
+ with open("data/sl_data/train.json", "r") as f:
96
+ train_data = json.load(f)
97
+ with open("data/sl_data/test.json", "r") as f:
98
+ test_data = json.load(f)
99
+ with open("data/sl_data/val.json", "r") as f:
100
+ val_data = json.load(f)
101
+
102
+ with open("cad_data/train_deduplicate_s.pkl", "rb") as f:
103
+ sk_data = pickle.load(f)
104
+
105
+ for item in train_data:
106
+ serial_num = item['serial_num']
107
+ description = item['description']
108
+ item["command_sequence"] = create_command_sequence(sk_data[serial_num])
109
+
110
+ for item in test_data:
111
+ serial_num = item['serial_num']
112
+ description = item['description']
113
+ item["command_sequence"] = create_command_sequence(sk_data[serial_num])
114
+
115
+ for item in val_data:
116
+ serial_num = item['serial_num']
117
+ description = item['description']
118
+ item["command_sequence"] = create_command_sequence(sk_data[serial_num])
119
+
120
+ with open("data/sl_data/train.json", "w+") as f:
121
+ json.dump(train_data, f, indent=4)
122
+ with open("data/sl_data/test.json", "w+") as f:
123
+ json.dump(test_data, f, indent=4)
124
+ with open("data/sl_data/val.json", "w+") as f:
125
+ json.dump(val_data, f, indent=4)
CADFusion/data/sl_data/sl_data.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a080e00591a07420d916e82365d8602ebeab00ffd909f87bc9911b231f2f5ea0
3
+ size 1084518
CADFusion/data/vf_data/example_vf_data.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:907df4efd2ceafd9d8c336dfbf62d1754f692c0aab72b1b212ea7b844125e702
3
+ size 2142
CADFusion/ds_config.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ gradient_accumulation_steps: 1
5
+ gradient_clipping: 1.0
6
+ offload_optimizer_device: none
7
+ offload_param_device: none
8
+ zero3_init_flag: true
9
+ zero_stage: 2
10
+ distributed_type: DEEPSPEED
11
+ downcast_bf16: 'no'
12
+ machine_rank: 0
13
+ main_training_function: main
14
+ mixed_precision: fp16
15
+ num_machines: 1
16
+ num_processes: 1
17
+ rdzv_backend: static
18
+ same_network: true
19
+ tpu_env: []
20
+ tpu_use_cluster: false
21
+ tpu_use_sudo: false
22
+ use_cpu: false
CADFusion/pyproject.toml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "CADFusion"
7
+ version = "1.0.0"
8
+ description = "Enhancing Text-to-CAD generation via sequential learning and visual feedback."
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: Apache Software License",
14
+ ]
15
+ dependencies = [
16
+ "torch==2.7.1",
17
+ "transformers==4.50.0",
18
+ "huggingface_hub==0.26.0",
19
+ "peft==0.9.0",
20
+ "accelerate==0.28.0",
21
+ "psutil==5.9.8",
22
+ "pillow==10.4.0",
23
+ "datasets==3.1.0",
24
+ "trl==0.11.4",
25
+ "gdown==5.2.0"
26
+ ]
27
+
28
+ [project.optional-dependencies]
29
+ train = ["wandb==0.16.4", "deepspeed==0.15.0"]
30
+ render = ["trimesh==4.4.9", "plyfile==1.0.3"]
31
+ eval = ["openai==1.75.0", "azure-identity==1.21.0", "scikit-learn==1.3.2"]
32
+ build = ["build", "twine"]
33
+
34
+ [tool.setuptools.packages.find]
35
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
36
+
37
+ [tool.wheel]
38
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
CADFusion/scripts/alternate_VF.sh ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set it to your data path
2
+ data_path=data/sl_data
3
+ # by default set it to CADFusion/exp
4
+ exp_path=exp/model_ckpt
5
+ # by default set it to CADFusion/data
6
+ vf_path=data/vf_data
7
+ train_data=$data_path/train.json
8
+ eval_data=$data_path/val.json
9
+
10
+ # This script requires your SL run named as xxxx0, because for each VF stage, the final digit increments
11
+ # to show the number of VF rounds finished.
12
+ # e.g. SL name: CAD-0
13
+ # base_name: CAD- (remove the last digit, the script autofills it)
14
+ # VF run 1: CAD-1 (automatically)
15
+ # VF run 2: CAD-2 (automatically)
16
+ # ...
17
+ base_name=model_name_you_trained_for_SL_with_last_digit_removed
18
+
19
+ run_name=${base_name}0
20
+ ./scripts/generate_samples.sh $run_name test "--full --device-map auto"
21
+ ./scripts/generate_samples.sh $run_name train "--sample-len 1000 --device-map auto"
22
+
23
+ ./scripts/make_dpo_data.sh $run_name --score-only
24
+ ./scripts/make_dpo_data.sh $run_name-train "--gpu 0"
25
+
26
+
27
+ for LOOP in 1 2 3 4 5
28
+ do
29
+ echo "Starting VF round $LOOP"
30
+ run_name=$base_name$LOOP
31
+ dpo_training_path=$vf_path/$base_name$((LOOP-1))-train.json
32
+ dpo_run_name=$base_name$LOOP-dpo
33
+ dpo_save_path=$exp_path/$dpo_run_name
34
+ sft_run_name=$base_name$LOOP
35
+
36
+ python src/train/dpo.py --run-name $dpo_run_name --pretrained-path $exp_path/$base_name$((LOOP-1)) --data-path $dpo_training_path --output-path $dpo_save_path
37
+ python src/train/llama_finetune.py --num-epochs 1 --run-name $sft_run_name --data-path $train_data --eval-data-path $eval_data --eval-freq 3000 --pretrained-path $dpo_save_path --expdir $exp_path
38
+
39
+ ./scripts/generate_samples.sh $dpo_run_name test "--full --device-map auto"
40
+ ./scripts/generate_samples.sh $run_name test "--full --device-map auto"
41
+ ./scripts/generate_samples.sh $run_name train "--sample-len 1000 --device-map auto"
42
+
43
+ ./scripts/make_dpo_data.sh $dpo_run_name --score-only
44
+ ./scripts/make_dpo_data.sh $run_name "--score-only --gpu 0"
45
+ ./scripts/make_dpo_data.sh $run_name-train "--gpu 0"
46
+
47
+ done
CADFusion/scripts/alternate_VF_quadra_gpu.sh ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set it to your data path
2
+ data_path=data/sl_data
3
+ # by default set it to CADFusion/exp
4
+ exp_path=exp/model_ckpt
5
+ # by default set it to CADFusion/data
6
+ vf_path=data/vf_data
7
+ train_data=$data_path/train.json
8
+ eval_data=$data_path/val.json
9
+
10
+ # This script requires your SL run named as xxxx0, because for each VF stage, the final digit increments
11
+ # to show the number of VF rounds finished.
12
+ # e.g. SL name: CAD-0
13
+ # base_name: CAD- (remove the last digit, the script autofills it)
14
+ # VF run 1: CAD-1 (automatically)
15
+ # VF run 2: CAD-2 (automatically)
16
+ # ...
17
+ base_name=model_name_you_trained_for_SL_with_last_digit_removed
18
+
19
+ run_name=${base_name}0
20
+ CUDA_VISIBLE_DEVICES=0,1 ./scripts/generate_samples.sh $run_name test "--full --device-map auto" &
21
+ CUDA_VISIBLE_DEVICES=2,3 ./scripts/generate_samples.sh $run_name train "--sample-len 10 --device-map auto"
22
+ wait
23
+
24
+ ./scripts/make_dpo_data.sh $run_name --score-only &
25
+ ./scripts/make_dpo_data.sh $run_name-train "--gpu 1"
26
+ wait
27
+
28
+
29
+ for LOOP in 1 2 3 4 5
30
+ do
31
+ echo "Starting VF round $LOOP"
32
+ run_name=$base_name$LOOP
33
+ dpo_training_path=$vf_path/$base_name$((LOOP-1))-train.json
34
+ dpo_run_name=$base_name$LOOP-dpo
35
+ dpo_save_path=$exp_path/$dpo_run_name
36
+ sft_run_name=$base_name$LOOP
37
+
38
+ python src/train/dpo.py --run-name $dpo_run_name --pretrained-path $exp_path/$base_name$((LOOP-1)) --data-path $dpo_training_path --output-path $dpo_save_path
39
+ python src/train/llama_finetune.py --num-epochs 1 --run-name $sft_run_name --data-path $train_data --eval-data-path $eval_data --eval-freq 3000 --pretrained-path $dpo_save_path --expdir $exp_path
40
+
41
+ CUDA_VISIBLE_DEVICES=0 ./scripts/generate_samples.sh $dpo_run_name test "--full --device-map auto" &
42
+ CUDA_VISIBLE_DEVICES=1 ./scripts/generate_samples.sh $run_name test "--full --device-map auto" &
43
+ CUDA_VISIBLE_DEVICES=2,3 ./scripts/generate_samples.sh $run_name train "--sample-len 1000 --device-map auto"
44
+ wait
45
+
46
+ ./scripts/make_dpo_data.sh $dpo_run_name --score-only &
47
+ ./scripts/make_dpo_data.sh $run_name "--score-only --gpu 1" &
48
+ ./scripts/make_dpo_data.sh $run_name-train "--gpu 2"
49
+ wait
50
+ done
CADFusion/scripts/generate_samples.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_data_path=data/sl_data/train.json
2
+ test_data_path=data/sl_data/test.json
3
+ run_name=$1
4
+ temperature=0.9
5
+
6
+ if [ -z "$2" ]
7
+ then
8
+ data_path=$test_data_path
9
+ else
10
+ if [ $2 = "train" ]; then
11
+ data_path=$train_data_path
12
+ run_name=$1-train
13
+ else
14
+ data_path=$test_data_path
15
+ temperature=0.3
16
+ fi
17
+ fi
18
+
19
+ model_path=exp/model_ckpt/$1
20
+ inference_path=exp/model_generation/$run_name.jsonl
21
+ visual_obj_path=exp/visual_objects/$run_name
22
+ output_figure_path=exp/figures/$run_name
23
+ log_path=exp/logs/$run_name
24
+
25
+ mkdir -p $log_path
26
+ mkdir -p exp/model_generation
27
+
28
+ echo "--------------------Inferencing--------------------" > $log_path/inference.txt
29
+ rm $inference_path
30
+ python3 src/test/inference.py --pretrained-path $model_path --in-path $data_path --out-path $inference_path --num-samples 5 --temperature $temperature --model-name llama3 > $log_path/inference.txt $3
31
+
32
+ echo "--------------------Parsing CAD objects--------------------" > $log_path/parsing_cad.txt
33
+ rm -rf $visual_obj_path
34
+ python3 src/rendering_utils/parser.py --in-path $inference_path --out-path $visual_obj_path > $log_path/parsing_cad.txt
35
+
36
+ echo "--------------------Parsing visual objects--------------------" > $log_path/parsing_visual.txt
37
+ python3 src/rendering_utils/parser_visual.py --data_folder $visual_obj_path > $log_path/parsing_visual.txt
38
+ python3 src/rendering_utils/ptl_sampler.py --in_dir $visual_obj_path --out_dir ptl > $log_path/sampling_ptl.out
39
+
40
+ echo "--------------------Rendering--------------------" > $log_path/rendering.txt
41
+ rm -rf $output_figure_path
42
+ export DISPLAY=:99
43
+ Xvfb :99 -screen 0 640x480x24 &
44
+ python3 src/rendering_utils/img_renderer.py --input_dir $visual_obj_path --output_dir $output_figure_path > $log_path/rendering.txt
CADFusion/scripts/make_dpo_data.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ source_path=exp/model_generation/$1.jsonl
2
+ figure_path=exp/figures/$1/
3
+ save_path=data/vf_data/$1.json
4
+
5
+ python src/dpo/make_dpo_dataset.py --source-data-path $source_path --figure-path $figure_path --save-path $save_path --num-samples 5 $2
CADFusion/scripts/preprocess_skexgen.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gdown --id 1so_CCGLIhqGEDQxMoiR--A4CQk4MjuOp
2
+ unzip cad_data.zip
3
+
4
+ # convert data into sequence and save in json
5
+ mkdir data
6
+ mkdir data/raw
7
+ python3 src/data_preprocessing/convert.py --in-path cad_data/train_deduplicate_s.pkl --out-path data/raw/train.json
8
+ python3 src/data_preprocessing/convert.py --in-path cad_data/val.pkl --out-path data/raw/val.json
9
+ python3 src/data_preprocessing/convert.py --in-path cad_data/test.pkl --out-path data/raw/test.json
10
+
11
+ # render the image for each entry in order to retrieve textual information by captioning:
12
+ mkdir exp
13
+ mkdir exp/visual_objects
14
+ mkdir exp/figures
15
+ for file in test val train; do
16
+ python3 src/rendering_utils/parser.py --in-path data/raw/$file.json --out-path exp/visual_objects/$file
17
+ timeout 180 python3 src/rendering_utils/parser_visual.py --data_folder exp/visual_objects/$file
18
+
19
+ export DISPLAY=:99
20
+ Xvfb :99 -screen 0 640x480x24 &
21
+ python3 src/rendering_utils/img_renderer.py --input_dir exp/visual_objects/$file --output_dir exp/figures/$file
22
+ done
23
+
24
+ # caption the images to generate descriptions
25
+ mkdir data/sl_data
26
+ python3 src/data_preprocessing/captioning.py --image-folder-path exp/figures/train --out-path data/sl_data/train.json
27
+ python3 src/data_preprocessing/captioning.py --image-folder-path exp/figures/val --out-path data/sl_data/val.json
28
+ python3 src/data_preprocessing/captioning.py --image-folder-path exp/figures/test --out-path data/sl_data/test.json
CADFusion/scripts/train_loop.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # by default set it to CADFusion/data
2
+ data_path=/your/path/to/data/folder
3
+ # by default set it to CADFusion/exp
4
+ exp_path=/your/path/to/exp/folder
5
+ # by default set it to CADFusion/data
6
+ exp_path=/your/path/to/vf_data/folder
7
+ train_data=$data_path/train.json
8
+ eval_data=$data_path/eval.json
9
+
10
+ base_name=model_name_you_trained_for_SL
11
+
12
+ run_name=${base_name}0
13
+ CUDA_VISIBLE_DEVICES=0,1 ./scripts/inference.sh $run_name test "--full --device-map auto" &
14
+ CUDA_VISIBLE_DEVICES=2,3 ./scripts/inference.sh $run_name train "--sample-len 1000 --device-map auto"
15
+ wait
16
+
17
+ ./scripts/make_dpo_data.sh $run_name --score-only &
18
+ ./scripts/make_dpo_data.sh $run_name-train "--gpu 1"
19
+ wait
20
+
21
+
22
+ for LOOP in 1 2 3 4 5
23
+ do
24
+ run_name=$base_name$LOOP
25
+ dpo_training_path=$vf_path/$base_name$((LOOP-1))-train.json
26
+ dpo_run_name=$base_name$LOOP-dpo
27
+ dpo_save_path=$exp_path/$dpo_run_name
28
+ sft_run_name=$base_name$LOOP
29
+
30
+ python src/train/dpo.py --run-name $dpo_run_name --pretrained-path $exp_path/$base_name$((LOOP-1)) --data-path $dpo_training_path --output-path $dpo_save_path
31
+ python src/train/llama_finetune.py --num-epochs 1 --run-name $sft_run_name --data-path $train_data --eval-data-path $eval_data --eval-freq 3000 --pretrained-path $dpo_save_path --expdir $exp_path
32
+
33
+ CUDA_VISIBLE_DEVICES=0 ./scripts/inference.sh $dpo_run_name test "--full --device-map auto" &
34
+ CUDA_VISIBLE_DEVICES=1 ./scripts/inference.sh $run_name test "--full --device-map auto" &
35
+ CUDA_VISIBLE_DEVICES=2,3 ./scripts/inference.sh $run_name train "--sample-len 1000 --device-map auto"
36
+ wait
37
+
38
+ ./scripts/make_dpo_data.sh $dpo_run_name --score-only &
39
+ ./scripts/make_dpo_data.sh $run_name "--score-only --gpu 1" &
40
+ ./scripts/make_dpo_data.sh $run_name-train "--gpu 2"
41
+ wait
42
+ done
CADFusion/scripts/train_with_shuffling.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set it to your data path
2
+ data_path=data/sl_data
3
+ # set it to your experiment path
4
+ exp_path=exp/model_ckpt
5
+ train_data=$data_path/train.json
6
+ eval_data=$data_path/val.json
7
+ shuffle_dataset_between_x_epochs=2
8
+ mkdir -p $exp_path
9
+
10
+ # round 0
11
+ accelerate launch --config_file ds_config.yaml src/train/llama_finetune.py --lora-rank 32 --lora-alpha 32 \
12
+ --num-epochs $shuffle_dataset_between_x_epochs --run-name $1 --data-path $train_data --eval-data-path $eval_data \
13
+ --device-map accelerate --eval-freq 1000 --save-freq 50000 --model-name llama3 --expdir $exp_path
14
+
15
+ for round in 1 2 3 4 5 6 7 8 9
16
+ do
17
+ python src/train/llama_finetune.py --lora-rank 32 --pretrained-path $exp_path/$1 --lora-alpha 32 \
18
+ --num-epochs $shuffle_dataset_between_x_epochs --run-name $1 --data-path $train_data --eval-data-path $eval_data \
19
+ --eval-freq 4000 --save-freq 50000 --expdir $exp_path
20
+ done
CADFusion/src/data_preprocessing/call_openai.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import AzureOpenAI
2
+ from azure.identity import AzureCliCredential, get_bearer_token_provider
3
+
4
+ import time
5
+
6
+ def setup_client():
7
+ scope = "api://trapi/.default"
8
+ credential = get_bearer_token_provider(AzureCliCredential(), scope)
9
+
10
+ api_version = '2024-12-01-preview'
11
+ deployment_name = 'gpt-4o_2024-08-06'
12
+ instance = 'gcr/shared/' # See https://aka.ms/trapi/models for the instance name, remove /openai (library adds it implicitly)
13
+ endpoint = f'https://trapi.research.microsoft.com/{instance}'
14
+
15
+ client = AzureOpenAI(
16
+ azure_endpoint=endpoint,
17
+ azure_ad_token_provider=credential,
18
+ api_version=api_version,
19
+ )
20
+ return client, deployment_name
21
+
22
+
23
+ def call_openai(client, deployment, prompt):
24
+ output = None
25
+ while output is None:
26
+ try:
27
+ time.sleep(0.5)
28
+ completion = client.chat.completions.create(
29
+ model = deployment,
30
+ messages = prompt,
31
+ )
32
+ output = completion.choices[0].message.content
33
+ except Exception as e:
34
+ print("API error:", e)
35
+ time.sleep(1)
36
+ output = None
37
+ return output
CADFusion/src/data_preprocessing/captioning.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import base64
4
+ import json
5
+ import time
6
+ from mimetypes import guess_type
7
+ from tqdm import tqdm
8
+ # from parse_sequence import parse_sequence
9
+ # from parse_visual import run_parallel
10
+ # from parse_image import render_file
11
+ from call_openai import setup_client, call_openai
12
+ import argparse
13
+
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--image-folder-path', type=str, default='exp/figures/test', help='Path to the input folder')
16
+ parser.add_argument('--out-path', type=str, default='data/raw', help='Path to the output file')
17
+ args = parser.parse_args()
18
+ file_path = args.image_folder_path
19
+ out_path = args.out_path
20
+
21
+ client, deployment_name = setup_client()
22
+ call_client = call_openai
23
+
24
+ def local_image_to_data_url(image_path):
25
+ # Encode a local image into data URL
26
+ mime_type, _ = guess_type(image_path)
27
+ if mime_type is None:
28
+ mime_type = 'application/octet-stream'
29
+ with open(image_path, "rb") as image_file:
30
+ base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
31
+ return f"data:{mime_type};base64,{base64_encoded_data}"
32
+
33
+ def call_model_1(prompt, image_path):
34
+ message_text = [
35
+ {"role":"system","content":"You are an AI assistant that helps people find information."},
36
+ {"role":"user","content":[
37
+ {
38
+ "type": "text",
39
+ "text": prompt
40
+ },
41
+ {
42
+ "type": "image_url",
43
+ "image_url": {"url": local_image_to_data_url(image_path)}
44
+ }
45
+ ]}
46
+ ]
47
+ return call_client(client, deployment_name, message_text)
48
+
49
+ def call_model_2(prompt1, image_path, output1, prompt2):
50
+ message_text = [
51
+ {"role":"system","content":"You are an AI assistant that helps people find information."},
52
+ {"role":"user","content":[
53
+ {
54
+ "type": "text",
55
+ "text": prompt1
56
+ },
57
+ {
58
+ "type": "image_url",
59
+ "image_url": {"url": local_image_to_data_url(image_path)}
60
+ }
61
+ ]},
62
+ {"role":"assistant","content":output1},
63
+ {"role":"user","content":prompt2}
64
+ ]
65
+ return call_client(client, deployment_name, message_text)
66
+
67
+ files = [f for f in os.listdir(args.image_folder_path) if os.path.isfile(os.path.join(args.image_folder_path, f))]
68
+ files.sort()
69
+ results = []
70
+ for filename in tqdm(files):
71
+ time.sleep(0.5)
72
+ output1 = None
73
+ output2 = None
74
+ image_path = os.path.join(file_path, filename)
75
+ # Send request
76
+ prompt1 = """Propose a series of questions about the 3D shape and give the answers. The first question should ask for a detailed description and others should focus on the specific geometric properties, number, size proportions and positional relationship, and other details."""
77
+ prompt2 = """Based on the dialogue, please give a final description of the 3D shape. No more than 70 words."""
78
+ while output1 is None or str(output1).startswith("I'm sorry"):
79
+ try:
80
+ output1 = call_model_1(prompt1, image_path)
81
+ except requests.RequestException as e:
82
+ print(f"Request failed: {e}")
83
+ time.sleep(1)
84
+ output1 = None
85
+ while output2 is None or str(output2).startswith("I'm sorry"):
86
+ try:
87
+ output2 = call_model_2(prompt1, image_path, output1, prompt2)
88
+ except requests.RequestException as e:
89
+ print(f"Request failed: {e}")
90
+ time.sleep(1)
91
+ output2 = None
92
+
93
+ result = {
94
+ "pic_name":filename,
95
+ "questions": output1,
96
+ "description":output2
97
+ }
98
+ results.append(result)
99
+
100
+ with open(out_path, 'w+', encoding='utf-8') as f:
101
+ json.dump(results, f, ensure_ascii=False, indent=4)
CADFusion/src/data_preprocessing/convert.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import argparse
3
+ import json
4
+ # hyperparameters from SkexGen project
5
+ SKETCH_R = 1
6
+ RADIUS_R = 1
7
+ EXTRUDE_R = 1.0
8
+ SCALE_R = 1.4
9
+ OFFSET_R = 0.9
10
+ PIX_PAD = 4
11
+ CMD_PAD = 3
12
+ COORD_PAD = 4
13
+ EXT_PAD = 1
14
+ EXTRA_PAD = 1
15
+ R_PAD = 2
16
+
17
+
18
+ def create_curve_str(se_xy, se_cmd):
19
+ curve_str = ""
20
+ xy_offset = 0
21
+ if se_cmd == 0: # line
22
+ curve_str = " line," + ",".join(str(x) for x in se_xy[0])
23
+ xy_offset = 2
24
+ elif se_cmd == 1: # arc
25
+ curve_str = " arc," + ",".join(str(x) for x in se_xy[0:2].flatten())
26
+ xy_offset = 3
27
+ elif se_cmd == 2: # circle
28
+ curve_str = " circle," + ",".join(str(x) for x in se_xy[0:4].flatten())
29
+ xy_offset = 5
30
+ curve_str += " <curve_end>"
31
+ return curve_str, xy_offset
32
+
33
+
34
+ def create_sketch_str(se_xy, se_cmd):
35
+ sketch_str = ""
36
+ len_xy, len_cmd = len(se_xy), len(se_cmd)
37
+ xy_idx = 0
38
+ for cmd_item in se_cmd: # for each command
39
+ if 0 <= cmd_item <= 2: # curve
40
+ curve_str, xy_offset = create_curve_str(se_xy[xy_idx:], cmd_item)
41
+ sketch_str += curve_str
42
+ xy_idx += xy_offset
43
+ elif cmd_item == -1: # loop
44
+ sketch_str += " <loop_end>"
45
+ xy_idx += 1
46
+ elif cmd_item == -2: # face
47
+ sketch_str += " <face_end>"
48
+ xy_idx += 1
49
+ elif cmd_item == -3: # sketch
50
+ sketch_str += " <sketch_end>"
51
+ xy_idx += 1
52
+ else:
53
+ raise ValueError("Invalid command: " + str(cmd_item))
54
+ if xy_idx != len_xy:
55
+ raise ValueError("xy_idx != len_xy")
56
+ return sketch_str
57
+
58
+
59
+ def create_extrude_str(se_ext):
60
+ extrude_str = ""
61
+ # extrude operation
62
+ if se_ext[14] == 1:
63
+ extrude_str += "add"
64
+ elif se_ext[14] == 2:
65
+ extrude_str += "cut"
66
+ elif se_ext[14] == 3:
67
+ extrude_str += "intersect"
68
+ else:
69
+ raise ValueError("Invalid extrude operation: " + str(se_ext[14]))
70
+ # other extrude parameters
71
+ extrude_str = (
72
+ extrude_str + "," + ",".join(str(x - EXT_PAD) for x in se_ext[0:5])
73
+ ) # ext_v, ext_T
74
+ extrude_str = (
75
+ extrude_str + "," + ",".join(str(x - R_PAD) for x in se_ext[5:14])
76
+ ) # ext_R
77
+ extrude_str = (
78
+ extrude_str + "," + ",".join(str(x - EXT_PAD) for x in se_ext[15:18])
79
+ ) # scale, offset
80
+ # extrude end
81
+ extrude_str += " <extrude_end>"
82
+ return extrude_str
83
+
84
+
85
+ def convert(in_path, out_path):
86
+ with open(in_path, "rb") as f:
87
+ data = pickle.load(f)
88
+ print("Data loaded: " + str(len(data)) + " samples")
89
+
90
+ results = []
91
+ for item in data: # for each data
92
+ se_str = ""
93
+ num_se = item["num_se"]
94
+ for se_idx in range(num_se): # for each sketch-extrude
95
+ xy, cmd, ext = (
96
+ item["se_xy"][se_idx] - COORD_PAD,
97
+ item["se_cmd"][se_idx] - CMD_PAD,
98
+ item["se_ext"][se_idx],
99
+ )
100
+ se_str = se_str + " " + create_sketch_str(xy, cmd).strip()
101
+ se_str = se_str + " " + create_extrude_str(ext).strip()
102
+ results.append(se_str.strip())
103
+
104
+ # with open(out_path, "wb") as f:
105
+ # pickle.dump(results, f)
106
+ # print("Data converted: " + str(len(results)) + " samples")
107
+ with open(out_path, "w") as f:
108
+ json.dump(results, f, indent=4)
109
+ print("Data converted: " + str(len(results)) + " samples")
110
+ # with open(out_path, "w") as f: # Open in text mode
111
+ # for result in results:
112
+ # f.write(result + "\n")
113
+
114
+ if __name__ == "__main__":
115
+ parser = argparse.ArgumentParser()
116
+ parser.add_argument("--in-path", type=str, required=True)
117
+ parser.add_argument("--out-path", type=str, required=True)
118
+ args = parser.parse_args()
119
+
120
+ convert(args.in_path, args.out_path)
CADFusion/src/dpo/llava_utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import time
4
+ import json
5
+ import requests
6
+ from mimetypes import guess_type
7
+ from transformers import pipeline
8
+ from transformers import LlavaNextProcessor
9
+ from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
10
+ import torch
11
+ from PIL import Image
12
+ dev='cuda:0'
13
+
14
+ # processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
15
+ # model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
16
+ # model.to(device)
17
+
18
+ def restart_model(device):
19
+ global dev
20
+ dev = device
21
+ processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
22
+ model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
23
+ model.to(device)
24
+ return model, processor
25
+
26
+ def ask_llm_on_figure(data, model, processor):
27
+ """
28
+ The layout of a typical data item
29
+ {
30
+ "index": 1,
31
+ "pic_name": "000000_001_final.png",
32
+ "ground_truth": "line,9,9 <curve_end> line,9,53 <curve_end> line,53,53 <curve_end> line,53,9 <curve_end> <loop_end> circle,31,29,31,20,35,25,27,25 <curve_end> <loop_end> circle,31,41,31,32,35,37,27,37 <curve_end> <loop_end> <face_end> <sketch_end> add,31,32,31,31,31,0,1,0,0,0,1,1,0,0,62,31,31 <extrude_end>",
33
+ "description": "Create a rectangular panel with two circular through-holes centrally aligned on the vertical axis.",
34
+ "prompt": "Below is a description of a 3D shape:\nCreate a rectangular panel with two circular through-holes centrally aligned on the vertical axis.\nGenerate a Computer-Aided Design(CAD) command sequence of the 3D shape:\n",
35
+ "output": "line,se,9 <curve_end> line,ne,9 <curve_end> line,ne,53 <curve_end> line,se,53 <curve_end> <loop_end> circle,22,41,22, Twenty1 ,31,30,12,30 <curve_end> <loop_end> circle,40,21,40, Ten2 ,50,32,29,32 <curve_end> <loop_end> <face_end> <sketch_end> add,31,33,31,31,31,1,0,0,0,0,1,0,-1,0,62,31,31 <extrude_end>"
36
+ },
37
+ """
38
+ url = data['figure_path']
39
+ image = Image.open(url)
40
+ description = data['description']
41
+ # data_scale = 10
42
+ # measurement = 'the degree of correspondence between them'
43
+
44
+ prompt = 'You are a harsh grader for new CAD designers\' works. The following is a text description of a CAD figure that they designed and an image of a CAD instance.' +\
45
+ f'\nDescription: {description}\n ' + \
46
+ f'Comment on this work for \n '+\
47
+ '1. If the overall shape remains correct; \n '+\
48
+ '2. If the number of components are correct, especially the circular holes; \n '+\
49
+ '3. If the distribution of the components are natural, i.e. they are not clustered together or collide with each other.\n'+\
50
+ 'After that, give a score out of 10. Do not comment on issues such as texture, smoothness and colors'
51
+
52
+
53
+ conversation = [
54
+ {
55
+ "role": "user",
56
+ "content": [
57
+ {"type": "text", "text": prompt},
58
+ {"type": "image"},
59
+ ],
60
+ },
61
+ ]
62
+ prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
63
+ inputs = processor(images=image, text=prompt, return_tensors="pt",).to(dev, torch.float16)
64
+
65
+ # autoregressively complete prompt
66
+ output = model.generate(**inputs, max_new_tokens=256, pad_token_id=processor.tokenizer.eos_token_id)
67
+ output = processor.decode(output[0], skip_special_tokens=True)
68
+ idx = output.index('assistant\n')
69
+ response = output[idx+10:]
70
+ return(response)
71
+
72
+
73
+ def ask_llm(data, model, processor):
74
+ description = data['gpt_label']
75
+
76
+ prompt = 'The following is an evaluation of an CAD object.' +\
77
+ f'\n evaluation: {description}\n' +\
78
+ 'Extract the integer score of the evaluation. The score is between 0 to 10. Return the number only.'
79
+
80
+ conversation = [
81
+ {
82
+ "role": "user",
83
+ "content": [
84
+ {"type": "text", "text": prompt},
85
+ ],
86
+ },
87
+ ]
88
+ prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
89
+ inputs = processor(text=prompt, return_tensors="pt",).to(dev, torch.float16)
90
+
91
+ output = model.generate(**inputs, max_new_tokens=16, pad_token_id=processor.tokenizer.eos_token_id)
92
+ output = processor.decode(output[0], skip_special_tokens=True)
93
+ idx = output.index('assistant\n')
94
+ response = output[idx+10:]
95
+ return(response)
CADFusion/src/dpo/make_dpo_dataset.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+ import argparse
5
+
6
+ from openai_utils import ask_gpt_on_figure, ask_gpt
7
+ from llava_utils import ask_llm, ask_llm_on_figure, restart_model
8
+ from tqdm import tqdm
9
+
10
+
11
+ if __name__ == '__main__':
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("--source-data-path", type=str, required=True)
14
+ parser.add_argument("--figure-path", type=str, required=True)
15
+ parser.add_argument("--save-path", type=str, required=True)
16
+ parser.add_argument("--num-samples", type=int, required=True)
17
+ parser.add_argument("--gpu", type=int, default=0)
18
+ parser.add_argument("--score-only", action="store_true", default=False)
19
+ parser.add_argument("--gpt", action="store_true", default=False)
20
+ args = parser.parse_args()
21
+
22
+ source_path = args.source_data_path
23
+ folder_path = args.figure_path
24
+ save_path = args.save_path
25
+ num_samples = args.num_samples
26
+ device=f'cuda:{args.gpu}'
27
+ if args.gpt:
28
+ func1, func2 = ask_gpt_on_figure, ask_gpt
29
+ model = None
30
+ processor = None
31
+ else:
32
+ func1, func2 = ask_llm_on_figure, ask_llm
33
+ model, processor = restart_model(device)
34
+
35
+ with open(source_path, 'r') as f:
36
+ test_data = json.load(f)
37
+
38
+ ####### Stage 1 #######
39
+ # for model generations that are able to render pictures,
40
+ # ask gpt to rate the generation quality.
41
+ for data in tqdm(test_data):
42
+ file_id = str(data['index']).zfill(6)
43
+ file = None
44
+ for f in os.listdir(folder_path):
45
+ if f.startswith(file_id):
46
+ file = folder_path + f
47
+ data['figure_path'] = file
48
+ error_cnt = 0
49
+ while 1:
50
+ try:
51
+ data['gpt_label'] = func1(data, model, processor)
52
+ break
53
+ except Exception as e:
54
+ print(e)
55
+ if args.gpt:
56
+ time.sleep(3)
57
+ else:
58
+ if error_cnt == 5:
59
+ exit()
60
+ model, processor = restart_model(device)
61
+ error_cnt += 1
62
+ with open(save_path, 'w+') as f:
63
+ json.dump(test_data, f, indent=4)
64
+
65
+ with open(save_path, 'r') as f:
66
+ test_data = json.load(f)
67
+ ####### Stage 2 #######
68
+ # clean up the dataset to summarize the generation quality estimation to a numerical score, and
69
+ # remove the failed ones, i.e. the generations that cannot render
70
+ for data in tqdm(test_data):
71
+ if "gpt_label" in data.keys():
72
+ error_cnt = 0
73
+ while 1:
74
+ try:
75
+ score = func2(data, model, processor)
76
+ print(score)
77
+ break
78
+ except Exception as e:
79
+ print(e)
80
+ if args.gpt:
81
+ time.sleep(3)
82
+ else:
83
+ if error_cnt == 5:
84
+ exit()
85
+ model, processor = restart_model(device)
86
+ error_cnt += 1
87
+ try:
88
+ data['gpt_score'] = int(score)
89
+ except:
90
+ print(f'ERROR: {score}')
91
+ pass
92
+
93
+ saved_data = [data for data in test_data if 'gpt_score' in data.keys()]
94
+ with open(save_path, 'w+') as f:
95
+ json.dump(saved_data, f, indent=4)
96
+
97
+ if args.score_only:
98
+ exit()
99
+
100
+ ####### Stage 3 #######
101
+ # 1. group up the scored generations by their description: we do not compare
102
+ # generation results that come from different origin prompts
103
+ temp_data = []
104
+ max_idx = test_data[-1]['index']
105
+ sample_size = max_idx // num_samples + 1
106
+ # a. select if any above 6
107
+
108
+ # for i in range(sample_size):
109
+ # next_sample = test_data[i*num_samples:(i+1)*num_samples]
110
+ # next_sample = [item for item in next_sample if 'gpt_score' in item.keys()]
111
+ # above_score = [item['gpt_score'] >= 6 for item in next_sample]
112
+ # if any(above_score):
113
+ # temp_data.extend(next_sample)
114
+ # temp_data = [data for data in temp_data if 'gpt_score' in data.keys()]
115
+
116
+ # b. select if avg above 6
117
+
118
+ # for i in range(sample_size):
119
+ # next_sample = test_data[i*num_samples:(i+1)*num_samples]
120
+ # next_sample = [item for item in next_sample if 'gpt_score' in item.keys()]
121
+ # if len(next_sample) == 0:
122
+ # continue
123
+ # scores = sum(item['gpt_score'] for item in next_sample) / len(next_sample)
124
+ # if scores >= 6:
125
+ # temp_data.extend(next_sample)
126
+ # temp_data = [data for data in temp_data if 'gpt_score' in data.keys()]
127
+
128
+ # c. select if individual above 6
129
+ test_data = saved_data
130
+ for item in test_data:
131
+ if 'gpt_score' not in item.keys():
132
+ continue
133
+ if item['gpt_score'] >= 6:
134
+ temp_data.append(item)
135
+ print(test_data[-1]['index'], max_idx)
136
+
137
+ grouped = [[] for _ in range(max_idx)]
138
+ for item in temp_data:
139
+ idx = item['index']
140
+ grouped[idx // num_samples].append(item)
141
+ grouped = [item for item in grouped if len(item) > 0]
142
+
143
+ # 2. within each group, make pairs where the chosens have higher score than the rejected ones.
144
+ # TODO: find a way to balance the data generated from each group
145
+ final_data = []
146
+ for group in grouped:
147
+ for item1 in group:
148
+ for item2 in group:
149
+ if item2['gpt_score'] > item1['gpt_score']:
150
+ info_dict = {
151
+ "description": item1['description'],
152
+ "prompt": item1['prompt'],
153
+ "chosen": item2['output'],
154
+ "rejected": item1['output']
155
+ }
156
+ final_data.append(info_dict)
157
+ # uncomment this break if you do not want too many data.
158
+ # break
159
+
160
+
161
+ with open(save_path, 'w+') as f:
162
+ json.dump(final_data, f, indent=4)
CADFusion/src/dpo/openai_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import time
4
+ import json
5
+
6
+ from mimetypes import guess_type
7
+ from openai import AzureOpenAI
8
+ from azure.identity import DefaultAzureCredential, get_bearer_token_provider
9
+
10
+ END_POINT = '<endpoint>'
11
+ MODEL_NAME = 'gpt-4o_2024-08-06'
12
+ API_VER = '2024-02-01'
13
+
14
+ def local_image_to_data_url(image_path):
15
+ # Encode a local image into data URL
16
+ mime_type, _ = guess_type(image_path)
17
+ if mime_type is None:
18
+ mime_type = 'application/octet-stream'
19
+ with open(image_path, "rb") as image_file:
20
+ base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
21
+ return f"data:{mime_type};base64,{base64_encoded_data}"
22
+
23
+ def ask_gpt_on_figure(data, _, __):
24
+ endpoint = END_POINT
25
+ token_provider = get_bearer_token_provider(
26
+ DefaultAzureCredential(),
27
+ "https://cognitiveservices.azure.com/.default"
28
+ )
29
+ deployment_name = MODEL_NAME
30
+
31
+ client = AzureOpenAI(
32
+ azure_ad_token_provider=token_provider,
33
+ azure_endpoint=endpoint,
34
+ api_version=API_VER
35
+ )
36
+ description = data['description']
37
+ data_scale = 10
38
+ measurement = 'if the figure corresponds to the given description'
39
+
40
+ prompt = 'The following is a text description of a 3D CAD figure and an image of a CAD instance. ' +\
41
+ f'Measure {measurement}, and give a score in the scale of {data_scale}. Do not comment on issues such as texture, smoothness and colors' +\
42
+ f'\n description: {description}\n'
43
+ image_path = data['figure_path']
44
+ response = client.chat.completions.create(
45
+ model=deployment_name,
46
+ messages=[
47
+ {'role': 'system', 'content': 'You are a helpful assistant'},
48
+ {'role': 'user', 'content': [
49
+ {'type': 'text', 'text': prompt},
50
+ {'type': 'image_url', 'image_url': {'url': local_image_to_data_url(image_path)}},
51
+ ]}
52
+ ]
53
+ )
54
+ time.sleep(3)
55
+ return(response.choices[0].message.content)
56
+
57
+
58
+ def ask_gpt(data, _, __):
59
+ endpoint = END_POINT
60
+ token_provider = get_bearer_token_provider(
61
+ DefaultAzureCredential(),
62
+ "https://cognitiveservices.azure.com/.default"
63
+ )
64
+ deployment_name = MODEL_NAME
65
+
66
+ client = AzureOpenAI(
67
+ azure_ad_token_provider=token_provider,
68
+ azure_endpoint=endpoint,
69
+ api_version=API_VER
70
+ )
71
+ description = data['gpt_label']
72
+
73
+ prompt = 'The following is an evaluation of an CAD object.' +\
74
+ f'\n evaluation: {description}\n' +\
75
+ 'Extract the integer score of the evaluation. The score is between 0 to 10. Return the number only.'
76
+
77
+ response = client.chat.completions.create(
78
+ model=deployment_name,
79
+ messages=[
80
+ {'role': 'system', 'content': 'You are a helpful assistant'},
81
+ {'role': 'user', 'content': [
82
+ {'type': 'text', 'text': prompt},
83
+ ]}
84
+ ]
85
+ )
86
+ # print(response.choices[0].message.content)
87
+ time.sleep(3)
88
+ return(response.choices[0].message.content)
CADFusion/src/rendering_utils/geometry/arc.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ from geometry.curve import Curve
4
+
5
+
6
+ class Arc(Curve):
7
+ def __init__(self, point_indices, point_data, is_outer):
8
+ assert len(point_indices) == 4, "Arc must be defined by 3 points"
9
+ assert point_data is not None
10
+ super(Arc, self).__init__(point_indices, point_data)
11
+ self.type = 'arc'
12
+ self.is_outer = is_outer
13
+ self.start = self.point_geom[0, :]
14
+ self.mid = self.point_geom[1, :]
15
+ self.center = self.point_geom[2, :]
16
+ self.end = self.point_geom[3, :]
17
+
18
+ self.r1 = math.sqrt( (self.start[0] - self.center[0])**2 + (self.start[1] - self.center[1])**2 )
19
+ self.r2 = math.sqrt( (self.end[0] - self.center[0])**2 + (self.end[1] - self.center[1])**2 )
20
+ self.radius = (self.r1+self.r2)/2
21
+
22
+ self.start_idx = point_indices[0]
23
+ self.mid_idx = point_indices[1]
24
+ self.center_idx = point_indices[2]
25
+ self.end_idx = point_indices[3]
26
+
27
+ self.bbox = self.verts_to_bbox(np.vstack([self.start, self.end, self.mid]))
28
+ self.bottom_left = np.array([self.bbox[0], self.bbox[2]])
29
+
30
+
31
+
32
+
CADFusion/src/rendering_utils/geometry/circle.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from geometry.curve import Curve
3
+ import pdb
4
+
5
+ class Circle(Curve):
6
+ def __init__(self, point_indices, point_data, is_outer):
7
+ assert len(point_indices) == 2, "Circle must be defined by 1 points"
8
+ assert point_data is not None
9
+ super(Circle, self).__init__(point_indices, point_data)
10
+ self.type = 'circle'
11
+ self.center = self.point_geom[0, :]
12
+ self.radius = self.point_geom[1, 0]
13
+ self.center_idx = point_indices[0]
14
+ self.radius_idx = point_indices[1]
15
+ self.is_outer = is_outer
16
+
17
+ self.pt1 = np.array([self.center[0], self.center[1]+self.radius])
18
+ self.pt2 = np.array([self.center[0], self.center[1]-self.radius])
19
+ self.pt3 = np.array([self.center[0]+self.radius, self.center[1]])
20
+ self.pt4 = np.array([self.center[0]-self.radius, self.center[1]])
21
+ self.bbox = self.verts_to_bbox(np.vstack([self.pt1, self.pt2, self.pt3, self.pt4]))
22
+ self.bottom_left = np.array([self.bbox[0], self.bbox[2]])
23
+
24
+
25
+
26
+
27
+
CADFusion/src/rendering_utils/geometry/curve.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+
4
+ class Curve():
5
+ def __init__(self, point_indices, point_data):
6
+ self.point_indices = point_indices
7
+ self.point_geom = point_data[point_indices, 0:2]
8
+
9
+ def verts_to_bbox(self, verts):
10
+ xs = [v[0] for v in verts]
11
+ ys = [v[1] for v in verts]
12
+ bbox = [min(xs), max(xs), min(ys), max(ys)]
13
+ return bbox
CADFusion/src/rendering_utils/geometry/geom_utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+
4
+ def angle_from_vector_to_x(vec):
5
+ assert vec.size == 2
6
+ # We need to find a unit vector
7
+ angle = 0.0
8
+
9
+ l = np.linalg.norm(vec)
10
+ uvec = vec/l
11
+
12
+ # 2 | 1
13
+ #-------
14
+ # 3 | 4
15
+ if uvec[0] >=0:
16
+ if uvec[1] >= 0:
17
+ # Qadrant 1
18
+ angle = math.asin(uvec[1])
19
+ else:
20
+ # Qadrant 4
21
+ angle = 2.0*math.pi - math.asin(-uvec[1])
22
+ else:
23
+ if vec[1] >= 0:
24
+ # Qadrant 2
25
+ angle = math.pi - math.asin(uvec[1])
26
+ else:
27
+ # Qadrant 3
28
+ angle = math.pi + math.asin(-uvec[1])
29
+ return angle
30
+
31
+
32
+ def convert_angle_to_1to360_range(angle_rad):
33
+ """
34
+ Converts the given angle in radians into 1-360 degrees range
35
+ """
36
+ angle = math.degrees(angle_rad)
37
+ # Lifted from: https://stackoverflow.com/questions/12234574/calculating-if-an-angle-is-between-two-angles
38
+ angle=(int(angle) % 360) + (angle-math.trunc(angle)) # converts angle to range -360 + 360
39
+ if angle > 0.0:
40
+ return angle
41
+ else:
42
+ return angle + 360.0
43
+
44
+
45
+ def angle_is_between(angle_rad, a_rad, b_rad):
46
+ """
47
+ Checks if angle is in between the range of a and b
48
+ (All angles must be given in radians)
49
+ """
50
+ angle = convert_angle_to_1to360_range(angle_rad)
51
+ a = convert_angle_to_1to360_range(a_rad)
52
+ b = convert_angle_to_1to360_range(b_rad)
53
+ if a < b:
54
+ return a <= angle and angle <= b
55
+ return a <= angle or angle <= b
56
+
57
+
58
+ def quantize_verts(verts, n_bits=8):
59
+ """Convert vertices in [-1., 1.] to discrete values in [0, n_bits**2 - 1]."""
60
+ min_range = -0.5
61
+ max_range = 0.5
62
+ range_quantize = 2 ** n_bits - 1
63
+ verts_quantize = (verts - min_range) * range_quantize / (max_range - min_range)
64
+ return verts_quantize.astype("int32")
65
+
66
+
67
+ def dequantize_verts(verts, n_bits=8, add_noise=False):
68
+ """Convert quantized vertices to floats."""
69
+ min_range = -0.5
70
+ max_range = 0.5
71
+ range_quantize = 2 ** n_bits - 1
72
+ verts = verts.astype("float32")
73
+ verts = verts * (max_range - min_range) / range_quantize + min_range
74
+ if add_noise:
75
+ verts += np.random.uniform(size=verts.shape) * (1 / range_quantize)
76
+ return verts
77
+
78
+
79
+ def center_vertices(vertices):
80
+ """Translate the vertices so that bounding box is centered at zero."""
81
+ vert_min = vertices.min(axis=0)
82
+ vert_max = vertices.max(axis=0)
83
+ vert_center = 0.5 * (vert_min + vert_max)
84
+ return vertices - vert_center, vert_center
85
+
86
+
87
+ def scale_vertices(vertices):
88
+ """Scale the vertices so that the long diagonal of the bounding box is one."""
89
+ vert_min = vertices.min(axis=0)
90
+ vert_max = vertices.max(axis=0)
91
+ extents = vert_max - vert_min
92
+ scale = np.sqrt(np.sum(extents ** 2))
93
+ return vertices / scale, scale
94
+
95
+
CADFusion/src/rendering_utils/geometry/line.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from geometry.curve import Curve
3
+
4
+ class Line(Curve):
5
+ def __init__(self, point_indices, point_data, is_outer):
6
+ assert len(point_indices) == 2, "Line must be defined by two points"
7
+ assert point_data is not None
8
+ super(Line, self).__init__(point_indices, point_data)
9
+ pt0 = self.point_geom[0, :]
10
+ pt1 = self.point_geom[1, :]
11
+ self.type = 'line'
12
+ self.start = pt0
13
+ self.end = pt1
14
+ self.start_idx = point_indices[0]
15
+ self.end_idx = point_indices[1]
16
+ self.is_outer = is_outer
17
+
18
+ self.bbox = self.verts_to_bbox(np.vstack([pt0, pt1]))
19
+ self.bottom_left = np.array([self.bbox[0], self.bbox[2]])
20
+
21
+
22
+
23
+
24
+
CADFusion/src/rendering_utils/geometry/obj_parser.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+
5
+ from geometry.arc import Arc
6
+ from geometry.circle import Circle
7
+ from geometry.line import Line
8
+
9
+ from geometry import geom_utils
10
+ import pdb
11
+
12
+
13
+ class OBJParser:
14
+ """
15
+ A class to read an OBJ file containing the sketch data
16
+ and hand it back in a form which is easy to work with.
17
+ """
18
+ def __init__(self, pathname=None):
19
+ self.pathname = pathname
20
+
21
+
22
+ def convert_vertices(self, vertices):
23
+ """Convert all the vertices to .obj format"""
24
+ vertex_strings = ""
25
+ for pt in vertices:
26
+ # e.g. v 0.123 0.234 0.345 1.0
27
+ vertex_string = f"v {pt[0]} {pt[1]}\n"
28
+ vertex_strings += vertex_string
29
+ return vertex_strings
30
+
31
+
32
+ def convert_curves(self, faces):
33
+ curve_strings = ""
34
+ total_curve = 0
35
+
36
+ # Faces (multiple closed regions)
37
+ for group_idx, loops in enumerate(faces):
38
+ curve_strings += f"\nface\n"
39
+ # Multiple loops (inner and outer)
40
+ for loop in loops:
41
+ if loop[0].is_outer:
42
+ curve_strings += f"out\n"
43
+ else:
44
+ curve_strings += f"in\n"
45
+ # All curves in one loop
46
+ for curve in loop:
47
+ total_curve += 1
48
+ if curve.type == 'line':
49
+ curve_strings += f"l {curve.start_idx} {curve.end_idx}\n"
50
+ elif curve.type == 'circle':
51
+ curve_strings += f"c {curve.center_idx} {curve.radius_idx}\n"
52
+ elif curve.type == 'arc':
53
+ curve_strings += f"a {curve.start_idx} {curve.mid_idx} {curve.center_idx} {curve.end_idx}\n"
54
+
55
+ return curve_strings, total_curve
56
+
57
+
58
+ def parse3d(self, point3d):
59
+ x = point3d[0]
60
+ y = point3d[1]
61
+ z = point3d[2]
62
+ return str(x)+' '+str(y)+' '+str(z)
63
+
64
+
65
+ def write_obj2(self, file, vertices, faces, meta_info, scale=None):
66
+ """ Write to .obj file """
67
+ vertex_strings = self.convert_vertices(vertices)
68
+ curve_strings, total_curve = self.convert_curves(faces)
69
+
70
+ with open(file, "w") as fh:
71
+ # Write Meta info
72
+ fh.write("# WaveFront *.obj file\n")
73
+ fh.write(f"# Vertices: {len(vertices)}\n")
74
+ fh.write(f"# Curves: {total_curve}\n")
75
+ fh.write("\n")
76
+
77
+ # Write vertex and curve
78
+ fh.write(vertex_strings)
79
+ fh.write("\n")
80
+ fh.write(curve_strings)
81
+ fh.write("\n")
82
+
83
+ #Write extrude value
84
+ fh.write("ExtrudeOperation: " + meta_info['set_op']+"\n")
85
+ extrude_string = 'Extrude '
86
+ for value in meta_info['extrude_value']:
87
+ extrude_string += str(value)+' '
88
+ fh.write(extrude_string)
89
+ fh.write("\n")
90
+
91
+ #Write refe plane transformation
92
+ p_orig = self.parse3d(meta_info['t_orig'])
93
+ x_axis = self.parse3d(meta_info['t_x'])
94
+ y_axis = self.parse3d(meta_info['t_y'])
95
+ z_axis = self.parse3d(meta_info['t_z'])
96
+ fh.write('T_origin '+p_orig)
97
+ fh.write("\n")
98
+ fh.write('T_xaxis '+x_axis)
99
+ fh.write("\n")
100
+ fh.write('T_yaxis '+y_axis)
101
+ fh.write("\n")
102
+ fh.write('T_zaxis '+z_axis)
103
+ fh.write("\n")
104
+
105
+ # Normalized object
106
+ if scale is not None:
107
+ fh.write('Scale '+str(scale))
108
+
109
+
110
+ def write_obj(self, file, curve_strings, total_curve, vertex_strings, total_v, meta_info, scale=None):
111
+ """ Write to .obj file """
112
+ #vertex_strings = self.convert_vertices(vertices)
113
+ #curve_strings, total_curve = self.convert_curves(faces)
114
+
115
+ with open(file, "w") as fh:
116
+ # Write Meta info
117
+ fh.write("# WaveFront *.obj file\n")
118
+ fh.write(f"# Vertices: {total_v}\n")
119
+ fh.write(f"# Curves: {total_curve}\n")
120
+ fh.write("\n")
121
+
122
+ # Write vertex and curve
123
+ fh.write(vertex_strings)
124
+ fh.write("\n")
125
+ fh.write(curve_strings)
126
+ fh.write("\n")
127
+
128
+ #Write extrude value
129
+ fh.write("ExtrudeOperation: " + meta_info['set_op']+"\n")
130
+ extrude_string = 'Extrude '
131
+ for value in meta_info['extrude_value']:
132
+ extrude_string += str(value)+' '
133
+ fh.write(extrude_string)
134
+ fh.write("\n")
135
+
136
+ #Write refe plane transformation
137
+ p_orig = self.parse3d(meta_info['t_orig'])
138
+ x_axis = self.parse3d(meta_info['t_x'])
139
+ y_axis = self.parse3d(meta_info['t_y'])
140
+ z_axis = self.parse3d(meta_info['t_z'])
141
+ fh.write('T_origin '+p_orig)
142
+ fh.write("\n")
143
+ fh.write('T_xaxis '+x_axis)
144
+ fh.write("\n")
145
+ fh.write('T_yaxis '+y_axis)
146
+ fh.write("\n")
147
+ fh.write('T_zaxis '+z_axis)
148
+ fh.write("\n")
149
+
150
+ # Normalized object
151
+ if scale is not None:
152
+ fh.write('Scale '+str(scale))
153
+
154
+
155
+ def parse_file(self, scale=1.0):
156
+ """
157
+ Parse obj file
158
+ Return
159
+ vertex 2D location numpy
160
+ curve list (geometry class)
161
+ extrude parameters
162
+ """
163
+
164
+ assert self.pathname is not None, "File is None"
165
+ assert self.pathname.exists(), "No such file"
166
+
167
+ # Parse file
168
+ vertex_list = []
169
+ loops = []
170
+ closed_loop = []
171
+
172
+ # Read vertice
173
+ with open(self.pathname) as obj_file:
174
+ for line in obj_file:
175
+ tokens = line.split()
176
+ if not tokens:
177
+ continue
178
+ line_type = tokens[0]
179
+ # Vertex
180
+ if line_type == "v":
181
+ vertex_list.append([float(x) for x in tokens[1:]])
182
+ vertices = np.array(vertex_list, dtype=np.float64) * scale
183
+
184
+ # Read curves
185
+ faces = []
186
+ loops = []
187
+ loop = []
188
+
189
+ # Read in all lines
190
+ lines = []
191
+ with open(self.pathname) as obj_file:
192
+ for line in obj_file:
193
+ lines.append(line)
194
+
195
+ # Parse all lines
196
+ faces = []
197
+ for str_idx, line in enumerate(lines):
198
+ tokens = line.split()
199
+ if not tokens:
200
+ continue
201
+ line_type = tokens[0]
202
+
203
+ # Start of a new face
204
+ if line_type == "face":
205
+ faces.append(self.read_face(lines, str_idx+1, vertices))
206
+
207
+ # Read meta data
208
+ meta_data = line.strip('# ').strip(' \n').split(' ')
209
+ meta_name = meta_data[0]
210
+
211
+ if meta_name == 'Extrude':
212
+ extrude_values = [float(x) for x in meta_data[1:]]
213
+ extrude_values = [x*scale for x in extrude_values]
214
+ elif meta_name == 'T_origin':
215
+ t_orig = [float(x) for x in meta_data[1:]]
216
+ t_orig = [x*scale for x in t_orig]
217
+ elif meta_name == 'T_xaxis':
218
+ t_x = [float(x) for x in meta_data[1:]]
219
+ elif meta_name == 'T_yaxis':
220
+ t_y = [float(x) for x in meta_data[1:]]
221
+ elif meta_name == 'T_zaxis':
222
+ t_z = [float(x) for x in meta_data[1:]]
223
+ elif meta_name == 'ExtrudeOperation:':
224
+ set_op = meta_data[1]
225
+
226
+ meta_info = {'extrude_value': extrude_values,
227
+ 'set_op': set_op,
228
+ 't_orig': t_orig,
229
+ 't_x': t_x,
230
+ 't_y': t_y,
231
+ 't_z': t_z,
232
+ }
233
+
234
+ return vertices, faces, meta_info
235
+
236
+
237
+
238
+ def read_face(self, lines, str_idx, vertices):
239
+ loops = []
240
+ loop = []
241
+ for line in lines[str_idx:]:
242
+ tokens = line.split()
243
+ if not tokens:
244
+ continue
245
+ line_type = tokens[0]
246
+
247
+ if line_type == 'face':
248
+ break
249
+
250
+ # Start of a new loop
251
+ if line_type == "out" or line_type == "in":
252
+ if len(loop) > 0:
253
+ loops.append(loop)
254
+ loop = []
255
+ is_outer = (line_type == 'out')
256
+
257
+ # Line
258
+ if line_type == 'l':
259
+ c_tok = tokens[1:]
260
+ curve = Line([int(c_tok[0]), int(c_tok[1])], vertices, is_outer=is_outer)
261
+ loop.append(curve)
262
+
263
+ # Arc
264
+ if line_type == 'a':
265
+ c_tok = tokens[1:]
266
+ curve = Arc([int(c_tok[0]), int(c_tok[1]), int(c_tok[2]), int(c_tok[3])], vertices, is_outer=is_outer)
267
+ loop.append(curve)
268
+
269
+ # Circle
270
+ if line_type == 'c':
271
+ c_tok = tokens[1:]
272
+ curve = Circle([int(c_tok[0]), int(c_tok[1])], vertices, is_outer=is_outer)
273
+ loop.append(curve)
274
+
275
+ loops.append(loop)
276
+ return loops
CADFusion/src/rendering_utils/geometry/obj_utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ from pathlib import Path
4
+ import pdb
5
+
6
+
7
+ def read_wire_obj(obj_path):
8
+ """Read vertices and lines from .obj file defining a wire body."""
9
+ vertex_list = []
10
+ loops = []
11
+
12
+ # Read vertice and curves
13
+ with open(obj_path) as obj_file:
14
+
15
+ for line in obj_file:
16
+ tokens = line.split()
17
+ if not tokens:
18
+ continue
19
+
20
+ line_type = tokens[0]
21
+
22
+ if line_type == "v":
23
+ vertex_list.append([float(x) for x in tokens[1:]])
24
+
25
+ if line_type == "g":
26
+ pdb.set_trace()
27
+
28
+
29
+
30
+
31
+ # Read meta data
32
+ meta_data = line.strip('# ').strip(' \n').split(' ')
33
+ meta_name = meta_data[0]
34
+ if meta_name == 'Extrude':
35
+ extrude_values= [float(x) for x in meta_data[1:]]
36
+ elif meta_name == 'T_origin':
37
+ t_orig = [float(x) for x in meta_data[1:]]
38
+ elif meta_name == 'T_xaxis':
39
+ t_x = [float(x) for x in meta_data[1:]]
40
+ elif meta_name == 'T_yaxis':
41
+ t_y = [float(x) for x in meta_data[1:]]
42
+ elif meta_name == 'T_zaxis':
43
+ t_z = [float(x) for x in meta_data[1:]]
44
+ elif meta_name == 'ExtrudeOperation:':
45
+ set_op = meta_data[1]
46
+
47
+
48
+ vertices = np.array(vertex_list)
49
+
50
+
51
+
52
+ meta_info = {'extrude_value': extrude_values,
53
+ 'set_op': set_op,
54
+ 't_orig': t_orig,
55
+ 't_x': t_x,
56
+ 't_y': t_y,
57
+ 't_z': t_z}
58
+
59
+ total_in_outs.append(in_outs)
60
+
61
+ return np.array(flat_vertices_list, dtype=np.float32), flat_hyperedge, total_in_outs, meta_info
62
+
63
+
64
+ def write_wire_obj(vertices, faces, file_path, transpose=True, scale=1.0):
65
+ """Write vertices and hyperedges to obj."""
66
+ vertex_dimension = vertices.shape[1]
67
+ assert vertex_dimension in (2, 3)
68
+ if transpose and vertex_dimension == 3:
69
+ # Permute 3D vertices where z comes first followed by x and y
70
+ vertices = vertices[:, [1, 2, 0]]
71
+ vertices *= scale
72
+ if faces is not None:
73
+ if len(faces) > 0:
74
+ if min(min(faces)) == 0:
75
+ f_add = 1
76
+ else:
77
+ f_add = 0
78
+ with open(file_path, "w") as f:
79
+ for v in vertices:
80
+ if vertex_dimension == 2:
81
+ f.write("v {} {} {}\n".format(v[0], v[1], 0.0))
82
+ else:
83
+ f.write("v {} {} {}\n".format(v[0], v[1], v[2]))
84
+ for face in faces:
85
+ line = "l"
86
+ for i in face:
87
+ # Pradeep: always adding 1 to the face index makes sense to me. Not sure why
88
+ # PolyGen does this conditionally (see L95 above)
89
+ # Something to note.
90
+ line += " {}".format(i + 1)
91
+ line += "\n"
92
+ f.write(line)
93
+
CADFusion/src/rendering_utils/img_renderer.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from OCC.Core.Graphic3d import *
3
+ from OCC.Display.OCCViewer import Viewer3d
4
+ from OCC.Extend.DataExchange import read_step_file
5
+ from OCC.Extend.TopologyUtils import TopologyExplorer
6
+ from OCC.Core.Quantity import Quantity_Color, Quantity_TOC_RGB, Quantity_NOC_WHITE
7
+ from OCC.Core.V3d import V3d_DirectionalLight
8
+ from OCC.Core.gp import gp_Dir
9
+ from glob import glob
10
+ import pathlib
11
+ from tqdm import tqdm
12
+
13
+
14
+ def render(shape, filename, width=1024, height=768, face_color_rgb=(0.2, 0.2, 0.2), edge_color_rgb=(0, 0, 0), show_face_boundary=True):
15
+ viewer = Viewer3d()
16
+ viewer.Create(phong_shading=True, create_default_lights=True)
17
+ viewer.set_bg_gradient_color([255, 255, 255], [255, 255, 255])
18
+ viewer.SetModeShaded()
19
+ viewer.hide_triedron()
20
+ viewer.EnableAntiAliasing()
21
+ dir_light = V3d_DirectionalLight(gp_Dir(0, 0.5, -1), Quantity_Color(Quantity_NOC_WHITE))
22
+ dir_light.SetEnabled(True)
23
+ dir_light.SetIntensity(500.0)
24
+ viewer.Viewer.AddLight(dir_light)
25
+ viewer.Viewer.SetLightOn()
26
+
27
+ viewer.default_drawer.EnableDrawHiddenLine()
28
+ viewer.default_drawer.SetFaceBoundaryDraw(show_face_boundary)
29
+ ais_context = viewer.GetContext()
30
+ dc = ais_context.DeviationCoefficient()
31
+ da = ais_context.DeviationAngle()
32
+ factor = 10
33
+ ais_context.SetDeviationCoefficient(dc / factor)
34
+ ais_context.SetDeviationAngle(da / factor)
35
+ topexp = TopologyExplorer(shape)
36
+ for face in topexp.faces():
37
+ if face is not None:
38
+ viewer.DisplayShape(face, color=Quantity_Color(*face_color_rgb, Quantity_TOC_RGB))
39
+ for edge in topexp.edges():
40
+ if edge is not None:
41
+ viewer.DisplayShape(edge, color=Quantity_Color(*edge_color_rgb, Quantity_TOC_RGB))
42
+ viewer.FitAll()
43
+ viewer.SetSize(width, height)
44
+ viewer.View.Dump(str(filename))
45
+
46
+
47
+ def main():
48
+ p = argparse.ArgumentParser()
49
+ p.add_argument("--input_dir", type=str, required=True, help="Input folder of STP/STEP files")
50
+ p.add_argument("--output_dir", type=str, required=True, help="Output folder of PNG files")
51
+ p.add_argument("--width", type=int, default=1024, help="Width of image")
52
+ p.add_argument("--height", type=int, default=768, help="Height of image")
53
+
54
+ args = p.parse_args()
55
+
56
+ files = []
57
+ cad_folders = sorted(glob(args.input_dir+'/*/'))
58
+ for folder in cad_folders:
59
+ input_path = pathlib.Path(folder)
60
+ files += list(input_path.glob("*.st*p"))
61
+ print(len(files))
62
+ # files = files[36000:] # debug only (* remove *)
63
+ output_path = pathlib.Path(args.output_dir)
64
+ if not output_path.exists():
65
+ output_path.mkdir(parents=True, exist_ok=True)
66
+
67
+ i = 0
68
+ j = 0
69
+ for fn in tqdm(files):
70
+ j += 1
71
+ try:
72
+ shape = read_step_file(str(fn))
73
+ # render(shape, output_path.joinpath(f'{j:06d}' + ".png"), args.width, args.height)
74
+ render(shape, output_path.joinpath(fn.stem[:6] + ".png"), args.width, args.height)
75
+ except Exception as e:
76
+ i += 1
77
+ # raise e
78
+ print(e)
79
+ continue
80
+ print("error number: ", i)
81
+ print("total number: ", j)
82
+
83
+ if __name__ == "__main__":
84
+ main()
CADFusion/src/rendering_utils/parser.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from collections import OrderedDict
3
+ import re
4
+ from pathlib import Path
5
+ import argparse
6
+ import os
7
+ import json
8
+ import math
9
+
10
+ # hyperparameters from SkexGen project
11
+ SKETCH_R = 1
12
+ RADIUS_R = 1
13
+ EXTRUDE_R = 1.0
14
+ SCALE_R = 1.4
15
+ OFFSET_R = 0.9
16
+ PIX_PAD = 4
17
+ CMD_PAD = 3
18
+ COORD_PAD = 4
19
+ EXT_PAD = 1
20
+ EXTRA_PAD = 1
21
+ R_PAD = 2
22
+
23
+
24
+ class CADparser:
25
+ """Parse CAD sequence to CAD object."""
26
+
27
+ def __init__(self, bit):
28
+ self.vertex_dict = OrderedDict()
29
+ self.bit = bit
30
+
31
+ def perform(self, cad_seq):
32
+ # divide into sketch and extrude
33
+ sketches, extrudes = self.get_SE(cad_seq)
34
+ if sketches is None or extrudes is None:
35
+ return None
36
+ # sequentially parse each pair of SE into obj
37
+ se_datas = []
38
+ for sketch, extrude in zip(sketches, extrudes):
39
+ extrude_param, scale, offset = self.parse_extrude(extrude)
40
+ if extrude_param is None or scale is None or offset is None:
41
+ return None
42
+ vertex_str, se_str = self.parse_sketch(sketch, scale, offset)
43
+ if vertex_str is None or se_str is None:
44
+ return None
45
+ se_datas.append(
46
+ {"vertex": vertex_str, "curve": se_str, "extrude": extrude_param}
47
+ )
48
+ self.vertex_dict.clear()
49
+
50
+ return se_datas
51
+
52
+ def parse_sketch(self, sketch, scale, offset):
53
+ faces = self.get_faces(sketch)
54
+ if len(faces) == 0:
55
+ return None, None
56
+ se_str = ""
57
+ for face_idx, face in enumerate(faces): # each face
58
+ face_str = "face\n"
59
+ loops = self.get_loops(face)
60
+ if len(loops) == 0:
61
+ return None, None
62
+ for loop_idx, loop in enumerate(loops): # each loop
63
+ curves = self.get_curves(loop)
64
+ if len(curves) == 0:
65
+ return None, None
66
+ next_curves = curves[1:]
67
+ next_curves += curves[:1]
68
+ cur_str = []
69
+ for curve, next_curve in zip(curves, next_curves): # each curve
70
+ if not self.obj_curve(curve, next_curve, cur_str, scale, offset):
71
+ return None, None
72
+ loop_str = ""
73
+ for c in cur_str:
74
+ loop_str += f"{c}\n"
75
+ if loop_idx == 0:
76
+ face_str += f"out\n{loop_str}\n"
77
+ else:
78
+ face_str += f"in\n{loop_str}\n"
79
+ se_str += face_str
80
+ vertex_str = self.convert_vertices()
81
+ return vertex_str, se_str
82
+
83
+ def parse_extrude(self, extrude):
84
+ ext = extrude.split(",")
85
+ if len(ext) != 18:
86
+ return None, None, None
87
+
88
+ # operation str to int
89
+ ext_op = {"add": 1, "cut": 2, "intersect": 3}.get(ext[0], None)
90
+ if ext_op is None:
91
+ return None, None, None
92
+ # dequantize ext_v, ext_T, scale and offset
93
+ ext_v, ext_T, scale, offset = self.dequantize_extrude_params(ext)
94
+ # get ext_R
95
+ ext_R = np.array(ext[6:15], dtype=int)
96
+
97
+ extrude_param = {"value": ext_v, "T": ext_T, "R": ext_R, "op": ext_op}
98
+ return extrude_param, scale, offset
99
+
100
+ def obj_curve(self, curve, next_curve, cur_str, scale, offset):
101
+ cur = curve.split(",")
102
+ next_cur = next_curve.split(",")
103
+ if cur[0] == "circle":
104
+ if len(cur) != 9:
105
+ return False
106
+ p1, p2, p3, p4 = self.dequantize_circle_points(
107
+ cur, next_cur, scale, offset)
108
+ center = np.asarray([0.5 * (p1[0] + p2[0]), 0.5 * (p3[1] + p4[1])])
109
+ radius = (np.linalg.norm(p1 - p2) + np.linalg.norm(p3 - p4)) / 4.0
110
+
111
+ center = center * scale + offset
112
+ radius = radius * scale
113
+
114
+ center_idx = self.save_vertex(center[0], center[1], "p")
115
+ radius_idx = self.save_vertex(radius, 0.0, "r")
116
+ cur_str.append(f"c {center_idx} {radius_idx}")
117
+ elif cur[0] == "arc":
118
+ if len(cur) != 5:
119
+ return False
120
+ if (
121
+ cur[1:3] == cur[3:5]
122
+ or cur[1:3] == next_cur[1:3]
123
+ or cur[3:5] == next_cur[3:5]
124
+ ): # invalid arc
125
+ return False
126
+ start_v, mid_v, end_v = self.dequantize_arc_points(
127
+ cur, next_cur, scale, offset
128
+ )
129
+ try:
130
+ center, _, _, _ = find_arc_geometry(start_v, mid_v, end_v)
131
+ except Exception:
132
+ return False
133
+ start_v = start_v * scale + offset
134
+ mid_v = mid_v * scale + offset
135
+ end_v = end_v * scale + offset
136
+ center = center * scale + offset
137
+
138
+ center_idx = self.save_vertex(center[0], center[1], "p")
139
+ start_idx = self.save_vertex(start_v[0], start_v[1], "p")
140
+ mid_idx = self.save_vertex(mid_v[0], mid_v[1], "p")
141
+ end_idx = self.save_vertex(end_v[0], end_v[1], "p")
142
+ cur_str.append(f"a {start_idx} {mid_idx} {center_idx} {end_idx}")
143
+ elif cur[0] == "line":
144
+ if len(cur) != 3:
145
+ return False
146
+ if cur[1:3] == next_cur[1:3]:
147
+ return False
148
+ start_v, end_v = self.dequantize_line_points(
149
+ cur, next_cur, scale, offset)
150
+ start_v = start_v * scale + offset
151
+ end_v = end_v * scale + offset
152
+
153
+ start_idx = self.save_vertex(start_v[0], start_v[1], "p")
154
+ end_idx = self.save_vertex(end_v[0], end_v[1], "p")
155
+ cur_str.append(f"l {start_idx} {end_idx}")
156
+ else:
157
+ return False
158
+ return True
159
+
160
+ def get_SE(self, cad_seq):
161
+ # sketches: 1) between sequence start and sketch_end,
162
+ sketches_from_start = re.findall(r"^(.+?)(?=<sketch_end>)", cad_seq)
163
+ # sketches: 2) between extrude_end and sketch_end
164
+ sketches_after_extrude = re.findall(
165
+ r"(?<=<extrude_end>)(.+?)(?=<sketch_end>)", cad_seq
166
+ )
167
+ sketches = [x.strip() for x in sketches_from_start] + [
168
+ x.strip() for x in sketches_after_extrude
169
+ ]
170
+ # extrudes: between sketch_end and extrude_end
171
+ extrudes = [
172
+ x.strip() for x in re.findall(r"<sketch_end>(.+?)<extrude_end>", cad_seq)
173
+ ]
174
+ if len(sketches) != len(extrudes):
175
+ return None, None
176
+ return sketches, extrudes
177
+
178
+ def get_faces(self, sketch):
179
+ faces = sketch.split("<face_end>")
180
+ return [x.strip() for x in faces if x.strip() != ""]
181
+
182
+ def get_loops(self, face):
183
+ loops = face.split("<loop_end>")
184
+ return [x.strip() for x in loops if x.strip() != ""]
185
+
186
+ def get_curves(self, loop):
187
+ curves = loop.split("<curve_end>")
188
+ return [x.strip() for x in curves if x.strip() != ""]
189
+
190
+ def dequantize_circle_points(self, curve, next_curve, scale, offset):
191
+ p1 = dequantize_verts(
192
+ np.array(curve[1:3], dtype=int),
193
+ n_bits=self.bit,
194
+ min_range=-SKETCH_R,
195
+ max_range=SKETCH_R,
196
+ add_noise=False,
197
+ )
198
+ p2 = dequantize_verts(
199
+ np.array(curve[3:5], dtype=int),
200
+ n_bits=self.bit,
201
+ min_range=-SKETCH_R,
202
+ max_range=SKETCH_R,
203
+ add_noise=False,
204
+ )
205
+ p3 = dequantize_verts(
206
+ np.array(curve[5:7], dtype=int),
207
+ n_bits=self.bit,
208
+ min_range=-SKETCH_R,
209
+ max_range=SKETCH_R,
210
+ add_noise=False,
211
+ )
212
+ p4 = dequantize_verts(
213
+ np.array(curve[7:9], dtype=int),
214
+ n_bits=self.bit,
215
+ min_range=-SKETCH_R,
216
+ max_range=SKETCH_R,
217
+ add_noise=False,
218
+ )
219
+ return p1, p2, p3, p4
220
+
221
+ def dequantize_arc_points(self, curve, next_curve, scale, offset):
222
+ start_v = dequantize_verts(
223
+ np.array(curve[1:3], dtype=int),
224
+ n_bits=self.bit,
225
+ min_range=-SKETCH_R,
226
+ max_range=SKETCH_R,
227
+ add_noise=False,
228
+ )
229
+ mid_v = dequantize_verts(
230
+ np.array(curve[3:5], dtype=int),
231
+ n_bits=self.bit,
232
+ min_range=-SKETCH_R,
233
+ max_range=SKETCH_R,
234
+ add_noise=False,
235
+ )
236
+ end_v = dequantize_verts(
237
+ np.array(next_curve[1:3], dtype=int),
238
+ n_bits=self.bit,
239
+ min_range=-SKETCH_R,
240
+ max_range=SKETCH_R,
241
+ add_noise=False,
242
+ )
243
+ return start_v, mid_v, end_v
244
+
245
+ def dequantize_line_points(self, curve, next_curve, scale, offset):
246
+ start_v = dequantize_verts(
247
+ np.array(curve[1:3], dtype=int),
248
+ n_bits=self.bit,
249
+ min_range=-SKETCH_R,
250
+ max_range=SKETCH_R,
251
+ add_noise=False,
252
+ )
253
+ end_v = dequantize_verts(
254
+ np.array(next_curve[1:3], dtype=int),
255
+ n_bits=self.bit,
256
+ min_range=-SKETCH_R,
257
+ max_range=SKETCH_R,
258
+ add_noise=False,
259
+ )
260
+ return start_v, end_v
261
+
262
+ def dequantize_extrude_params(self, extrude):
263
+ ext_v = dequantize_verts(
264
+ np.array(extrude[1:3], dtype=int),
265
+ n_bits=self.bit,
266
+ min_range=-EXTRUDE_R,
267
+ max_range=EXTRUDE_R,
268
+ add_noise=False,
269
+ )
270
+ ext_T = dequantize_verts(
271
+ np.array(extrude[3:6], dtype=int),
272
+ n_bits=self.bit,
273
+ min_range=-EXTRUDE_R,
274
+ max_range=EXTRUDE_R,
275
+ add_noise=False,
276
+ )
277
+ scale = dequantize_verts(
278
+ np.array(extrude[15], dtype=int),
279
+ n_bits=self.bit,
280
+ min_range=0.0,
281
+ max_range=SCALE_R,
282
+ add_noise=False,
283
+ )
284
+ offset = dequantize_verts(
285
+ np.array(extrude[16:18], dtype=int),
286
+ n_bits=self.bit,
287
+ min_range=-OFFSET_R,
288
+ max_range=OFFSET_R,
289
+ add_noise=False,
290
+ )
291
+ return ext_v, ext_T, scale, offset
292
+
293
+ def save_vertex(self, h_x, h_y, text):
294
+ unique_key = f"{text}:x{h_x}y{h_y}"
295
+ index = 0
296
+ for key in self.vertex_dict.keys():
297
+ # Vertex location already exist in dict
298
+ if unique_key == key:
299
+ return index
300
+ index += 1
301
+ # Vertex location does not exist in dict
302
+ self.vertex_dict[unique_key] = [h_x, h_y]
303
+ return index
304
+
305
+ def convert_vertices(self):
306
+ """Convert all the vertices to .obj format"""
307
+ vertex_strings = ""
308
+ for pt in self.vertex_dict.values():
309
+ # e.g. v 0.123 0.234 0.345 1.0
310
+ vertex_string = f"v {pt[0]} {pt[1]}\n"
311
+ vertex_strings += vertex_string
312
+ return vertex_strings
313
+
314
+
315
+ def find_arc_geometry(a, b, c):
316
+ A = b[0] - a[0]
317
+ B = b[1] - a[1]
318
+ C = c[0] - a[0]
319
+ D = c[1] - a[1]
320
+
321
+ E = A*(a[0] + b[0]) + B*(a[1] + b[1])
322
+ F = C*(a[0] + c[0]) + D*(a[1] + c[1])
323
+
324
+ G = 2.0*(A*(c[1] - b[1])-B*(c[0] - b[0]))
325
+
326
+ if G == 0:
327
+ raise Exception("zero G")
328
+
329
+ p_0 = (D*E - B*F) / G
330
+ p_1 = (A*F - C*E) / G
331
+
332
+ center = np.array([p_0, p_1])
333
+ radius = np.linalg.norm(center - a)
334
+
335
+ angles = []
336
+ for xx in [a, b, c]:
337
+ angle = angle_from_vector_to_x(xx - center)
338
+ angles.append(angle)
339
+
340
+ ab = b-a
341
+ ac = c-a
342
+ cp = np.cross(ab, ac)
343
+ if cp >= 0:
344
+ start_angle_rads = angles[0]
345
+ end_angle_rads = angles[2]
346
+ else:
347
+ start_angle_rads = angles[2]
348
+ end_angle_rads = angles[0]
349
+
350
+ return center, radius, start_angle_rads, end_angle_rads
351
+
352
+
353
+ def angle_from_vector_to_x(vec):
354
+ assert vec.size == 2
355
+ # We need to find a unit vector
356
+ angle = 0.0
357
+
358
+ l = np.linalg.norm(vec)
359
+ uvec = vec/l
360
+
361
+ # 2 | 1
362
+ # -------
363
+ # 3 | 4
364
+ if uvec[0] >= 0:
365
+ if uvec[1] >= 0:
366
+ # Qadrant 1
367
+ angle = math.asin(uvec[1])
368
+ else:
369
+ # Qadrant 4
370
+ angle = 2.0*math.pi - math.asin(-uvec[1])
371
+ else:
372
+ if vec[1] >= 0:
373
+ # Qadrant 2
374
+ angle = math.pi - math.asin(uvec[1])
375
+ else:
376
+ # Qadrant 3
377
+ angle = math.pi + math.asin(-uvec[1])
378
+ return angle
379
+
380
+
381
+ def dequantize_verts(verts, n_bits=8, min_range=-0.5, max_range=0.5, add_noise=False):
382
+ """Convert quantized vertices to floats."""
383
+ range_quantize = 2**n_bits - 1
384
+ verts = verts.astype("float32")
385
+ verts = verts * (max_range - min_range) / range_quantize + min_range
386
+ return verts
387
+
388
+
389
+ def write_obj_sample(save_folder, data):
390
+ for idx, write_data in enumerate(data):
391
+ obj_name = Path(save_folder).stem + "_" + \
392
+ str(idx).zfill(3) + "_param.obj"
393
+ obj_file = Path(save_folder) / obj_name
394
+ extrude_param = write_data["extrude"]
395
+ vertex_strings = write_data["vertex"]
396
+ curve_strings = write_data["curve"]
397
+
398
+ """Write an .obj file with the curves and verts"""
399
+ if extrude_param["op"] == 1: # 'add'
400
+ set_op = "NewBodyFeatureOperation"
401
+ elif extrude_param["op"] == 2: # 'cut'
402
+ set_op = "CutFeatureOperation"
403
+ elif extrude_param["op"] == 3: # 'cut'
404
+ set_op = "IntersectFeatureOperation"
405
+
406
+ with open(obj_file, "w") as fh:
407
+ # Write Meta info
408
+ fh.write("# WaveFront *.obj file\n")
409
+ fh.write("# ExtrudeOperation: " + set_op + "\n")
410
+ fh.write("\n")
411
+
412
+ # Write vertex and curve
413
+ fh.write(vertex_strings)
414
+ fh.write("\n")
415
+ fh.write(curve_strings)
416
+ fh.write("\n")
417
+
418
+ # Write extrude value
419
+ extrude_string = "Extrude "
420
+ for value in extrude_param["value"]:
421
+ extrude_string += str(value) + " "
422
+ fh.write(extrude_string)
423
+ fh.write("\n")
424
+
425
+ # Write refe plane value
426
+ p_orig = parse3d_sample(extrude_param["T"])
427
+ x_axis = parse3d_sample(extrude_param["R"][0:3])
428
+ y_axis = parse3d_sample(extrude_param["R"][3:6])
429
+ z_axis = parse3d_sample(extrude_param["R"][6:9])
430
+ fh.write("T_origin " + p_orig)
431
+ fh.write("\n")
432
+ fh.write("T_xaxis " + x_axis)
433
+ fh.write("\n")
434
+ fh.write("T_yaxis " + y_axis)
435
+ fh.write("\n")
436
+ fh.write("T_zaxis " + z_axis)
437
+
438
+
439
+ def parse3d_sample(point3d):
440
+ x = point3d[0]
441
+ y = point3d[1]
442
+ z = point3d[2]
443
+ return str(x) + " " + str(y) + " " + str(z)
444
+
445
+
446
+ if __name__ == "__main__":
447
+ parser = argparse.ArgumentParser()
448
+ parser.add_argument("--in-path", type=str, required=True)
449
+ parser.add_argument("--out-path", type=str, required=True)
450
+ args = parser.parse_args()
451
+
452
+ # with open(args.in_path, "r") as f:
453
+ # data = f.readlines()
454
+ with open(args.in_path, 'r') as file:
455
+ data = file.read()
456
+
457
+ data = json.loads(data)
458
+
459
+ num_valid_str = 0
460
+ for idx, item in enumerate(data):
461
+ try:
462
+ cad_parser = CADparser(bit=6)
463
+ # print(idx)
464
+ if type(item) == str:
465
+ parsed_data = cad_parser.perform(item)
466
+ elif type(item) == dict:
467
+ parsed_data = cad_parser.perform(item['output'])
468
+ else:
469
+ raise ValueError("Invalid data type")
470
+ out_path = os.path.join(args.out_path, str(idx).zfill(6))
471
+ os.makedirs(out_path, exist_ok=True)
472
+ if parsed_data is not None:
473
+ num_valid_str += 1
474
+ write_obj_sample(out_path, parsed_data)
475
+ except Exception as e:
476
+ print(e)
477
+ pass
478
+ print(f"Number of valid CAD strings: {num_valid_str}/{len(data)}")
CADFusion/src/rendering_utils/parser_visual.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from pathlib import Path
4
+ from tqdm import tqdm
5
+ from multiprocessing import Pool
6
+ from glob import glob
7
+ from utils.obj_reconverter import OBJReconverter
8
+ from OCC.Core.BRepCheck import BRepCheck_Analyzer
9
+ from geometry.obj_parser import OBJParser
10
+ from utils.util import write_stl_file
11
+ from OCC.Extend.DataExchange import write_step_file
12
+
13
+ import signal
14
+ from contextlib import contextmanager
15
+ @contextmanager
16
+ def timeout(time):
17
+ # Register a function to raise a TimeoutError on the signal.
18
+ signal.signal(signal.SIGALRM, raise_timeout)
19
+ # Schedule the signal to be sent after ``time``.
20
+ signal.alarm(time)
21
+ try:
22
+ yield
23
+ except TimeoutError:
24
+ raise Exception("time out")
25
+ finally:
26
+ # Unregister the signal so it won't be triggered
27
+ # if the timeout is not reached.
28
+ signal.signal(signal.SIGALRM, signal.SIG_IGN)
29
+ def raise_timeout(signum, frame):
30
+ raise TimeoutError
31
+
32
+ NUM_TRHEADS = 36
33
+
34
+ def find_files(folder, extension):
35
+ return sorted([Path(os.path.join(folder, f)) for f in os.listdir(folder) if f.endswith(extension)])
36
+
37
+
38
+ def run_parallel(project_folder):
39
+ output_folder = project_folder
40
+
41
+ param_objs = find_files(project_folder, 'param.obj')
42
+
43
+ cur_solid = None
44
+ extrude_idx = 0
45
+ for obj in param_objs:
46
+ try:
47
+ with timeout(30):
48
+ parser = OBJParser(obj)
49
+ _, faces, meta_info = parser.parse_file(1.0)
50
+ converter = OBJReconverter()
51
+ ext_solid, _, _ = converter.parse_obj(faces, meta_info)
52
+ set_op = meta_info["set_op"]
53
+ if set_op == "NewBodyFeatureOperation" or set_op == "JoinFeatureOperation":
54
+ if cur_solid is None:
55
+ cur_solid = ext_solid
56
+ else:
57
+ cur_solid = converter.my_op(cur_solid, ext_solid, 'fuse')
58
+ elif set_op == "CutFeatureOperation":
59
+ cur_solid = converter.my_op(cur_solid, ext_solid, 'cut')
60
+ elif set_op == "IntersectFeatureOperation":
61
+ cur_solid = converter.my_op(cur_solid, ext_solid, 'common')
62
+ else:
63
+ raise Exception("Unknown operation type")
64
+
65
+ analyzer = BRepCheck_Analyzer(cur_solid)
66
+ if not analyzer.IsValid():
67
+ raise Exception("brep check failed")
68
+
69
+ extrude_idx += 1
70
+
71
+ except Exception as ex:
72
+ print(ex)
73
+ msg = [project_folder, str(ex)[:100]]
74
+ return None
75
+ try:
76
+ with timeout(30):
77
+ stl_name = Path(output_folder).stem + '_'+ str(extrude_idx).zfill(3) + "_final.stl"
78
+ output_path = os.path.join(output_folder, stl_name)
79
+ write_stl_file(cur_solid, output_path, linear_deflection=0.001, angular_deflection=0.5)
80
+
81
+ step_name = Path(output_folder).stem + '_'+ str(extrude_idx).zfill(3) + "_final.step"
82
+ output_path = os.path.join(output_folder, step_name)
83
+ write_step_file(cur_solid, output_path)
84
+
85
+ except Exception as ex:
86
+ print(ex)
87
+ msg = [project_folder, str(ex)[:500]]
88
+ return None
89
+
90
+ return cur_solid
91
+
92
+
93
+ if __name__ == "__main__":
94
+ parser = argparse.ArgumentParser()
95
+ parser.add_argument("--data_folder", type=str, required=True)
96
+ parser.add_argument("--single-file", action='store_true', default=False)
97
+ args = parser.parse_args()
98
+
99
+ if args.single_file:
100
+ # If single file, just run the function on that file
101
+ run_parallel(args.data_folder)
102
+ exit(0)
103
+ else:
104
+ solids = []
105
+ # cad_folders = sorted(glob(args.data_folder+'/*'))[50000:] # why after 50000?
106
+ cad_folders = sorted(glob(args.data_folder+'/*'))
107
+ # print("len of cad_folder:", len(cad_folders))
108
+ convert_iter = Pool(NUM_TRHEADS).imap(run_parallel, cad_folders)
109
+ for solid in tqdm(convert_iter, total=len(cad_folders)):
110
+ pass
CADFusion/src/rendering_utils/ptl_sampler.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import ntpath
4
+ from tqdm import tqdm
5
+ import multiprocessing
6
+ from pathlib import Path
7
+ from glob import glob
8
+ import trimesh
9
+ from trimesh.sample import sample_surface
10
+ from plyfile import PlyData, PlyElement
11
+ import numpy as np
12
+
13
+
14
+ def write_ply(points, filename, text=False):
15
+ """ input: Nx3, write points to filename as PLY format. """
16
+ points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])]
17
+ vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')])
18
+ el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
19
+ with open(filename, mode='wb') as f:
20
+ PlyData([el], text=text).write(f)
21
+
22
+
23
+ def find_files(folder, extension):
24
+ return sorted([Path(os.path.join(folder, f)) for f in os.listdir(folder) if f.endswith(extension)])
25
+
26
+ class SamplePoints:
27
+ """
28
+ Perform sampleing of points.
29
+ """
30
+
31
+ def __init__(self):
32
+ """
33
+ Constructor.
34
+ """
35
+ parser = self.get_parser()
36
+ self.options = parser.parse_args()
37
+
38
+
39
+ def get_parser(self):
40
+ """
41
+ Get parser of tool.
42
+
43
+ :return: parser
44
+ """
45
+ parser = argparse.ArgumentParser(description='Scale a set of meshes stored as OFF files.')
46
+ parser.add_argument('--in_dir', type=str, help='Path to input directory.')
47
+ parser.add_argument('--out_dir', type=str, help='Path to output directory; files within are overwritten!')
48
+ parser.add_argument("--single-file", action='store_true', default=False)
49
+ return parser
50
+
51
+
52
+ def run_parallel(self, project_folder):
53
+ out_folder = os.path.join(project_folder, self.options.out_dir)
54
+ if not os.path.exists(out_folder):
55
+ os.makedirs(out_folder)
56
+
57
+ files = find_files(project_folder, 'final.stl')
58
+
59
+ for filepath in files:
60
+ N_POINTS = 2000
61
+ try:
62
+ out_mesh = trimesh.load(str(filepath))
63
+ out_pc, _ = sample_surface(out_mesh, N_POINTS)
64
+ save_path = os.path.join(out_folder, ntpath.basename(filepath)[:-4]+'_pcd.ply')
65
+ write_ply(out_pc, save_path)
66
+
67
+ except Exception as ex:
68
+ return project_folder
69
+ return
70
+
71
+
72
+ def run(self):
73
+ """
74
+ Run simplification.
75
+ """
76
+ if self.options.single_file:
77
+ self.run_parallel(self.options.in_dir)
78
+ else:
79
+ project_folders = sorted(glob(self.options.in_dir+'/*/'))
80
+ num_cpus = multiprocessing.cpu_count()
81
+ convert_iter = multiprocessing.Pool(num_cpus).imap(self.run_parallel, project_folders)
82
+ for _ in tqdm(convert_iter, total=len(project_folders)):
83
+ pass
84
+
85
+
86
+ if __name__ == '__main__':
87
+ app = SamplePoints()
88
+ app.run()
CADFusion/src/rendering_utils/utils/obj_reconverter.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from collections import OrderedDict
3
+ from .util import create_point, create_unit_vec, get_transform, create_sketch_plane
4
+
5
+ # OCC
6
+ from OCC.Core.BRepCheck import BRepCheck_Analyzer
7
+ from OCC.Core.GC import GC_MakeArcOfCircle
8
+ from OCC.Core.BRepBuilderAPI import (
9
+ BRepBuilderAPI_MakeFace,
10
+ BRepBuilderAPI_MakeWire,
11
+ BRepBuilderAPI_MakeEdge,
12
+ )
13
+ from OCC.Core.BRepAlgoAPI import BRepAlgoAPI_Fuse, BRepAlgoAPI_Cut, BRepAlgoAPI_Common
14
+ from OCC.Core.BRepPrimAPI import BRepPrimAPI_MakePrism
15
+ from OCC.Core.BRepAdaptor import BRepAdaptor_Surface
16
+ from OCC.Core.BRepGProp import brepgprop_VolumeProperties, brepgprop_SurfaceProperties
17
+ from OCC.Core.GProp import GProp_GProps
18
+ from OCC.Core.ShapeFix import ShapeFix_Face, ShapeFix_Wire
19
+ from OCC.Core.gp import gp_Vec, gp_Ax2, gp_Dir, gp_Circ
20
+ from OCC.Extend.DataExchange import write_stl_file
21
+
22
+
23
+ class OBJReconverter:
24
+ """OBJ Data Reconverter"""
25
+
26
+ def __init__(self):
27
+ self.vertex_dict = OrderedDict()
28
+ self.PRECISION = 1e-5
29
+ self.eps = 1e-7
30
+ self.x_axis = gp_Dir(1.0, 0.0, 0.0)
31
+
32
+ def convert_curve(self, curve):
33
+ """
34
+ convert to json dict format
35
+ """
36
+ json_curve = {}
37
+
38
+ if curve.type == "circle":
39
+ json_curve["type"] = "Circle3D"
40
+ json_curve["center_point"] = {
41
+ "x": curve.center[0],
42
+ "y": curve.center[1],
43
+ "z": 0,
44
+ }
45
+ json_curve["radius"] = curve.radius
46
+
47
+ if curve.type == "line":
48
+ json_curve["type"] = "Line3D"
49
+ json_curve["start_point"] = {
50
+ "x": curve.start[0],
51
+ "y": curve.start[1],
52
+ "z": 0,
53
+ }
54
+ json_curve["end_point"] = {"x": curve.end[0], "y": curve.end[1], "z": 0}
55
+
56
+ if curve.type == "arc":
57
+ json_curve["type"] = "Arc3D"
58
+ json_curve["start_point"] = {
59
+ "x": curve.start[0],
60
+ "y": curve.start[1],
61
+ "z": 0,
62
+ }
63
+ json_curve["end_point"] = {"x": curve.end[0], "y": curve.end[1], "z": 0}
64
+ json_curve["mid_point"] = {"x": curve.mid[0], "y": curve.mid[1], "z": 0}
65
+ json_curve["center_point"] = {
66
+ "x": curve.center[0],
67
+ "y": curve.center[1],
68
+ "z": 0,
69
+ }
70
+
71
+ json_curve["is_outer"] = curve.is_outer
72
+ return json_curve
73
+
74
+ def convert_vertices(self):
75
+ """Convert all the vertices to .obj format"""
76
+ vertex_strings = ""
77
+ for pt in self.vertex_dict.values():
78
+ # e.g. v 0.123 0.234 0.345 1.0
79
+ vertex_string = f"v {pt[0]} {pt[1]}\n"
80
+ vertex_strings += vertex_string
81
+ return vertex_strings
82
+
83
+ def parse_obj(self, faces, meta_info):
84
+ """
85
+ reconstruct brep from obj file
86
+ """
87
+ # At least one needs to match
88
+ for face in faces:
89
+ for loop in face:
90
+ if len(loop) > 1:
91
+ for idx, curve in enumerate(loop[:-1]):
92
+ next_curve = np.vstack([loop[idx + 1].start, loop[idx + 1].end])
93
+ diff1 = np.sum(np.abs(curve.start - next_curve), 1)
94
+ diff2 = np.sum(np.abs(curve.end - next_curve), 1)
95
+
96
+ if min(diff2) == 0 or min(diff1) == 0:
97
+ continue # edge connected
98
+
99
+ assert (
100
+ min(diff1) < 1e-3 or min(diff2) < 1e-3
101
+ ) # difference should be small
102
+
103
+ if min(diff1) > min(diff2):
104
+ min_idx = np.argmin(diff2)
105
+ if min_idx == 0:
106
+ loop[idx + 1].start_idx = curve.end_idx
107
+ loop[idx + 1].start = curve.end
108
+ else:
109
+ loop[idx + 1].end_idx = curve.end_idx
110
+ loop[idx + 1].end = curve.end
111
+ else:
112
+ min_idx = np.argmin(diff1)
113
+ if min_idx == 0:
114
+ loop[idx + 1].start_idx = curve.start_idx
115
+ loop[idx + 1].start = curve.start
116
+ else:
117
+ loop[idx + 1].end_idx = curve.start_idx
118
+ loop[idx + 1].end = curve.start
119
+
120
+ # Solve start / end connection
121
+ shared_idx = list(
122
+ set([loop[-2].start_idx, loop[-2].end_idx]).intersection(
123
+ set([loop[-1].start_idx, loop[-1].end_idx])
124
+ )
125
+ )
126
+
127
+ assert len(shared_idx) >= 1
128
+
129
+ if len(shared_idx) == 2:
130
+ assert len(loop) == 2 # do nothing
131
+ else:
132
+ if shared_idx[0] == loop[-1].start_idx:
133
+ do_start = False
134
+ else:
135
+ do_start = True
136
+ start_curve = np.vstack([loop[0].start, loop[0].end])
137
+
138
+ if do_start:
139
+ diff = np.sum(np.abs(loop[-1].start - start_curve), 1)
140
+ else:
141
+ diff = np.sum(np.abs(loop[-1].end - start_curve), 1)
142
+ assert min(diff) < 1e-3
143
+
144
+ min_idx = np.argmin(diff)
145
+ if min_idx == 0:
146
+ if do_start:
147
+ loop[-1].start_idx = loop[0].start_idx
148
+ loop[-1].start = loop[0].start
149
+ else:
150
+ loop[-1].end_idx = loop[0].start_idx
151
+ loop[-1].end = loop[0].start
152
+ else:
153
+ if do_start:
154
+ loop[-1].start_idx = loop[0].end_idx
155
+ loop[-1].start = loop[0].end
156
+ else:
157
+ loop[-1].end_idx = loop[0].end_idx
158
+ loop[-1].end = loop[0].end
159
+
160
+ # Parse groups to json loop/curve profile
161
+ extrusion = {}
162
+ extrusion["profiles"] = []
163
+ for face in faces:
164
+ profile = {}
165
+ profile["loops"] = []
166
+ for loop in face:
167
+ pl = {}
168
+ pl["profile_curves"] = []
169
+ for curve in loop:
170
+ # convert to json format
171
+ pl["profile_curves"].append(self.convert_curve(curve))
172
+ profile["loops"].append(pl)
173
+ extrusion["profiles"].append(profile)
174
+
175
+ # Parse transform
176
+ sketch = {}
177
+ transform = {}
178
+ transform["origin"] = {
179
+ "x": meta_info["t_orig"][0],
180
+ "y": meta_info["t_orig"][1],
181
+ "z": meta_info["t_orig"][2],
182
+ }
183
+ transform["x_axis"] = {
184
+ "x": meta_info["t_x"][0],
185
+ "y": meta_info["t_x"][1],
186
+ "z": meta_info["t_x"][2],
187
+ }
188
+ transform["y_axis"] = {
189
+ "x": meta_info["t_y"][0],
190
+ "y": meta_info["t_y"][1],
191
+ "z": meta_info["t_y"][2],
192
+ }
193
+ transform["z_axis"] = {
194
+ "x": meta_info["t_z"][0],
195
+ "y": meta_info["t_z"][1],
196
+ "z": meta_info["t_z"][2],
197
+ }
198
+ sketch["transform"] = transform
199
+
200
+ # Parse extrude
201
+ extrude_params = {}
202
+ extrude_params["extrude_type"] = meta_info["set_op"]
203
+ extrude_params["extrude_values"] = meta_info["extrude_value"]
204
+
205
+ # Create sketch
206
+ all_faces = []
207
+ curve_strings = ""
208
+ curve_count = 0
209
+ for profile in extrusion["profiles"]:
210
+ ref_face, face, curve_string, c_count = self.parse_sketch(sketch, profile)
211
+ curve_strings += curve_string
212
+ curve_count += c_count
213
+ all_faces.append(face)
214
+
215
+ # Merge all faces in the same plane
216
+ plane_face = all_faces[0]
217
+ for face in all_faces[1:]:
218
+ plane_face = self.my_op(plane_face, face, "fuse")
219
+ solid = self.extrude_face(ref_face, plane_face, extrude_params)
220
+ return solid, curve_strings, curve_count
221
+
222
+ def my_op(self, big, small, op_name):
223
+ if op_name == "cut":
224
+ op = BRepAlgoAPI_Cut(big, small)
225
+ elif op_name == "fuse":
226
+ op = BRepAlgoAPI_Fuse(big, small)
227
+ elif op_name == "common":
228
+ op = BRepAlgoAPI_Common(big, small)
229
+ op.SetFuzzyValue(self.PRECISION)
230
+ op.Build()
231
+ return op.Shape()
232
+
233
+ def build_body(self, face, normal, value):
234
+ extrusion_vec = gp_Vec(normal).Multiplied(value)
235
+ make_prism = BRepPrimAPI_MakePrism(face, extrusion_vec)
236
+ make_prism.Build()
237
+ prism = make_prism.Prism()
238
+ return prism.Shape()
239
+
240
+ def extrudeBasedOnType(self, face, normal, distance):
241
+ # Extrude based on the two bound values
242
+ if not (distance[0] < distance[1]):
243
+ raise Exception("incorrect distance")
244
+ large_value = distance[1]
245
+ small_value = distance[0]
246
+
247
+ if large_value == 0:
248
+ return self.build_body(face, -normal, -small_value)
249
+ elif small_value == 0:
250
+ return self.build_body(face, normal, large_value)
251
+ elif np.sign(large_value) == np.sign(small_value):
252
+ if large_value < 0:
253
+ body1 = self.build_body(face, -normal, -small_value)
254
+ body2 = self.build_body(face, -normal, -large_value)
255
+ return self.my_op(body1, body2, "cut")
256
+ else:
257
+ assert large_value > 0
258
+ body1 = self.build_body(face, normal, small_value)
259
+ body2 = self.build_body(face, normal, large_value)
260
+ return self.my_op(body2, body1, "cut")
261
+ else:
262
+ assert np.sign(large_value) != np.sign(small_value)
263
+ body1 = self.build_body(face, normal, large_value)
264
+ body2 = self.build_body(face, -normal, -small_value)
265
+ return self.my_op(body1, body2, "fuse")
266
+
267
+ def extrude_face(self, ref_face, face, extrude_params):
268
+ distance = extrude_params["extrude_values"]
269
+ surf = BRepAdaptor_Surface(ref_face).Plane()
270
+ normal = surf.Axis().Direction()
271
+ extruded_shape = self.extrudeBasedOnType(face, normal, distance)
272
+ return extruded_shape
273
+
274
+ def parse_sketch(self, sketch, profile):
275
+ """
276
+ Sketch in one closed loop (one out, multiple ins)
277
+ """
278
+ # Transformation from local to global xyz coord
279
+ transform = get_transform(sketch["transform"])
280
+
281
+ # Create face region (automatically infer from all wires)
282
+ outer_facelist = []
283
+ inner_facelist = []
284
+ curve_count = 0
285
+ outer_string = []
286
+ inner_string = []
287
+ plane = create_sketch_plane(sketch["transform"])
288
+
289
+ for idx, pl in enumerate(profile["loops"]):
290
+ # Create loop
291
+ loop, curve_string, num_curve = self.parse_loop(
292
+ pl["profile_curves"], transform
293
+ )
294
+ # Create face
295
+ face_builder = BRepBuilderAPI_MakeFace(plane, loop)
296
+ if not face_builder.IsDone():
297
+ raise Exception("face builder not done")
298
+ face = face_builder.Face()
299
+ # Fix face
300
+ fixer = ShapeFix_Face(face)
301
+ fixer.SetPrecision(self.PRECISION)
302
+ fixer.FixOrientation()
303
+
304
+ analyzer = BRepCheck_Analyzer(fixer.Face())
305
+ if not analyzer.IsValid():
306
+ raise Exception("face check failed")
307
+
308
+ curve_count += num_curve
309
+
310
+ if pl["profile_curves"][0]["is_outer"]:
311
+ outer_facelist.append(fixer.Face())
312
+ outer_string.append(curve_string)
313
+ else:
314
+ inner_facelist.append(fixer.Face())
315
+ inner_string.append(curve_string)
316
+
317
+ # Create final closed loop face
318
+ assert len(outer_facelist) > 0
319
+ final_face = outer_facelist[0]
320
+ for face in outer_facelist[1:]:
321
+ final_face = self.my_op(final_face, face, "fuse")
322
+ for face in inner_facelist:
323
+ final_face = self.my_op(final_face, face, "cut")
324
+
325
+ # Append inner outer information to string
326
+ assert len(outer_string) == 1
327
+ out_str = ""
328
+ in_str = ""
329
+ for c_str in outer_string:
330
+ out_str += "out\n" + c_str + "\n"
331
+ for c_str in inner_string:
332
+ in_str += "in\n" + c_str + "\n"
333
+ final_str = "face\n" + out_str + in_str
334
+
335
+ return outer_facelist[0], final_face, final_str, curve_count
336
+
337
+ def parse_loop(self, profile_loop, transform):
338
+ """Create face in one closed loop"""
339
+ topo_wire = BRepBuilderAPI_MakeWire()
340
+ curve_strings = ""
341
+ curve_count = 0
342
+
343
+ # Loop through all the curves in one loop
344
+ for profile_curve in profile_loop:
345
+ curve_edge, curve_string = self.parse_curve(profile_curve, transform)
346
+ topo_wire.Add(curve_edge)
347
+ if not topo_wire.IsDone():
348
+ raise Exception("wire builder not done")
349
+
350
+ curve_string += "\n"
351
+ curve_count += 1
352
+ curve_strings += curve_string
353
+
354
+ fixer = ShapeFix_Wire()
355
+ fixer.Load(topo_wire.Wire())
356
+ fixer.SetPrecision(self.PRECISION)
357
+ fixer.FixClosed()
358
+ fixer.Perform()
359
+ return fixer.Wire(), curve_strings, curve_count
360
+
361
+ def parse_curve(self, curve, transform):
362
+ if curve["type"] == "Line3D":
363
+ return self.create_line(curve, transform)
364
+ elif curve["type"] == "Circle3D":
365
+ return self.create_circle(curve, transform)
366
+ elif curve["type"] == "Arc3D":
367
+ return self.create_arc(curve, transform)
368
+ else:
369
+ raise Exception("unknown curve type")
370
+
371
+ def create_line(self, line, transform):
372
+ start = create_point(line["start_point"], transform)
373
+ end = create_point(line["end_point"], transform)
374
+ if start.Distance(end) == 0:
375
+ raise Exception("start/end point same location")
376
+ topo_edge = BRepBuilderAPI_MakeEdge(start, end)
377
+
378
+ # Save pre-transform
379
+ star_idx = self.save_vertex(
380
+ line["start_point"]["x"] + 0.0, line["start_point"]["y"] + 0.0, "p"
381
+ )
382
+ end_idx = self.save_vertex(
383
+ line["end_point"]["x"] + 0.0, line["end_point"]["y"] + 0.0, "p"
384
+ )
385
+ curve_string = f"l {star_idx} {end_idx}"
386
+ return topo_edge.Edge(), curve_string
387
+
388
+ def create_arc(self, arc, transform):
389
+ start = create_point(arc["start_point"], transform)
390
+ mid = create_point(arc["mid_point"], transform)
391
+ end = create_point(arc["end_point"], transform)
392
+ arc_occ = GC_MakeArcOfCircle(start, mid, end).Value()
393
+ topo_edge = BRepBuilderAPI_MakeEdge(arc_occ)
394
+
395
+ # Save pre-transform
396
+ start_idx = self.save_vertex(
397
+ arc["start_point"]["x"] + 0.0, arc["start_point"]["y"] + 0.0, "p"
398
+ )
399
+ end_idx = self.save_vertex(
400
+ arc["end_point"]["x"] + 0.0, arc["end_point"]["y"] + 0.0, "p"
401
+ )
402
+ center_idx = self.save_vertex(
403
+ arc["center_point"]["x"] + 0.0, arc["center_point"]["y"] + 0.0, "p"
404
+ )
405
+ mid_idx = self.save_vertex(
406
+ arc["mid_point"]["x"] + 0.0, arc["mid_point"]["y"] + 0.0, "p"
407
+ )
408
+ curve_string = f"a {start_idx} {mid_idx} {center_idx} {end_idx}"
409
+ return topo_edge.Edge(), curve_string
410
+
411
+ def create_circle(self, circle, transform):
412
+ center = create_point(circle["center_point"], transform)
413
+ radius = circle["radius"]
414
+ normal = create_unit_vec({"x": 0.0, "y": 0.0, "z": 1.0}, transform)
415
+ ref_vector3d = self.x_axis.Transformed(transform)
416
+ axis = gp_Ax2(center, normal, ref_vector3d)
417
+ gp_circle = gp_Circ(axis, abs(float(radius)))
418
+ topo_edge = BRepBuilderAPI_MakeEdge(gp_circle)
419
+
420
+ center_idx = self.save_vertex(
421
+ circle["center_point"]["x"] + 0.0, circle["center_point"]["y"] + 0.0, "p"
422
+ )
423
+ radius_idx = self.save_vertex(abs(float(radius)) + 0.0, 0, "r")
424
+ curve_string = f"c {center_idx} {radius_idx}"
425
+ return topo_edge.Edge(), curve_string
426
+
427
+ def save_vertex(self, h_x, h_y, text):
428
+ unique_key = f"{text}:x{h_x}y{h_y}"
429
+ index = 0
430
+ for key in self.vertex_dict.keys():
431
+ # Vertex location already exist in dict
432
+ if unique_key == key:
433
+ return index
434
+ index += 1
435
+ # Vertex location does not exist in dict
436
+ self.vertex_dict[unique_key] = [h_x, h_y]
437
+ return index
CADFusion/src/rendering_utils/utils/util.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from OCC.Core.gp import gp_Pnt, gp_Vec, gp_Dir, gp_XYZ, gp_Ax3, gp_Trsf, gp_Pln
3
+ from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
4
+ from OCC.Core.StlAPI import StlAPI_Writer
5
+
6
+ def create_xyz(xyz):
7
+ return gp_XYZ(xyz["x"], xyz["y"], xyz["z"])
8
+
9
+
10
+ def get_ax3(transform_dict):
11
+ origin = create_xyz(transform_dict["origin"])
12
+ x_axis = create_xyz(transform_dict["x_axis"])
13
+ y_axis = create_xyz(transform_dict["y_axis"])
14
+ z_axis = create_xyz(transform_dict["z_axis"])
15
+ # Create new coord (orig, Norm, x-axis)
16
+ axis3 = gp_Ax3(gp_Pnt(origin), gp_Dir(z_axis), gp_Dir(x_axis))
17
+ return axis3
18
+
19
+
20
+ def get_transform(transform_dict):
21
+ axis3 = get_ax3(transform_dict)
22
+ transform_to_local = gp_Trsf()
23
+ transform_to_local.SetTransformation(axis3)
24
+ return transform_to_local.Inverted()
25
+
26
+
27
+ def create_sketch_plane(transform_dict):
28
+ axis3 = get_ax3(transform_dict)
29
+ return gp_Pln(axis3)
30
+
31
+
32
+ def create_point(point_dict, transform):
33
+ pt2d = gp_Pnt(point_dict["x"], point_dict["y"], point_dict["z"])
34
+ return pt2d.Transformed(transform)
35
+
36
+
37
+ def create_unit_vec(vec_dict, transform):
38
+ vec2d = gp_Dir(vec_dict["x"], vec_dict["y"], vec_dict["z"])
39
+ return vec2d.Transformed(transform)
40
+
41
+
42
+ def write_stl_file(a_shape, filename, mode="ascii", linear_deflection=0.001, angular_deflection=0.5):
43
+ """ export the shape to a STL file
44
+ Be careful, the shape first need to be explicitely meshed using BRepMesh_IncrementalMesh
45
+ a_shape: the topods_shape to export
46
+ filename: the filename
47
+ mode: optional, "ascii" by default. Can either be "binary"
48
+ linear_deflection: optional, default to 0.001. Lower, more occurate mesh
49
+ angular_deflection: optional, default to 0.5. Lower, more accurate_mesh
50
+ """
51
+ if a_shape.IsNull():
52
+ raise AssertionError("Shape is null.")
53
+ if mode not in ["ascii", "binary"]:
54
+ raise AssertionError("mode should be either ascii or binary")
55
+ if os.path.isfile(filename):
56
+ print("Warning: %s file already exists and will be replaced" % filename)
57
+ # first mesh the shape
58
+ mesh = BRepMesh_IncrementalMesh(a_shape, linear_deflection, False, angular_deflection, True)
59
+ #mesh.SetDeflection(0.05)
60
+ mesh.Perform()
61
+ if not mesh.IsDone():
62
+ raise AssertionError("Mesh is not done.")
63
+
64
+ stl_exporter = StlAPI_Writer()
65
+ if mode == "ascii":
66
+ stl_exporter.SetASCIIMode(True)
67
+ else: # binary, just set the ASCII flag to False
68
+ stl_exporter.SetASCIIMode(False)
69
+ stl_exporter.Write(a_shape, filename)
70
+
71
+ if not os.path.isfile(filename):
72
+ raise IOError("File not written to disk.")
CADFusion/src/test/VLM_score.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import base64
4
+ import json
5
+ import time
6
+ import argparse
7
+ from mimetypes import guess_type
8
+ from tqdm import tqdm
9
+ import re
10
+
11
+ from openai import AzureOpenAI
12
+ from azure.identity import AzureCliCredential, get_bearer_token_provider
13
+
14
+ scope = "api://trapi/.default"
15
+ credential = get_bearer_token_provider(AzureCliCredential(),scope)
16
+
17
+ api_version = '2024-12-01-preview'
18
+ # deployment_name = 'gpt-4.1-mini_2025-04-14'
19
+ deployment_name = 'gpt-4o_2024-08-06'
20
+ instance = '<trapi/path>' # See https://aka.ms/trapi/models for the instance name, remove /openai (library adds it implicitly)
21
+ endpoint = f'https://trapi.research.microsoft.com/{instance}'
22
+
23
+ client = AzureOpenAI(
24
+ azure_endpoint=endpoint,
25
+ azure_ad_token_provider=credential,
26
+ api_version=api_version,
27
+ )
28
+
29
+ def local_image_to_data_url(image_path):
30
+ mime_type, _ = guess_type(image_path)
31
+ if mime_type is None:
32
+ mime_type = 'application/octet-stream'
33
+ with open(image_path, "rb") as image_file:
34
+ base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
35
+ return f"data:{mime_type};base64,{base64_encoded_data}"
36
+
37
+ def ask_gpt(image_path, prompt):
38
+ image_url = local_image_to_data_url(image_path)
39
+ message_text = [
40
+ {"role": "system", "content": "You are an AI assistant that helps people find information."},
41
+ {"role": "user", "content": [
42
+ {"type": "text", "text": prompt},
43
+ {"type": "image_url", "image_url": {"url": image_url}},
44
+ ]}
45
+ ]
46
+
47
+ completion = client.chat.completions.create(
48
+ model=deployment_name,
49
+ messages=message_text,)
50
+ output = completion.choices[0].message.content
51
+ return output
52
+
53
+ if __name__ == '__main__':
54
+ import argparse
55
+ parser = argparse.ArgumentParser()
56
+ parser.add_argument('--test-path', type=str, default='data/sl_data/test.jsonl', help='Path to the JSONL file containing test data')
57
+ parser.add_argument('--name', type=str, default='original_seq', help='Run name of the testee')
58
+ parser.add_argument('--figure-dir', type=str, default='exp/figures')
59
+ parser.add_argument('--save-path', type=str, default='exp/evals', help='Target folder to save the results')
60
+ parser.add_argument('--repetition', type=int, default=5, help='Number of repetitions for each image')
61
+ args = parser.parse_args()
62
+
63
+ results = []
64
+ jsonl_path = args.test_path
65
+ name = args.name
66
+ figures_dir = f"{args.figure_dir}/{name}/"
67
+ save_path = f"{args.save_path}/{name}.jsonl"
68
+
69
+ with open(jsonl_path, 'r+') as file:
70
+ test_data = json.load(file)
71
+ repetition = args.repetition
72
+ results = []
73
+ for i in tqdm(range(len(test_data[:800]))):
74
+ item = test_data[i]
75
+ for j in range(repetition):
76
+ img_num = i * repetition + j
77
+ image_name = f"{img_num:06d}.png"
78
+ image_path = os.path.join(figures_dir, image_name)
79
+ if os.path.exists(image_path):
80
+ description = item['description']
81
+ try:
82
+ score = ask_gpt(image_path, f"The following is a text description of a 3D CAD figure and an image of a CAD instance. Measure if the figure corresponds to the given description, and give a score in the scale of 10. Only return the score. Do not comment on issues such as texture, smoothness and colors.\n description:{description}\n")
83
+
84
+ # "The following is an original image of a CAD instance, a text description on editing and an image of the edited result. Measure if the figure corresponds to the given description, and give a score in the scale of 10. Only return the score. Do not comment on issues such as texture, smoothness and colors.\n description:{description}\n"
85
+ except Exception as e:
86
+ print(img_num)
87
+ print(e)
88
+ score = -1
89
+ result = {
90
+ "index": img_num,
91
+ "gpt_score": score
92
+ }
93
+ results.append(result)
94
+ with open(save_path, 'w+') as file:
95
+ json.dump(results, file, indent=4)
CADFusion/src/test/chamfer_dist.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import os
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ import random
7
+ import warnings
8
+ from glob import glob
9
+ from scipy.stats import entropy
10
+ from sklearn.neighbors import NearestNeighbors
11
+ from plyfile import PlyData
12
+ from pathlib import Path
13
+ from multiprocessing import Pool
14
+ from chamfer_distance import ChamferDistance
15
+
16
+ random.seed(0)
17
+ N_POINTS = 2000
18
+ NUM_TRHEADS = 16
19
+
20
+
21
+ def find_files(folder, extension):
22
+ return sorted([Path(os.path.join(folder, f)) for f in os.listdir(folder) if f.endswith(extension)])
23
+
24
+
25
+ def read_ply(path):
26
+ with open(path, 'rb') as f:
27
+ plydata = PlyData.read(f)
28
+ x = np.array(plydata['vertex']['x'])
29
+ y = np.array(plydata['vertex']['y'])
30
+ z = np.array(plydata['vertex']['z'])
31
+ vertex = np.stack([x, y, z], axis=1)
32
+ return vertex
33
+
34
+
35
+ def distChamfer(a, b):
36
+ x, y = a, b
37
+ bs, num_points, points_dim = x.size()
38
+ xx = torch.bmm(x, x.transpose(2, 1))
39
+ yy = torch.bmm(y, y.transpose(2, 1))
40
+ zz = torch.bmm(x, y.transpose(2, 1))
41
+ diag_ind = torch.arange(0, num_points).to(a).long()
42
+ rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
43
+ ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
44
+ P = (rx.transpose(2, 1) + ry - 2 * zz)
45
+ return P.min(1)[0], P.min(2)[0]
46
+
47
+
48
+ def _pairwise_CD(sample_pcs, ref_pcs, batch_size):
49
+ N_sample = sample_pcs.shape[0]
50
+ N_ref = ref_pcs.shape[0]
51
+ all_cd = []
52
+ all_emd = []
53
+ iterator = range(N_sample)
54
+ matched_gt = []
55
+ pbar = tqdm(iterator)
56
+ chamfer_dist = ChamferDistance()
57
+
58
+ for sample_b_start in pbar:
59
+ sample_batch = sample_pcs[sample_b_start]
60
+
61
+ cd_lst = []
62
+ emd_lst = []
63
+ for ref_b_start in range(0, N_ref, batch_size):
64
+ ref_b_end = min(N_ref, ref_b_start + batch_size)
65
+ ref_batch = ref_pcs[ref_b_start:ref_b_end]
66
+
67
+ batch_size_ref = ref_batch.size(0)
68
+ sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1)
69
+ sample_batch_exp = sample_batch_exp.contiguous()
70
+
71
+ dl, dr, idx1, idx2 = chamfer_dist(sample_batch_exp,ref_batch)
72
+ cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))
73
+
74
+ cd_lst = torch.cat(cd_lst, dim=1)
75
+ all_cd.append(cd_lst)
76
+
77
+ hit = np.argmin(cd_lst.detach().cpu().numpy()[0])
78
+ matched_gt.append(hit)
79
+ pbar.set_postfix({"cov": len(np.unique(matched_gt)) * 1.0 / N_ref})
80
+
81
+ all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref
82
+
83
+ return all_cd
84
+
85
+
86
+ def compute_cov_mmd(sample_pcs, ref_pcs, batch_size):
87
+ all_dist = _pairwise_CD(sample_pcs, ref_pcs, batch_size)
88
+ print(all_dist.shape, flush=True)
89
+ N_sample, N_ref = all_dist.size(0), all_dist.size(1)
90
+ min_val_fromsmp, min_idx = torch.min(all_dist, dim=1)
91
+ min_val, _ = torch.min(all_dist, dim=0)
92
+ mmd = min_val.mean()
93
+ cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)
94
+ cov = torch.tensor(cov).to(all_dist)
95
+
96
+ return {
97
+ # 'med-CD': torch.diagonal(all_dist).median().item(),
98
+ 'avg-CD': torch.diagonal(all_dist).mean().item(),
99
+ 'COV-CD': cov.item(),
100
+ 'MMD-CD': mmd.item()
101
+ }
102
+
103
+
104
+ def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, in_unit_sphere, resolution=28):
105
+ '''Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models For 3D Point Clouds```.
106
+ Args:
107
+ sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points.
108
+ ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points.
109
+ resolution: (int) grid-resolution. Affects granularity of measurements.
110
+ '''
111
+ sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1]
112
+ ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1]
113
+ return jensen_shannon_divergence(sample_grid_var, ref_grid_var)
114
+
115
+
116
+ def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False):
117
+ '''Given a collection of point-clouds, estimate the entropy of the random variables
118
+ corresponding to occupancy-grid activation patterns.
119
+ Inputs:
120
+ pclouds: (numpy array) #point-clouds x points per point-cloud x 3
121
+ grid_resolution (int) size of occupancy grid that will be used.
122
+ '''
123
+ epsilon = 10e-4
124
+ bound = 1 + epsilon
125
+ if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
126
+ print(abs(np.max(pclouds)), abs(np.min(pclouds)))
127
+ warnings.warn('Point-clouds are not in unit cube.')
128
+
129
+ if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:
130
+ warnings.warn('Point-clouds are not in unit sphere.')
131
+
132
+ grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
133
+ grid_coordinates = grid_coordinates.reshape(-1, 3)
134
+ grid_counters = np.zeros(len(grid_coordinates))
135
+ grid_bernoulli_rvars = np.zeros(len(grid_coordinates))
136
+ nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates)
137
+
138
+ for pc in pclouds:
139
+ _, indices = nn.kneighbors(pc)
140
+ indices = np.squeeze(indices)
141
+ for i in indices:
142
+ grid_counters[i] += 1
143
+ indices = np.unique(indices)
144
+ for i in indices:
145
+ grid_bernoulli_rvars[i] += 1
146
+
147
+ acc_entropy = 0.0
148
+ n = float(len(pclouds))
149
+ for g in grid_bernoulli_rvars:
150
+ p = 0.0
151
+ if g > 0:
152
+ p = float(g) / n
153
+ acc_entropy += entropy([p, 1.0 - p])
154
+
155
+ return acc_entropy / len(grid_counters), grid_counters
156
+
157
+
158
+ def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
159
+ '''Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,
160
+ that is placed in the unit-cube.
161
+ If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere.
162
+ '''
163
+ grid = np.ndarray((resolution, resolution, resolution, 3), np.float32)
164
+ spacing = 1.0 / float(resolution - 1) * 2
165
+ for i in range(resolution):
166
+ for j in range(resolution):
167
+ for k in range(resolution):
168
+ grid[i, j, k, 0] = i * spacing - 0.5 * 2
169
+ grid[i, j, k, 1] = j * spacing - 0.5 * 2
170
+ grid[i, j, k, 2] = k * spacing - 0.5 * 2
171
+
172
+ if clip_sphere:
173
+ grid = grid.reshape(-1, 3)
174
+ grid = grid[np.linalg.norm(grid, axis=1) <= 0.5]
175
+
176
+ return grid, spacing
177
+
178
+
179
+ def jensen_shannon_divergence(P, Q):
180
+ if np.any(P < 0) or np.any(Q < 0):
181
+ raise ValueError('Negative values.')
182
+ if len(P) != len(Q):
183
+ raise ValueError('Non equal size.')
184
+
185
+ P_ = P / np.sum(P) # Ensure probabilities.
186
+ Q_ = Q / np.sum(Q)
187
+
188
+ e1 = entropy(P_, base=2)
189
+ e2 = entropy(Q_, base=2)
190
+ e_sum = entropy((P_ + Q_) / 2.0, base=2)
191
+ res = e_sum - ((e1 + e2) / 2.0)
192
+
193
+ res2 = _jsdiv(P_, Q_)
194
+
195
+ if not np.allclose(res, res2, atol=10e-5, rtol=0):
196
+ warnings.warn('Numerical values of two JSD methods don\'t agree.')
197
+
198
+ return res
199
+
200
+
201
+ def _jsdiv(P, Q):
202
+ '''another way of computing JSD'''
203
+
204
+ def _kldiv(A, B):
205
+ a = A.copy()
206
+ b = B.copy()
207
+ idx = np.logical_and(a > 0, b > 0)
208
+ a = a[idx]
209
+ b = b[idx]
210
+ return np.sum([v for v in a * np.log2(a / b)])
211
+
212
+ P_ = P / np.sum(P)
213
+ Q_ = Q / np.sum(Q)
214
+
215
+ M = 0.5 * (P_ + Q_)
216
+
217
+ return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))
218
+
219
+
220
+ def downsample_pc(points, n):
221
+ sample_idx = random.sample(list(range(points.shape[0])), n)
222
+ return points[sample_idx]
223
+
224
+
225
+ def normalize_pc(points):
226
+ scale = np.max(np.abs(points))
227
+ points = points / scale
228
+ return points
229
+
230
+
231
+ def collect_pc(cad_folder):
232
+ pc_path = find_files(os.path.join(cad_folder, 'ptl'), 'final_pcd.ply')
233
+ if len(pc_path) == 0:
234
+ return []
235
+ pc_path = pc_path[-1] # final pcd
236
+ pc = read_ply(pc_path)
237
+ if pc.shape[0] > N_POINTS:
238
+ pc = downsample_pc(pc, N_POINTS)
239
+ pc = normalize_pc(pc)
240
+ return pc
241
+
242
+ def collect_pc2(cad_folder):
243
+ pc = read_ply(cad_folder)
244
+ if pc.shape[0] > N_POINTS:
245
+ pc = downsample_pc(pc, N_POINTS)
246
+ pc = normalize_pc(pc)
247
+ return pc
248
+
249
+
250
+ def main():
251
+ parser = argparse.ArgumentParser()
252
+ parser.add_argument("--fake", type=str)
253
+ parser.add_argument("--real", type=str)
254
+ parser.add_argument("--output", type=str)
255
+ split = 1
256
+ args = parser.parse_args()
257
+ if args.output is None:
258
+ args.output = args.fake + '_cad_results.txt'
259
+ chamfer_dist = ChamferDistance()
260
+ cd = []
261
+ for i in tqdm(range(952)):
262
+ fake_pcs = []
263
+ real_pcs = []
264
+ for j in range(split):
265
+ fake_index = i * split + j
266
+ fake_folder = os.path.join(args.fake, f'{fake_index:06d}')
267
+ if not os.path.exists(fake_folder):
268
+ continue
269
+ else:
270
+ fake_pc = collect_pc(fake_folder)
271
+ if len(fake_pc) == 0:
272
+ continue
273
+ fake_pcs.append(fake_pc)
274
+
275
+ real_folder = os.path.join(args.real, f'{i:06d}')
276
+ if not os.path.exists(real_folder):
277
+ continue
278
+ else:
279
+ real_pc = collect_pc(real_folder)
280
+ if len(real_pc) == 0:
281
+ continue
282
+ real_pcs.append(real_pc)
283
+
284
+ if len(fake_pcs) == 0 or len(real_pcs) == 0:
285
+ continue
286
+ sample_pcs = np.stack(fake_pcs, axis=0)
287
+ ref_pcs = np.stack(real_pcs, axis=0)
288
+
289
+ sample_pcs = torch.tensor(sample_pcs, dtype=torch.float32).cuda()
290
+ ref_pcs = torch.tensor(ref_pcs, dtype=torch.float32).cuda()
291
+ print(sample_pcs.shape, ref_pcs.shape)
292
+ dl, dr, idx1, idx2 = chamfer_dist(sample_pcs, ref_pcs)
293
+ min_val = (dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1).squeeze(0).min().item()
294
+ cd.append(min_val)
295
+
296
+ cd = np.array(cd)
297
+ mean = np.mean(cd)
298
+ median = np.median(cd)
299
+ print('mean:', mean)
300
+ print('median:', median)
301
+
302
+
303
+ if __name__ == '__main__':
304
+ import time
305
+ start_time = time.time()
306
+ main()
307
+ end_time = time.time()
308
+ print(end_time - start_time)
CADFusion/src/test/dist_eval.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import os
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ import random
7
+ import warnings
8
+ from glob import glob
9
+ from scipy.stats import entropy
10
+ from sklearn.neighbors import NearestNeighbors
11
+ from plyfile import PlyData
12
+ from pathlib import Path
13
+ from multiprocessing import Pool
14
+ from chamfer_distance import ChamferDistance
15
+
16
+ random.seed(0)
17
+ N_POINTS = 2000
18
+ NUM_TRHEADS = 16
19
+
20
+
21
+ def find_files(folder, extension):
22
+ return sorted([Path(os.path.join(folder, f)) for f in os.listdir(folder) if f.endswith(extension)])
23
+
24
+
25
+ def read_ply(path):
26
+ with open(path, 'rb') as f:
27
+ plydata = PlyData.read(f)
28
+ x = np.array(plydata['vertex']['x'])
29
+ y = np.array(plydata['vertex']['y'])
30
+ z = np.array(plydata['vertex']['z'])
31
+ vertex = np.stack([x, y, z], axis=1)
32
+ return vertex
33
+
34
+
35
+ def distChamfer(a, b):
36
+ x, y = a, b
37
+ bs, num_points, points_dim = x.size()
38
+ xx = torch.bmm(x, x.transpose(2, 1))
39
+ yy = torch.bmm(y, y.transpose(2, 1))
40
+ zz = torch.bmm(x, y.transpose(2, 1))
41
+ diag_ind = torch.arange(0, num_points).to(a).long()
42
+ rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
43
+ ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
44
+ P = (rx.transpose(2, 1) + ry - 2 * zz)
45
+ return P.min(1)[0], P.min(2)[0]
46
+
47
+
48
+ def _pairwise_CD(sample_pcs, ref_pcs, batch_size):
49
+ N_sample = sample_pcs.shape[0]
50
+ N_ref = ref_pcs.shape[0]
51
+ all_cd = []
52
+ all_emd = []
53
+ iterator = range(N_sample)
54
+ matched_gt = []
55
+ pbar = tqdm(iterator)
56
+ chamfer_dist = ChamferDistance()
57
+
58
+ for sample_b_start in pbar:
59
+ sample_batch = sample_pcs[sample_b_start]
60
+
61
+ cd_lst = []
62
+ emd_lst = []
63
+ for ref_b_start in range(0, N_ref, batch_size):
64
+ ref_b_end = min(N_ref, ref_b_start + batch_size)
65
+ ref_batch = ref_pcs[ref_b_start:ref_b_end]
66
+
67
+ batch_size_ref = ref_batch.size(0)
68
+ sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1)
69
+ sample_batch_exp = sample_batch_exp.contiguous()
70
+
71
+ dl, dr, idx1, idx2 = chamfer_dist(sample_batch_exp,ref_batch)
72
+ cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))
73
+
74
+ cd_lst = torch.cat(cd_lst, dim=1)
75
+ all_cd.append(cd_lst)
76
+
77
+ hit = np.argmin(cd_lst.detach().cpu().numpy()[0])
78
+ matched_gt.append(hit)
79
+ pbar.set_postfix({"cov": len(np.unique(matched_gt)) * 1.0 / N_ref})
80
+
81
+ all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref
82
+
83
+ return all_cd
84
+
85
+
86
+ def compute_cov_mmd(sample_pcs, ref_pcs, batch_size):
87
+ all_dist = _pairwise_CD(sample_pcs, ref_pcs, batch_size)
88
+ print(all_dist.shape, flush=True)
89
+ N_sample, N_ref = all_dist.size(0), all_dist.size(1)
90
+ min_val_fromsmp, min_idx = torch.min(all_dist, dim=1)
91
+ min_val, _ = torch.min(all_dist, dim=0)
92
+ mmd = min_val.mean()
93
+ cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)
94
+ cov = torch.tensor(cov).to(all_dist)
95
+
96
+ return {
97
+ # 'med-CD': torch.diagonal(all_dist).median().item(),
98
+ 'avg-CD': torch.diagonal(all_dist).mean().item(),
99
+ 'COV-CD': cov.item(),
100
+ 'MMD-CD': mmd.item()
101
+ }
102
+
103
+
104
+ def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, in_unit_sphere, resolution=28):
105
+ '''Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models For 3D Point Clouds```.
106
+ Args:
107
+ sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points.
108
+ ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points.
109
+ resolution: (int) grid-resolution. Affects granularity of measurements.
110
+ '''
111
+ sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1]
112
+ ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1]
113
+ return jensen_shannon_divergence(sample_grid_var, ref_grid_var)
114
+
115
+
116
+ def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False):
117
+ '''Given a collection of point-clouds, estimate the entropy of the random variables
118
+ corresponding to occupancy-grid activation patterns.
119
+ Inputs:
120
+ pclouds: (numpy array) #point-clouds x points per point-cloud x 3
121
+ grid_resolution (int) size of occupancy grid that will be used.
122
+ '''
123
+ epsilon = 10e-4
124
+ bound = 1 + epsilon
125
+ if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
126
+ print(abs(np.max(pclouds)), abs(np.min(pclouds)))
127
+ warnings.warn('Point-clouds are not in unit cube.')
128
+
129
+ if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:
130
+ warnings.warn('Point-clouds are not in unit sphere.')
131
+
132
+ grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
133
+ grid_coordinates = grid_coordinates.reshape(-1, 3)
134
+ grid_counters = np.zeros(len(grid_coordinates))
135
+ grid_bernoulli_rvars = np.zeros(len(grid_coordinates))
136
+ nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates)
137
+
138
+ for pc in pclouds:
139
+ _, indices = nn.kneighbors(pc)
140
+ indices = np.squeeze(indices)
141
+ for i in indices:
142
+ grid_counters[i] += 1
143
+ indices = np.unique(indices)
144
+ for i in indices:
145
+ grid_bernoulli_rvars[i] += 1
146
+
147
+ acc_entropy = 0.0
148
+ n = float(len(pclouds))
149
+ for g in grid_bernoulli_rvars:
150
+ p = 0.0
151
+ if g > 0:
152
+ p = float(g) / n
153
+ acc_entropy += entropy([p, 1.0 - p])
154
+
155
+ return acc_entropy / len(grid_counters), grid_counters
156
+
157
+
158
+ def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
159
+ '''Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,
160
+ that is placed in the unit-cube.
161
+ If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere.
162
+ '''
163
+ grid = np.ndarray((resolution, resolution, resolution, 3), np.float32)
164
+ spacing = 1.0 / float(resolution - 1) * 2
165
+ for i in range(resolution):
166
+ for j in range(resolution):
167
+ for k in range(resolution):
168
+ grid[i, j, k, 0] = i * spacing - 0.5 * 2
169
+ grid[i, j, k, 1] = j * spacing - 0.5 * 2
170
+ grid[i, j, k, 2] = k * spacing - 0.5 * 2
171
+
172
+ if clip_sphere:
173
+ grid = grid.reshape(-1, 3)
174
+ grid = grid[np.linalg.norm(grid, axis=1) <= 0.5]
175
+
176
+ return grid, spacing
177
+
178
+
179
+ def jensen_shannon_divergence(P, Q):
180
+ if np.any(P < 0) or np.any(Q < 0):
181
+ raise ValueError('Negative values.')
182
+ if len(P) != len(Q):
183
+ raise ValueError('Non equal size.')
184
+
185
+ P_ = P / np.sum(P) # Ensure probabilities.
186
+ Q_ = Q / np.sum(Q)
187
+
188
+ e1 = entropy(P_, base=2)
189
+ e2 = entropy(Q_, base=2)
190
+ e_sum = entropy((P_ + Q_) / 2.0, base=2)
191
+ res = e_sum - ((e1 + e2) / 2.0)
192
+
193
+ res2 = _jsdiv(P_, Q_)
194
+
195
+ if not np.allclose(res, res2, atol=10e-5, rtol=0):
196
+ warnings.warn('Numerical values of two JSD methods don\'t agree.')
197
+
198
+ return res
199
+
200
+
201
+ def _jsdiv(P, Q):
202
+ '''another way of computing JSD'''
203
+
204
+ def _kldiv(A, B):
205
+ a = A.copy()
206
+ b = B.copy()
207
+ idx = np.logical_and(a > 0, b > 0)
208
+ a = a[idx]
209
+ b = b[idx]
210
+ return np.sum([v for v in a * np.log2(a / b)])
211
+
212
+ P_ = P / np.sum(P)
213
+ Q_ = Q / np.sum(Q)
214
+
215
+ M = 0.5 * (P_ + Q_)
216
+
217
+ return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))
218
+
219
+
220
+ def downsample_pc(points, n):
221
+ sample_idx = random.sample(list(range(points.shape[0])), n)
222
+ return points[sample_idx]
223
+
224
+
225
+ def normalize_pc(points):
226
+ scale = np.max(np.abs(points))
227
+ points = points / scale
228
+ return points
229
+
230
+
231
+ def collect_pc(cad_folder):
232
+ pc_path = find_files(os.path.join(cad_folder, 'ptl'), 'final_pcd.ply')
233
+ if len(pc_path) == 0:
234
+ return []
235
+ pc_path = pc_path[-1] # final pcd
236
+ pc = read_ply(pc_path)
237
+ if pc.shape[0] > N_POINTS:
238
+ pc = downsample_pc(pc, N_POINTS)
239
+ pc = normalize_pc(pc)
240
+ return pc
241
+
242
+ def collect_pc2(cad_folder):
243
+ pc = read_ply(cad_folder)
244
+ if pc.shape[0] > N_POINTS:
245
+ pc = downsample_pc(pc, N_POINTS)
246
+ pc = normalize_pc(pc)
247
+ return pc
248
+
249
+
250
+ def main():
251
+ parser = argparse.ArgumentParser()
252
+ parser.add_argument("--fake", type=str)
253
+ parser.add_argument("--real", type=str)
254
+ parser.add_argument("--output", type=str)
255
+ parser.add_argument("--n_test", type=int, default=200)
256
+ parser.add_argument("--multi", type=int, default=1)
257
+ parser.add_argument("--times", type=int, default=10)
258
+ parser.add_argument("--batch_size", type=int, default=64)
259
+ args = parser.parse_args()
260
+
261
+ print("n_test: {}, multiplier: {}, repeat times: {}".format(args.n_test, args.multi, args.times))
262
+ if args.output is None:
263
+ args.output = args.fake + '_cad_results.txt'
264
+
265
+ # Load fake pcd
266
+
267
+ fake_folders = sorted(glob(args.fake+'/*/'))
268
+ real_folders = sorted(glob(args.real+'/*/'))
269
+
270
+ fake_overlapped = []
271
+ real_overlapped = []
272
+ for i in range(800):
273
+ if f'{args.fake}/{i:06d}/' in fake_folders and f'{args.real}/{i:06d}/' in real_folders:
274
+ if len(glob(f'{args.fake}/{i:06d}/ptl/*')) > 0 and len(glob(f'{args.real}/{i:06d}/ptl/*')) > 0:
275
+ fake_overlapped.append(f'{args.fake}/{i:06d}/')
276
+ real_overlapped.append(f'{args.real}/{i:06d}/')
277
+ print(len(fake_overlapped), len(real_overlapped))
278
+
279
+ fake_folders = fake_overlapped
280
+ real_folders = real_overlapped
281
+
282
+ sample_pcs = []
283
+ load_iter = Pool(NUM_TRHEADS).imap(collect_pc, fake_folders)
284
+ for pc in tqdm(load_iter, total=len(fake_folders)):
285
+ if len(pc) > 0:
286
+ sample_pcs.append(pc)
287
+ sample_pcs = np.stack(sample_pcs, axis=0)
288
+ print("fake point clouds: {}".format(sample_pcs.shape))
289
+
290
+ # Load reference pcd
291
+ ref_pcs = []
292
+ load_iter = Pool(NUM_TRHEADS).imap(collect_pc, real_folders)
293
+ for pc in tqdm(load_iter, total=len(real_folders)):
294
+ if len(pc) > 0:
295
+ ref_pcs.append(pc)
296
+ ref_pcs = np.stack(ref_pcs, axis=0)
297
+ print("real point clouds: {}".format(ref_pcs.shape))
298
+
299
+ # # Testing
300
+ fp = open(args.output, "w")
301
+
302
+ rand_sample_pcs = sample_pcs
303
+ rand_ref_pcs = ref_pcs
304
+
305
+ jsd = jsd_between_point_cloud_sets(rand_sample_pcs, rand_ref_pcs, in_unit_sphere=False)
306
+ with torch.no_grad():
307
+ rand_sample_pcs = torch.tensor(rand_sample_pcs).cuda()
308
+ rand_ref_pcs = torch.tensor(rand_ref_pcs).cuda()
309
+ result = compute_cov_mmd(rand_sample_pcs, rand_ref_pcs, batch_size=args.batch_size)
310
+ result.update({"JSD": jsd})
311
+
312
+ print(result)
313
+ print(result, file=fp)
314
+ fp.close()
315
+
316
+ # Testing
317
+ # fp = open(args.output, "w")
318
+ # result_list = []
319
+ # for i in range(args.times):
320
+ # print("iteration {}...".format(i))
321
+ # select_idx = random.sample(list(range(len(sample_pcs))), int(args.multi * args.n_test))
322
+ # rand_sample_pcs = sample_pcs[select_idx]
323
+
324
+ # select_idx = random.sample(list(range(len(ref_pcs))), args.n_test)
325
+ # rand_ref_pcs = ref_pcs[select_idx]
326
+
327
+ # jsd = jsd_between_point_cloud_sets(rand_sample_pcs, rand_ref_pcs, in_unit_sphere=False)
328
+ # with torch.no_grad():
329
+ # rand_sample_pcs = torch.tensor(rand_sample_pcs).cuda()
330
+ # rand_ref_pcs = torch.tensor(rand_ref_pcs).cuda()
331
+ # result = compute_cov_mmd(rand_sample_pcs, rand_ref_pcs, batch_size=args.batch_size)
332
+ # result.update({"JSD": jsd})
333
+
334
+ # print(result)
335
+ # print(result, file=fp)
336
+ # result_list.append(result)
337
+ # avg_result = {}
338
+ # for k in result_list[0].keys():
339
+ # avg_result.update({"avg-" + k: np.mean([x[k] for x in result_list])})
340
+ # print("average result:")
341
+ # print(avg_result)
342
+ # print(avg_result, file=fp)
343
+ # fp.close()
344
+
345
+
346
+ if __name__ == '__main__':
347
+ import time
348
+ start_time = time.time()
349
+ main()
350
+ end_time = time.time()
351
+ print(end_time - start_time)
CADFusion/src/test/f1_eval.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import argparse
4
+
5
+ """
6
+ We did not implement the Hungarian matching algorithm from text2cad, but provided a vanilla matching for f1. It is because
7
+ 1. We argue that CAD scenarios are too complicated to be evaluated with a simple matching algorithm, especially when performed on the primitive level. Moreover, matching every primitive exactly is against the goal of our framework which attempt to encourage CAD models generate visually correct objects instead of accurate primitives compared to the ground truth.
8
+ 2. In our exploration, discrepancies on the number of primitives between model generation and the ground truth usually indicates the entire failure of the sketch so that using any of the algorithm does not affect the final evaluation result anyway.
9
+ 3. Our evaluation is a lower bound of the performance of the model on the matching algorithm, therefore it does not affect the overall integrety of our framework.
10
+
11
+ We encourage users to implement their own matching algorithm if they want to evaluate the model with a more strict metric.
12
+ """
13
+
14
+ parser = argparse.ArgumentParser(description='Evaluate F1 scores for generated sketches.')
15
+ parser.add_argument('--test-path', type=str, default='data/sl_data/test.jsonl', help='Path to the JSONL file containing test data')
16
+ parser.add_argument('--file_path', type=str, required=True, help='Path to the JSONL file containing generated sketches.')
17
+ args = parser.parse_args()
18
+ file_path = args.file_path
19
+ data_path = args.test_path
20
+ with open(data_path, 'r') as f:
21
+ data = json.load(f)
22
+
23
+ def find_f1(ground_truth, pred, token):
24
+ num_tok_gt = len(re.findall(token, ground_truth))
25
+ num_tok_pred = len(re.findall(token, pred))
26
+ # print(num_tok_gt, num_tok_pred)
27
+ min_tok = min(num_tok_gt, num_tok_pred)
28
+ if min_tok <= 0:
29
+ return -1
30
+ tok_recall = min_tok / num_tok_gt
31
+ tok_precision = min_tok / num_tok_pred
32
+ tok_f1 = 2 * tok_recall * tok_precision / (tok_recall + tok_precision)
33
+ return tok_f1
34
+
35
+ with open(file_path, 'r') as f:
36
+ gen = json.load(f)
37
+ line = []
38
+ arc = []
39
+ circle = []
40
+ ext = []
41
+ for i in range(1000):
42
+ ground_truth = data[i]['output']
43
+ pred = gen[i]['output']
44
+ ext_f1 = find_f1(ground_truth, pred, r'<extrude_end>')
45
+ if ext_f1 > 0:
46
+ ext.append(ext_f1)
47
+
48
+ skext_gt = ground_truth.split('<extrude_end>')[:-1]
49
+ skext_pred = pred.split('<extrude_end>')[:-1]
50
+ min_len_skext = min(len(skext_gt), len(skext_pred))
51
+ if min_len_skext == 0:
52
+ continue
53
+ line_f1 = 0
54
+ arc_f1 = 0
55
+ circle_f1 = 0
56
+ for gt, pr in zip(skext_gt, skext_pred):
57
+ line_f1 += find_f1(gt, pr, r'line.*?<curve_end>')
58
+ arc_f1 += find_f1(gt, pr, r'arc.*?<curve_end>')
59
+ circle_f1 += find_f1(gt, pr, r'circle.*?<curve_end>')
60
+
61
+ line_f1 = line_f1 / min_len_skext
62
+ arc_f1 = arc_f1 / min_len_skext
63
+ circle_f1 = circle_f1 / min_len_skext
64
+ if line_f1 > 0:
65
+ line.append(line_f1)
66
+ if arc_f1 > 0:
67
+ arc.append(arc_f1)
68
+ if circle_f1 > 0:
69
+ circle.append(circle_f1)
70
+ line_avg = sum(line) / len(line)
71
+ arc_avg = sum(arc) / len(arc)
72
+ circle_avg = sum(circle) / len(circle)
73
+ avgf1 = (line_avg + arc_avg + circle_avg) / 3
74
+ print(file_path, line_avg, arc_avg, circle_avg, avgf1, sum(ext) / len(ext))
CADFusion/src/test/generate.ipynb ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "2d243f81",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import argparse\n",
11
+ "import random\n",
12
+ "import os\n",
13
+ "import subprocess\n",
14
+ "import shutil\n",
15
+ "\n",
16
+ "from PIL import Image\n",
17
+ "from huggingface_hub import login\n",
18
+ "from utils import MAX_LENGTH, prepare_model_and_tokenizer\n",
19
+ "from visual_utils.parser import CADparser, write_obj_sample\n",
20
+ "from IPython.display import clear_output"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "markdown",
25
+ "id": "b98812ed",
26
+ "metadata": {},
27
+ "source": [
28
+ "### Initializing model and arguments"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 49,
34
+ "id": "df625563",
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "parser = argparse.ArgumentParser()\n",
39
+ "# parser.add_argument(\"--model-name\", type=str, default=\"llama3\")\n",
40
+ "parser.add_argument(\"--device-map\", type=str, default='auto')\n",
41
+ "parser.add_argument(\"--lora-rank\", type=int, default=32)\n",
42
+ "parser.add_argument(\"--lora-alpha\", type=int, default=32)\n",
43
+ "parser.add_argument(\"--lora-dropout\", type=float, default=0.05)\n",
44
+ "parser.add_argument(\"--pretrained-path\", type=str, required=True)\n",
45
+ "parser.add_argument(\"--top-p\", type=float, default=0.9)\n",
46
+ "parser.add_argument(\"--temperature\", type=float, default=0.9)\n",
47
+ "\n",
48
+ "arguments = ['--pretrained-path', '/home/v-wangruiyu/repos/CADFusion/exp/model_ckpt/CADFusion_v1_1', '--temperature', '0.3']\n",
49
+ "args = parser.parse_args(arguments)"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "id": "5624f320",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "login() # put your own hf token to access llama\n",
60
+ "random.seed(0)\n",
61
+ "model, tokenizer = prepare_model_and_tokenizer(args)\n",
62
+ "model.eval()\n",
63
+ "clear_output()"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "markdown",
68
+ "id": "86b9cb09",
69
+ "metadata": {},
70
+ "source": [
71
+ "### Custom prompting"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": 180,
77
+ "id": "db06d560",
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "description = input(\"Please input a description of a 3D shape: \")\n",
82
+ "# description = 'The 3D shape is a cylinder.'\n",
83
+ "\n",
84
+ "prompt = 'Below is a description of a 3D shape:\\n'\n",
85
+ "prompt += description\n",
86
+ "prompt += '\\nGenerate a Computer-Aided Design(CAD) command sequence of the 3D shape:\\n'"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "markdown",
91
+ "id": "bb16f861",
92
+ "metadata": {},
93
+ "source": [
94
+ "### Inference and rendering"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "markdown",
99
+ "id": "59c5f38e",
100
+ "metadata": {},
101
+ "source": [
102
+ "#### Model Inference"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": 181,
108
+ "id": "ab5ff2e8",
109
+ "metadata": {},
110
+ "outputs": [
111
+ {
112
+ "name": "stderr",
113
+ "output_type": "stream",
114
+ "text": [
115
+ "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
116
+ ]
117
+ },
118
+ {
119
+ "data": {
120
+ "text/plain": [
121
+ "'circle,31,53,31,9,53,31,9,31 <curve_end> <loop_end> circle,31,51,31,11,51,31,11,31 <curve_end> <loop_end> <face_end> circle,31,51,31,11,51,31,11,31 <curve_end> <loop_end> <face_end> <sketch_end> add,0,62,31,31,31,1,0,0,0,0,1,0,-1,0,7,31,31 <extrude_end>'"
122
+ ]
123
+ },
124
+ "execution_count": 181,
125
+ "metadata": {},
126
+ "output_type": "execute_result"
127
+ }
128
+ ],
129
+ "source": [
130
+ "batch = tokenizer(\n",
131
+ " prompt,\n",
132
+ " return_tensors=\"pt\",\n",
133
+ ")\n",
134
+ "batch = {k: v.cuda() for k, v in batch.items()}\n",
135
+ "\n",
136
+ "generate_ids = model.generate(\n",
137
+ " **batch,\n",
138
+ " do_sample=True,\n",
139
+ " max_new_tokens=MAX_LENGTH,\n",
140
+ " temperature=args.temperature,\n",
141
+ " top_p=args.top_p,\n",
142
+ " repetition_penalty=1.3,\n",
143
+ ")\n",
144
+ "\n",
145
+ "gen_strs = tokenizer.batch_decode(\n",
146
+ " generate_ids,\n",
147
+ " skip_special_tokens=True,\n",
148
+ " clean_up_tokenization_spaces=False,\n",
149
+ ")\n",
150
+ "gen_strs = gen_strs[0][len(prompt):]\n",
151
+ "gen_strs"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "markdown",
156
+ "id": "f56d6fcf",
157
+ "metadata": {},
158
+ "source": [
159
+ "#### Render .obj file"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": 182,
165
+ "id": "95498ccb",
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "out_path = 'visual_cache/gen_obj'\n",
170
+ "# remove the existing output directory if it exists\n",
171
+ "if os.path.exists(out_path):\n",
172
+ " shutil.rmtree(out_path)\n",
173
+ "# create the output directory\n",
174
+ "os.makedirs(out_path, exist_ok=True)\n",
175
+ "\n",
176
+ "cad_parser = CADparser(bit=6)\n",
177
+ "parsed_data = cad_parser.perform(gen_strs)\n",
178
+ "write_obj_sample(out_path, parsed_data)"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "markdown",
183
+ "id": "79b5dfaf",
184
+ "metadata": {},
185
+ "source": [
186
+ "#### Render .step, .stl, .ply files\n",
187
+ "N.B. if the Statistics on Transfer logs do not show up, the model may not have produced renderable outputs. Re-run the inference or change your prompt to see if it gets better results. "
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": null,
193
+ "id": "8a49694f",
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "out_path = os.path.abspath(out_path)\n",
198
+ "py_path = os.path.abspath('../rendering_utils/parser_visual.py')\n",
199
+ "subprocess.run(['python3', py_path, '--data_folder', out_path, '--single-file'])\n",
200
+ "py_path = os.path.abspath('../rendering_utils/ptl_sampler.py')\n",
201
+ "subprocess.run(['python3', py_path, '--in_dir', out_path, '--out_dir', 'ptl', '--single-file'])\n",
202
+ "# clear_output()"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "markdown",
207
+ "id": "0e0f1fd1",
208
+ "metadata": {},
209
+ "source": [
210
+ "#### Image rendering"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": null,
216
+ "id": "586f3a91",
217
+ "metadata": {},
218
+ "outputs": [],
219
+ "source": [
220
+ "visual_obj_path = 'visual_cache'\n",
221
+ "output_figure_path = 'visual_cache/figures'\n",
222
+ "if os.path.exists(output_figure_path):\n",
223
+ " shutil.rmtree(output_figure_path)\n",
224
+ "py_path = os.path.abspath('../rendering_utils/img_renderer.py')\n",
225
+ "os.makedirs(output_figure_path, exist_ok=True)\n",
226
+ "try:\n",
227
+ " xvfb_process = subprocess.Popen(\n",
228
+ " [\"Xvfb\", \":99\", \"-screen\", \"0\", \"640x480x24\"],\n",
229
+ " stdout=subprocess.DEVNULL,\n",
230
+ " stderr=subprocess.DEVNULL\n",
231
+ " )\n",
232
+ " print(\"Xvfb started in the background.\")\n",
233
+ "except FileNotFoundError:\n",
234
+ " print(\"Error: Xvfb not found. Please ensure it is installed and in your system's PATH.\")\n",
235
+ "\n",
236
+ "os.environ['DISPLAY'] = ':99'\n",
237
+ "try:\n",
238
+ " subprocess.run(\n",
239
+ " ['python3', py_path, '--input_dir', visual_obj_path, '--output_dir', output_figure_path]\n",
240
+ " )\n",
241
+ " print(\"Rendering script completed successfully.\")\n",
242
+ "finally:\n",
243
+ " if xvfb_process.poll() is None: # Check if Xvfb is still running\n",
244
+ " xvfb_process.terminate()\n",
245
+ " print(\"Xvfb terminated.\")\n",
246
+ " else:\n",
247
+ " print(\"Xvfb already exited.\")\n",
248
+ " \n",
249
+ "del os.environ['DISPLAY']\n",
250
+ "clear_output()\n",
251
+ "\n",
252
+ "input_image_path = os.path.join(output_figure_path, 'gen_ob.png')\n",
253
+ "if os.path.exists(input_image_path):\n",
254
+ " img = Image.open(input_image_path)\n",
255
+ " img.show()\n",
256
+ "else:\n",
257
+ " print(f\"{input_image_path} does not exist.\")"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "markdown",
262
+ "id": "c78fed0f",
263
+ "metadata": {},
264
+ "source": [
265
+ "#### Files retrieval\n",
266
+ "By default, the produced step, stl, obj and ply files are stored under the visual_cache folder. You can save them to your custom places for further use. Do not put them in the cache folder as they will be deleted after the next run."
267
+ ]
268
+ }
269
+ ],
270
+ "metadata": {
271
+ "kernelspec": {
272
+ "display_name": "cdfs",
273
+ "language": "python",
274
+ "name": "python3"
275
+ },
276
+ "language_info": {
277
+ "codemirror_mode": {
278
+ "name": "ipython",
279
+ "version": 3
280
+ },
281
+ "file_extension": ".py",
282
+ "mimetype": "text/x-python",
283
+ "name": "python",
284
+ "nbconvert_exporter": "python",
285
+ "pygments_lexer": "ipython3",
286
+ "version": "3.9.23"
287
+ }
288
+ },
289
+ "nbformat": 4,
290
+ "nbformat_minor": 5
291
+ }
CADFusion/src/test/inference.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import random
4
+
5
+ from huggingface_hub import login
6
+ from tqdm import tqdm
7
+ from utils import MAX_LENGTH, prepare_model_and_tokenizer
8
+
9
+ login()
10
+
11
+ random.seed(0)
12
+
13
+ def conditional_sample(args):
14
+ model, tokenizer = prepare_model_and_tokenizer(args)
15
+
16
+ model.eval()
17
+ with open(args.in_path, 'r', encoding='utf-8') as file:
18
+ data = json.load(file)
19
+
20
+ print(data[0])
21
+ data = [item for item in data if item['description'] != 'null']
22
+
23
+ global_count=0
24
+ responses = []
25
+ if args.full:
26
+ data=data
27
+ else:
28
+ random.shuffle(data)
29
+ data = data[:args.sample_len]
30
+
31
+ for item in tqdm(data):
32
+ prompts = []
33
+ for _ in range(args.num_samples):
34
+ prompt = 'Below is a description of a 3D shape:\n'
35
+ prompt += item['description']
36
+ prompt += '\nGenerate a Computer-Aided Design(CAD) command sequence of the 3D shape:\n'
37
+
38
+ prompts.append(prompt)
39
+
40
+ outputs = []
41
+
42
+ while len(outputs) < args.num_samples:
43
+ batch_prompts = prompts[len(outputs) : len(outputs) + args.batch_size]
44
+
45
+ batch = tokenizer(
46
+ list(batch_prompts),
47
+ return_tensors="pt",
48
+ )
49
+ batch = {k: v.cuda() for k, v in batch.items()}
50
+
51
+ generate_ids = model.generate(
52
+ **batch,
53
+ do_sample=True,
54
+ max_new_tokens=MAX_LENGTH,
55
+ temperature=args.temperature,
56
+ top_p=args.top_p,
57
+ repetition_penalty=1.3,
58
+ )
59
+
60
+ gen_strs = tokenizer.batch_decode(
61
+ generate_ids,
62
+ skip_special_tokens=True,
63
+ clean_up_tokenization_spaces=False,
64
+ )
65
+
66
+ outputs.extend(gen_strs)
67
+ print(f"Generated {len(outputs)}/{args.num_samples}samples.")
68
+
69
+ for prompt, output in zip(prompts, outputs):
70
+ result = {
71
+ 'index': global_count,
72
+ # 'pic_name': item['pic_name'],
73
+ 'ground_truth': item['command_sequence'],
74
+ 'description': item['description'],
75
+ 'prompt': prompt,
76
+ 'output': output[len(prompt):]
77
+ }
78
+ if 'original_seq' in item.keys():
79
+ result['original_seq'] = item['original_seq']
80
+ responses.append(result)
81
+ global_count += 1
82
+
83
+ with open(args.out_path, "w+") as f:
84
+ json.dump(responses, f, indent=4)
85
+
86
+
87
+
88
+ if __name__ == "__main__":
89
+ parser = argparse.ArgumentParser()
90
+ parser.add_argument("--model-name", type=str, default="llama3")
91
+ parser.add_argument("--lora-rank", type=int, default=32)
92
+ parser.add_argument("--lora-alpha", type=int, default=32)
93
+ parser.add_argument("--lora-dropout", type=float, default=0.05)
94
+ parser.add_argument("--sample-len", type=int, default=100)
95
+ parser.add_argument("--pretrained-path", type=str, required=True)
96
+ parser.add_argument("--num-samples", type=int, default=500)
97
+ parser.add_argument("--batch-size", type=int, default=32)
98
+ parser.add_argument("--in-path", type=str, default="test_description.json")
99
+ parser.add_argument("--out-path", type=str, default="cad_samples.jsonl")
100
+ parser.add_argument("--temperature", type=float, default=0.9)
101
+ parser.add_argument("--device-map", type=str, default='auto')
102
+ parser.add_argument("--top-p", type=float, default=0.9)
103
+ parser.add_argument("--full", action="store_true", default=False)
104
+ args = parser.parse_args()
105
+
106
+ conditional_sample(args)
CADFusion/src/test/utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from peft import LoraConfig, PeftModel, get_peft_model
4
+
5
+ IGNORE_INDEX = -100
6
+ MAX_LENGTH = 512
7
+ DEFAULT_PAD_TOKEN = "[PAD]"
8
+ DEFAULT_EOS_TOKEN = "</s>"
9
+ DEFAULT_BOS_TOKEN = "<s>"
10
+ DEFAULT_UNK_TOKEN = "<unk>"
11
+
12
+ def smart_tokenizer_and_embedding_resize(
13
+ special_tokens_dict,
14
+ llama_tokenizer,
15
+ model,
16
+ ):
17
+ """Resize tokenizer and embedding.
18
+
19
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
20
+ """
21
+ num_new_tokens = llama_tokenizer.add_special_tokens(special_tokens_dict)
22
+ model.resize_token_embeddings(len(llama_tokenizer))
23
+
24
+ if num_new_tokens > 0:
25
+ input_embeddings = model.get_input_embeddings().weight.data
26
+ output_embeddings = model.get_output_embeddings().weight.data
27
+
28
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
29
+ dim=0, keepdim=True
30
+ )
31
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
32
+ dim=0, keepdim=True
33
+ )
34
+
35
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
36
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
37
+
38
+ def prepare_model_and_tokenizer(args):
39
+ model_id = "meta-llama/Meta-Llama-3-8B"
40
+ print(f"Model size: {model_id}")
41
+ if hasattr(args, 'device_map'):
42
+ device_map = args.device_map
43
+ else:
44
+ device_map = 'auto'
45
+ pipeline = transformers.pipeline("text2text-generation",
46
+ model=model_id, model_kwargs={"torch_dtype": torch.float32}, device_map=device_map)
47
+ tokenizer = pipeline.tokenizer
48
+ base_model = pipeline.model
49
+
50
+ special_tokens_dict = dict()
51
+ if tokenizer.pad_token is None:
52
+ special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
53
+ if tokenizer.eos_token is None:
54
+ special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
55
+ if tokenizer.bos_token is None:
56
+ special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
57
+ if tokenizer.unk_token is None:
58
+ special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
59
+
60
+ smart_tokenizer_and_embedding_resize(
61
+ special_tokens_dict=special_tokens_dict,
62
+ llama_tokenizer=tokenizer,
63
+ model=base_model,
64
+ )
65
+
66
+ peft_config = LoraConfig(
67
+ r=args.lora_rank,
68
+ lora_alpha=args.lora_alpha,
69
+ lora_dropout=args.lora_dropout,
70
+ bias="none",
71
+ task_type="CAUSAL_LM",
72
+ )
73
+
74
+ tokenizer.padding_side = 'left'
75
+ peftmodel = get_peft_model(base_model, peft_config)
76
+ if args.pretrained_path:
77
+ # load a previous checkpoint if the path is given
78
+ model = PeftModel.from_pretrained(base_model, args.pretrained_path, device_map=device_map)
79
+ peft_state_dict = {f"{k}": v for k, v in model.state_dict().items()}
80
+ peftmodel.load_state_dict(peft_state_dict)
81
+
82
+ for name, param in peftmodel.named_parameters():
83
+ if "lora" in name: # Check if "lora" is in the parameter's name
84
+ param.requires_grad = True
85
+ peftmodel.print_trainable_parameters()
86
+ return peftmodel, tokenizer
CADFusion/src/test/visual_utils/__init__.py ADDED
File without changes
CADFusion/src/test/visual_utils/parser.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from collections import OrderedDict
3
+ import re
4
+ from pathlib import Path
5
+ import argparse
6
+ import os
7
+ import json
8
+ import math
9
+
10
+ # hyperparameters from SkexGen project
11
+ SKETCH_R = 1
12
+ RADIUS_R = 1
13
+ EXTRUDE_R = 1.0
14
+ SCALE_R = 1.4
15
+ OFFSET_R = 0.9
16
+ PIX_PAD = 4
17
+ CMD_PAD = 3
18
+ COORD_PAD = 4
19
+ EXT_PAD = 1
20
+ EXTRA_PAD = 1
21
+ R_PAD = 2
22
+
23
+
24
+ class CADparser:
25
+ """Parse CAD sequence to CAD object."""
26
+
27
+ def __init__(self, bit):
28
+ self.vertex_dict = OrderedDict()
29
+ self.bit = bit
30
+
31
+ def perform(self, cad_seq):
32
+ # divide into sketch and extrude
33
+ sketches, extrudes = self.get_SE(cad_seq)
34
+ if sketches is None or extrudes is None:
35
+ return None
36
+ # sequentially parse each pair of SE into obj
37
+ se_datas = []
38
+ for sketch, extrude in zip(sketches, extrudes):
39
+ extrude_param, scale, offset = self.parse_extrude(extrude)
40
+ if extrude_param is None or scale is None or offset is None:
41
+ return None
42
+ vertex_str, se_str = self.parse_sketch(sketch, scale, offset)
43
+ if vertex_str is None or se_str is None:
44
+ return None
45
+ se_datas.append(
46
+ {"vertex": vertex_str, "curve": se_str, "extrude": extrude_param}
47
+ )
48
+ self.vertex_dict.clear()
49
+
50
+ return se_datas
51
+
52
+ def parse_sketch(self, sketch, scale, offset):
53
+ faces = self.get_faces(sketch)
54
+ if len(faces) == 0:
55
+ return None, None
56
+ se_str = ""
57
+ for face_idx, face in enumerate(faces): # each face
58
+ face_str = "face\n"
59
+ loops = self.get_loops(face)
60
+ if len(loops) == 0:
61
+ return None, None
62
+ for loop_idx, loop in enumerate(loops): # each loop
63
+ curves = self.get_curves(loop)
64
+ if len(curves) == 0:
65
+ return None, None
66
+ next_curves = curves[1:]
67
+ next_curves += curves[:1]
68
+ cur_str = []
69
+ for curve, next_curve in zip(curves, next_curves): # each curve
70
+ if not self.obj_curve(curve, next_curve, cur_str, scale, offset):
71
+ return None, None
72
+ loop_str = ""
73
+ for c in cur_str:
74
+ loop_str += f"{c}\n"
75
+ if loop_idx == 0:
76
+ face_str += f"out\n{loop_str}\n"
77
+ else:
78
+ face_str += f"in\n{loop_str}\n"
79
+ se_str += face_str
80
+ vertex_str = self.convert_vertices()
81
+ return vertex_str, se_str
82
+
83
+ def parse_extrude(self, extrude):
84
+ ext = extrude.split(",")
85
+ if len(ext) != 18:
86
+ return None, None, None
87
+
88
+ # operation str to int
89
+ ext_op = {"add": 1, "cut": 2, "intersect": 3}.get(ext[0], None)
90
+ if ext_op is None:
91
+ return None, None, None
92
+ # dequantize ext_v, ext_T, scale and offset
93
+ ext_v, ext_T, scale, offset = self.dequantize_extrude_params(ext)
94
+ # get ext_R
95
+ ext_R = np.array(ext[6:15], dtype=int)
96
+
97
+ extrude_param = {"value": ext_v, "T": ext_T, "R": ext_R, "op": ext_op}
98
+ return extrude_param, scale, offset
99
+
100
+ def obj_curve(self, curve, next_curve, cur_str, scale, offset):
101
+ cur = curve.split(",")
102
+ next_cur = next_curve.split(",")
103
+ if cur[0] == "circle":
104
+ if len(cur) != 9:
105
+ return False
106
+ p1, p2, p3, p4 = self.dequantize_circle_points(
107
+ cur, next_cur, scale, offset)
108
+ center = np.asarray([0.5 * (p1[0] + p2[0]), 0.5 * (p3[1] + p4[1])])
109
+ radius = (np.linalg.norm(p1 - p2) + np.linalg.norm(p3 - p4)) / 4.0
110
+
111
+ center = center * scale + offset
112
+ radius = radius * scale
113
+
114
+ center_idx = self.save_vertex(center[0], center[1], "p")
115
+ radius_idx = self.save_vertex(radius, 0.0, "r")
116
+ cur_str.append(f"c {center_idx} {radius_idx}")
117
+ elif cur[0] == "arc":
118
+ if len(cur) != 5:
119
+ return False
120
+ if (
121
+ cur[1:3] == cur[3:5]
122
+ or cur[1:3] == next_cur[1:3]
123
+ or cur[3:5] == next_cur[3:5]
124
+ ): # invalid arc
125
+ return False
126
+ start_v, mid_v, end_v = self.dequantize_arc_points(
127
+ cur, next_cur, scale, offset
128
+ )
129
+ try:
130
+ center, _, _, _ = find_arc_geometry(start_v, mid_v, end_v)
131
+ except Exception:
132
+ return False
133
+ start_v = start_v * scale + offset
134
+ mid_v = mid_v * scale + offset
135
+ end_v = end_v * scale + offset
136
+ center = center * scale + offset
137
+
138
+ center_idx = self.save_vertex(center[0], center[1], "p")
139
+ start_idx = self.save_vertex(start_v[0], start_v[1], "p")
140
+ mid_idx = self.save_vertex(mid_v[0], mid_v[1], "p")
141
+ end_idx = self.save_vertex(end_v[0], end_v[1], "p")
142
+ cur_str.append(f"a {start_idx} {mid_idx} {center_idx} {end_idx}")
143
+ elif cur[0] == "line":
144
+ if len(cur) != 3:
145
+ return False
146
+ if cur[1:3] == next_cur[1:3]:
147
+ return False
148
+ start_v, end_v = self.dequantize_line_points(
149
+ cur, next_cur, scale, offset)
150
+ start_v = start_v * scale + offset
151
+ end_v = end_v * scale + offset
152
+
153
+ start_idx = self.save_vertex(start_v[0], start_v[1], "p")
154
+ end_idx = self.save_vertex(end_v[0], end_v[1], "p")
155
+ cur_str.append(f"l {start_idx} {end_idx}")
156
+ else:
157
+ return False
158
+ return True
159
+
160
+ def get_SE(self, cad_seq):
161
+ # sketches: 1) between sequence start and sketch_end,
162
+ sketches_from_start = re.findall(r"^(.+?)(?=<sketch_end>)", cad_seq)
163
+ # sketches: 2) between extrude_end and sketch_end
164
+ sketches_after_extrude = re.findall(
165
+ r"(?<=<extrude_end>)(.+?)(?=<sketch_end>)", cad_seq
166
+ )
167
+ sketches = [x.strip() for x in sketches_from_start] + [
168
+ x.strip() for x in sketches_after_extrude
169
+ ]
170
+ # extrudes: between sketch_end and extrude_end
171
+ extrudes = [
172
+ x.strip() for x in re.findall(r"<sketch_end>(.+?)<extrude_end>", cad_seq)
173
+ ]
174
+ if len(sketches) != len(extrudes):
175
+ return None, None
176
+ return sketches, extrudes
177
+
178
+ def get_faces(self, sketch):
179
+ faces = sketch.split("<face_end>")
180
+ return [x.strip() for x in faces if x.strip() != ""]
181
+
182
+ def get_loops(self, face):
183
+ loops = face.split("<loop_end>")
184
+ return [x.strip() for x in loops if x.strip() != ""]
185
+
186
+ def get_curves(self, loop):
187
+ curves = loop.split("<curve_end>")
188
+ return [x.strip() for x in curves if x.strip() != ""]
189
+
190
+ def dequantize_circle_points(self, curve, next_curve, scale, offset):
191
+ p1 = dequantize_verts(
192
+ np.array(curve[1:3], dtype=int),
193
+ n_bits=self.bit,
194
+ min_range=-SKETCH_R,
195
+ max_range=SKETCH_R,
196
+ add_noise=False,
197
+ )
198
+ p2 = dequantize_verts(
199
+ np.array(curve[3:5], dtype=int),
200
+ n_bits=self.bit,
201
+ min_range=-SKETCH_R,
202
+ max_range=SKETCH_R,
203
+ add_noise=False,
204
+ )
205
+ p3 = dequantize_verts(
206
+ np.array(curve[5:7], dtype=int),
207
+ n_bits=self.bit,
208
+ min_range=-SKETCH_R,
209
+ max_range=SKETCH_R,
210
+ add_noise=False,
211
+ )
212
+ p4 = dequantize_verts(
213
+ np.array(curve[7:9], dtype=int),
214
+ n_bits=self.bit,
215
+ min_range=-SKETCH_R,
216
+ max_range=SKETCH_R,
217
+ add_noise=False,
218
+ )
219
+ return p1, p2, p3, p4
220
+
221
+ def dequantize_arc_points(self, curve, next_curve, scale, offset):
222
+ start_v = dequantize_verts(
223
+ np.array(curve[1:3], dtype=int),
224
+ n_bits=self.bit,
225
+ min_range=-SKETCH_R,
226
+ max_range=SKETCH_R,
227
+ add_noise=False,
228
+ )
229
+ mid_v = dequantize_verts(
230
+ np.array(curve[3:5], dtype=int),
231
+ n_bits=self.bit,
232
+ min_range=-SKETCH_R,
233
+ max_range=SKETCH_R,
234
+ add_noise=False,
235
+ )
236
+ end_v = dequantize_verts(
237
+ np.array(next_curve[1:3], dtype=int),
238
+ n_bits=self.bit,
239
+ min_range=-SKETCH_R,
240
+ max_range=SKETCH_R,
241
+ add_noise=False,
242
+ )
243
+ return start_v, mid_v, end_v
244
+
245
+ def dequantize_line_points(self, curve, next_curve, scale, offset):
246
+ start_v = dequantize_verts(
247
+ np.array(curve[1:3], dtype=int),
248
+ n_bits=self.bit,
249
+ min_range=-SKETCH_R,
250
+ max_range=SKETCH_R,
251
+ add_noise=False,
252
+ )
253
+ end_v = dequantize_verts(
254
+ np.array(next_curve[1:3], dtype=int),
255
+ n_bits=self.bit,
256
+ min_range=-SKETCH_R,
257
+ max_range=SKETCH_R,
258
+ add_noise=False,
259
+ )
260
+ return start_v, end_v
261
+
262
+ def dequantize_extrude_params(self, extrude):
263
+ ext_v = dequantize_verts(
264
+ np.array(extrude[1:3], dtype=int),
265
+ n_bits=self.bit,
266
+ min_range=-EXTRUDE_R,
267
+ max_range=EXTRUDE_R,
268
+ add_noise=False,
269
+ )
270
+ ext_T = dequantize_verts(
271
+ np.array(extrude[3:6], dtype=int),
272
+ n_bits=self.bit,
273
+ min_range=-EXTRUDE_R,
274
+ max_range=EXTRUDE_R,
275
+ add_noise=False,
276
+ )
277
+ scale = dequantize_verts(
278
+ np.array(extrude[15], dtype=int),
279
+ n_bits=self.bit,
280
+ min_range=0.0,
281
+ max_range=SCALE_R,
282
+ add_noise=False,
283
+ )
284
+ offset = dequantize_verts(
285
+ np.array(extrude[16:18], dtype=int),
286
+ n_bits=self.bit,
287
+ min_range=-OFFSET_R,
288
+ max_range=OFFSET_R,
289
+ add_noise=False,
290
+ )
291
+ return ext_v, ext_T, scale, offset
292
+
293
+ def save_vertex(self, h_x, h_y, text):
294
+ unique_key = f"{text}:x{h_x}y{h_y}"
295
+ index = 0
296
+ for key in self.vertex_dict.keys():
297
+ # Vertex location already exist in dict
298
+ if unique_key == key:
299
+ return index
300
+ index += 1
301
+ # Vertex location does not exist in dict
302
+ self.vertex_dict[unique_key] = [h_x, h_y]
303
+ return index
304
+
305
+ def convert_vertices(self):
306
+ """Convert all the vertices to .obj format"""
307
+ vertex_strings = ""
308
+ for pt in self.vertex_dict.values():
309
+ # e.g. v 0.123 0.234 0.345 1.0
310
+ vertex_string = f"v {pt[0]} {pt[1]}\n"
311
+ vertex_strings += vertex_string
312
+ return vertex_strings
313
+
314
+
315
+ def find_arc_geometry(a, b, c):
316
+ A = b[0] - a[0]
317
+ B = b[1] - a[1]
318
+ C = c[0] - a[0]
319
+ D = c[1] - a[1]
320
+
321
+ E = A*(a[0] + b[0]) + B*(a[1] + b[1])
322
+ F = C*(a[0] + c[0]) + D*(a[1] + c[1])
323
+
324
+ G = 2.0*(A*(c[1] - b[1])-B*(c[0] - b[0]))
325
+
326
+ if G == 0:
327
+ raise Exception("zero G")
328
+
329
+ p_0 = (D*E - B*F) / G
330
+ p_1 = (A*F - C*E) / G
331
+
332
+ center = np.array([p_0, p_1])
333
+ radius = np.linalg.norm(center - a)
334
+
335
+ angles = []
336
+ for xx in [a, b, c]:
337
+ angle = angle_from_vector_to_x(xx - center)
338
+ angles.append(angle)
339
+
340
+ ab = b-a
341
+ ac = c-a
342
+ cp = np.cross(ab, ac)
343
+ if cp >= 0:
344
+ start_angle_rads = angles[0]
345
+ end_angle_rads = angles[2]
346
+ else:
347
+ start_angle_rads = angles[2]
348
+ end_angle_rads = angles[0]
349
+
350
+ return center, radius, start_angle_rads, end_angle_rads
351
+
352
+
353
+ def angle_from_vector_to_x(vec):
354
+ assert vec.size == 2
355
+ # We need to find a unit vector
356
+ angle = 0.0
357
+
358
+ l = np.linalg.norm(vec)
359
+ uvec = vec/l
360
+
361
+ # 2 | 1
362
+ # -------
363
+ # 3 | 4
364
+ if uvec[0] >= 0:
365
+ if uvec[1] >= 0:
366
+ # Qadrant 1
367
+ angle = math.asin(uvec[1])
368
+ else:
369
+ # Qadrant 4
370
+ angle = 2.0*math.pi - math.asin(-uvec[1])
371
+ else:
372
+ if vec[1] >= 0:
373
+ # Qadrant 2
374
+ angle = math.pi - math.asin(uvec[1])
375
+ else:
376
+ # Qadrant 3
377
+ angle = math.pi + math.asin(-uvec[1])
378
+ return angle
379
+
380
+
381
+ def dequantize_verts(verts, n_bits=8, min_range=-0.5, max_range=0.5, add_noise=False):
382
+ """Convert quantized vertices to floats."""
383
+ range_quantize = 2**n_bits - 1
384
+ verts = verts.astype("float32")
385
+ verts = verts * (max_range - min_range) / range_quantize + min_range
386
+ return verts
387
+
388
+
389
+ def write_obj_sample(save_folder, data):
390
+ for idx, write_data in enumerate(data):
391
+ obj_name = Path(save_folder).stem + "_" + \
392
+ str(idx).zfill(3) + "_param.obj"
393
+ obj_file = Path(save_folder) / obj_name
394
+ extrude_param = write_data["extrude"]
395
+ vertex_strings = write_data["vertex"]
396
+ curve_strings = write_data["curve"]
397
+
398
+ """Write an .obj file with the curves and verts"""
399
+ if extrude_param["op"] == 1: # 'add'
400
+ set_op = "NewBodyFeatureOperation"
401
+ elif extrude_param["op"] == 2: # 'cut'
402
+ set_op = "CutFeatureOperation"
403
+ elif extrude_param["op"] == 3: # 'cut'
404
+ set_op = "IntersectFeatureOperation"
405
+
406
+ with open(obj_file, "w") as fh:
407
+ # Write Meta info
408
+ fh.write("# WaveFront *.obj file\n")
409
+ fh.write("# ExtrudeOperation: " + set_op + "\n")
410
+ fh.write("\n")
411
+
412
+ # Write vertex and curve
413
+ fh.write(vertex_strings)
414
+ fh.write("\n")
415
+ fh.write(curve_strings)
416
+ fh.write("\n")
417
+
418
+ # Write extrude value
419
+ extrude_string = "Extrude "
420
+ for value in extrude_param["value"]:
421
+ extrude_string += str(value) + " "
422
+ fh.write(extrude_string)
423
+ fh.write("\n")
424
+
425
+ # Write refe plane value
426
+ p_orig = parse3d_sample(extrude_param["T"])
427
+ x_axis = parse3d_sample(extrude_param["R"][0:3])
428
+ y_axis = parse3d_sample(extrude_param["R"][3:6])
429
+ z_axis = parse3d_sample(extrude_param["R"][6:9])
430
+ fh.write("T_origin " + p_orig)
431
+ fh.write("\n")
432
+ fh.write("T_xaxis " + x_axis)
433
+ fh.write("\n")
434
+ fh.write("T_yaxis " + y_axis)
435
+ fh.write("\n")
436
+ fh.write("T_zaxis " + z_axis)
437
+
438
+
439
+ def parse3d_sample(point3d):
440
+ x = point3d[0]
441
+ y = point3d[1]
442
+ z = point3d[2]
443
+ return str(x) + " " + str(y) + " " + str(z)
444
+
445
+
446
+ if __name__ == "__main__":
447
+ parser = argparse.ArgumentParser()
448
+ parser.add_argument("--in-path", type=str, required=True)
449
+ parser.add_argument("--out-path", type=str, required=True)
450
+ args = parser.parse_args()
451
+
452
+ # with open(args.in_path, "r") as f:
453
+ # data = f.readlines()
454
+ with open(args.in_path, 'r') as file:
455
+ data = file.read()
456
+
457
+ data = json.loads(data)
458
+
459
+ num_valid_str = 0
460
+ for idx, item in enumerate(data):
461
+ try:
462
+ cad_parser = CADparser(bit=6)
463
+ # print(idx)
464
+ if type(item) == str:
465
+ parsed_data = cad_parser.perform(item)
466
+ elif type(item) == dict:
467
+ parsed_data = cad_parser.perform(item['output'])
468
+ else:
469
+ raise ValueError("Invalid data type")
470
+ out_path = os.path.join(args.out_path, str(idx).zfill(6))
471
+ os.makedirs(out_path, exist_ok=True)
472
+ if parsed_data is not None:
473
+ num_valid_str += 1
474
+ write_obj_sample(out_path, parsed_data)
475
+ except Exception as e:
476
+ print(e)
477
+ pass
478
+ print(f"Number of valid CAD strings: {num_valid_str}/{len(data)}")
CADFusion/src/train/CAD_dataset.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import torch
4
+ import random
5
+ import transformers
6
+
7
+ from dataclasses import dataclass
8
+ from torch.utils.data import Dataset
9
+ from utils import IGNORE_INDEX, MAX_LENGTH
10
+
11
+ class CADDataset(Dataset):
12
+ def __init__(self, json_fn, cutoff=True, llama_tokenizer=None):
13
+ if not os.path.exists(json_fn):
14
+ raise ValueError(f"{json_fn} does not exist")
15
+ self.inputs = json.load(open(json_fn, "r"))
16
+ print(len(self.inputs))
17
+ self.inputs = [item for item in self.inputs if 'null' not in item['description']]
18
+ random.shuffle(self.inputs)
19
+ if cutoff:
20
+ self.inputs = self.inputs[:18953]
21
+ print(len(self.inputs))
22
+ self.llama_tokenizer = llama_tokenizer
23
+
24
+ def __len__(self):
25
+ return len(self.inputs)
26
+
27
+ def __getitem__(self, index):
28
+ item = self.inputs[index]
29
+ seq = item['command_sequence']
30
+ des = item['description']
31
+ val = self.tokenize(seq, des)
32
+ return val
33
+
34
+
35
+ def tokenize(self, seq, des):
36
+ tokens, prompt_length = self.conditional_generation_task(seq=seq, des=des)
37
+ input_ids = tokens.input_ids[0]
38
+ labels = tokens.input_ids[0].clone() # Clone the input_ids for labels
39
+ # Set the labels for the prompt part to IGNORE_INDEX so they are ignored in loss calculation
40
+ labels[:prompt_length] = IGNORE_INDEX
41
+ input_id_lens = label_lens = (
42
+ tokens.input_ids.ne(self.llama_tokenizer.pad_token_id).sum().item()
43
+ )
44
+ return dict(
45
+ input_ids=input_ids,
46
+ input_id_lens=input_id_lens,
47
+ labels=labels,
48
+ label_lens=label_lens,
49
+ )
50
+
51
+
52
+ def conditional_generation_task(self, seq, des):
53
+ prompt = 'Below is a description of a 3D shape:\n'
54
+ prompt += des
55
+ prompt += '\nGenerate a Computer-Aided Design(CAD) command sequence of the 3D shape:\n'
56
+ full_text = prompt + seq + self.llama_tokenizer.eos_token
57
+ tokens = self.llama_tokenizer(
58
+ full_text,
59
+ max_length=MAX_LENGTH,
60
+ return_tensors="pt",
61
+ truncation=True,
62
+ )
63
+ prompt_length = len(self.llama_tokenizer(prompt)['input_ids'])
64
+ return tokens, prompt_length
65
+
66
+
67
+ @dataclass
68
+ class DataCollatorForSupervisedDataset(object):
69
+ """Collate examples for supervised fine-tuning."""
70
+
71
+ tokenizer: transformers.PreTrainedTokenizer
72
+
73
+ def __call__(self, instances):
74
+ input_ids, labels = tuple(
75
+ [instance[key].clone().detach() for instance in instances]
76
+ for key in ("input_ids", "labels")
77
+ )
78
+ # force left padding
79
+ reversed_sequences = [torch.flip(input_id, [0]) for input_id in input_ids]
80
+ input_ids = torch.nn.utils.rnn.pad_sequence(reversed_sequences, batch_first=True, padding_value=self.tokenizer.pad_token_id)
81
+ input_ids = torch.flip(input_ids, [0, 1])
82
+ labels = torch.nn.utils.rnn.pad_sequence(
83
+ labels, batch_first=True, padding_value=IGNORE_INDEX
84
+ )
85
+ return dict(
86
+ input_ids=input_ids,
87
+ labels=labels,
88
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
89
+ )
CADFusion/src/train/dpo.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import json
5
+ import random
6
+ import transformers
7
+ from huggingface_hub import login
8
+
9
+ login() # put your huggingface token here
10
+ os.environ["WANDB_PROJECT"] = "CADFusion_VF"
11
+
12
+ from datasets import Dataset
13
+ from trl import DPOTrainer, DPOConfig
14
+ from utils import prepare_model_and_tokenizer
15
+
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument("--run-name", type=str, required=True)
18
+ parser.add_argument("--lora-rank", type=int, default=32)
19
+ parser.add_argument("--lora-alpha", type=int, default=32)
20
+ parser.add_argument("--lora-dropout", type=float, default=0.05)
21
+ parser.add_argument("--sample-cutoff", default=100000, type=int)
22
+ parser.add_argument("--pretrained-path", type=str, required=True)
23
+ parser.add_argument("--data-path", type=str, required=True)
24
+ parser.add_argument("--output-path", type=str, required=True)
25
+ parser.add_argument("--num-epochs", type=int, default=3)
26
+ parser.add_argument("--batch-size", type=int, default=2)
27
+ parser.add_argument("--eval-freq", default=1000, type=int)
28
+ parser.add_argument("--save-freq", default=500, type=int)
29
+ parser.add_argument("--debug", action="store_true", default=False)
30
+ args = parser.parse_args()
31
+
32
+
33
+
34
+ with open(args.data_path, 'r') as f:
35
+ raw_data = json.load(f)
36
+
37
+ random.shuffle(raw_data)
38
+
39
+ if len(raw_data) > args.sample_cutoff + 100:
40
+ ds = {
41
+ "train": Dataset.from_list(raw_data[:args.sample_cutoff]),
42
+ "val": Dataset.from_list(raw_data[-100:])
43
+ }
44
+ else:
45
+ ds = {
46
+ "train": Dataset.from_list(raw_data[:-100]),
47
+ "val": Dataset.from_list(raw_data[-100:])
48
+ }
49
+
50
+ llama_model, llama_tokenizer = prepare_model_and_tokenizer(args)
51
+
52
+ for name, param in llama_model.named_parameters():
53
+ if "lora" in name: # Check if "lora" is in the parameter's name
54
+ param.requires_grad = True
55
+
56
+ training_args = DPOConfig(
57
+ run_name=args.run_name,
58
+ learning_rate=1.41e-5,
59
+ per_device_train_batch_size=2,
60
+ per_device_eval_batch_size=args.batch_size,
61
+ report_to="wandb",
62
+ num_train_epochs=args.num_epochs,
63
+ do_eval=True,
64
+ eval_steps=args.eval_freq,
65
+ save_steps=args.save_freq,
66
+ output_dir=args.output_path
67
+ )
68
+
69
+ trainer = DPOTrainer(
70
+ llama_model,
71
+ None,
72
+ args=training_args,
73
+ train_dataset=ds['train'],
74
+ eval_dataset=ds['val'],
75
+ tokenizer=llama_tokenizer,
76
+ )
77
+ trainer.save_model()
78
+ trainer.train()
79
+ trainer.save_model()
CADFusion/src/train/llama_finetune.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ import os
4
+ import torch
5
+ import transformers
6
+
7
+ from CAD_dataset import CADDataset, DataCollatorForSupervisedDataset
8
+ from huggingface_hub import login
9
+ from pathlib import Path
10
+ from peft import LoraConfig, get_peft_model
11
+ from transformers import Trainer, TrainingArguments
12
+ from utils import prepare_model_and_tokenizer
13
+
14
+ login() # put your huggingface token here
15
+
16
+ def setup_datasets(args, llama_tokenizer, transform_args={}):
17
+ datasets = {
18
+ "train": CADDataset(
19
+ args.data_path,
20
+ llama_tokenizer=llama_tokenizer,
21
+ ),
22
+ "val": CADDataset(
23
+ args.eval_data_path,
24
+ llama_tokenizer=llama_tokenizer,
25
+ ),
26
+ }
27
+ return datasets
28
+
29
+
30
+ def setup_training_args(args):
31
+ output_dir = args.expdir / args.run_name
32
+ output_dir.mkdir(parents=True, exist_ok=True)
33
+
34
+ if args.debug:
35
+ os.environ["WANDB_DISABLED"] = "True"
36
+ os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
37
+ training_args = TrainingArguments(
38
+ fsdp=False,
39
+ fp16=False,
40
+ bf16=False,
41
+ do_eval=True,
42
+ gradient_checkpointing=False,
43
+ ddp_find_unused_parameters=False,
44
+ num_train_epochs=args.num_epochs,
45
+ eval_steps=args.eval_freq,
46
+ save_steps=args.save_freq,
47
+ logging_steps=10,
48
+ evaluation_strategy="steps",
49
+ per_device_train_batch_size=args.batch_size,
50
+ per_device_eval_batch_size=args.batch_size,
51
+ learning_rate=args.lr,
52
+ lr_scheduler_type=args.lr_scheduler,
53
+ warmup_steps=args.num_warmup_steps,
54
+ weight_decay=args.weight_decay,
55
+ gradient_accumulation_steps=args.grad_accum,
56
+ output_dir=output_dir,
57
+ run_name=args.run_name,
58
+ report_to="wandb",
59
+ dataloader_num_workers=8,
60
+ remove_unused_columns=False,
61
+ # label_names=["cad_ids"], # this is to make trainer behave as expected
62
+ )
63
+ return training_args
64
+
65
+
66
+ def setup_trainer(args):
67
+ training_args = setup_training_args(args)
68
+ if args.device_map == 'accelerate':
69
+ args.device_map = {'': training_args.local_rank}
70
+ model, llama_tokenizer = prepare_model_and_tokenizer(args)
71
+
72
+ datasets = setup_datasets(args, llama_tokenizer)
73
+
74
+ data_collator = DataCollatorForSupervisedDataset(
75
+ tokenizer=llama_tokenizer,
76
+ )
77
+
78
+ trainer = Trainer(
79
+ model=model,
80
+ args=training_args,
81
+ train_dataset=datasets["train"],
82
+ eval_dataset=datasets["val"],
83
+ data_collator=data_collator,
84
+ )
85
+
86
+ return trainer
87
+
88
+
89
+ def main(args):
90
+ trainer = setup_trainer(args)
91
+
92
+ if args.resume_dir is not None:
93
+ train_result = trainer.train(resume_from_checkpoint=args.resume_dir)
94
+ else:
95
+ train_result = trainer.train()
96
+
97
+ print(train_result)
98
+ trainer.save_state()
99
+ trainer.save_model()
100
+
101
+
102
+ if __name__ == "__main__":
103
+ parser = argparse.ArgumentParser()
104
+ parser.add_argument("--run-name", type=str, required=True)
105
+ parser.add_argument("--expdir", type=Path, default="exp")
106
+ parser.add_argument("--model-name", default="llama3")
107
+ parser.add_argument("--lora-rank", type=int, default=32)
108
+ parser.add_argument("--lora-alpha", type=int, default=32)
109
+ parser.add_argument("--lora-dropout", type=float, default=0.05)
110
+ parser.add_argument("--data-path", type=Path, default="data/train.json")
111
+ parser.add_argument("--eval-data-path", type=Path, default="data/eval.json")
112
+ parser.add_argument("--pretrained-path", type=Path, default=None)
113
+ parser.add_argument("--num-epochs", type=int, default=40)
114
+ parser.add_argument("--batch-size", type=int, default=1)
115
+ parser.add_argument("--grad-accum", type=int, default=1)
116
+ parser.add_argument("--lr", type=float, default=1e-4)
117
+ parser.add_argument("--lr-scheduler", type=str, default="cosine")
118
+ parser.add_argument("--num-warmup-steps", type=int, default=100)
119
+ parser.add_argument("--weight-decay", type=float, default=0.0)
120
+ parser.add_argument("--eval-freq", default=1000, type=int)
121
+ parser.add_argument("--save-freq", default=50000, type=int)
122
+ parser.add_argument("--device-map", type=str, default='auto')
123
+ parser.add_argument("--resume-dir", type=Path, default=None)
124
+ parser.add_argument("--debug", action="store_true", default=False)
125
+ args = parser.parse_args()
126
+ os.environ["WANDB_PROJECT"] = "CADFusion_SL"
127
+ main(args)
CADFusion/src/train/utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from peft import LoraConfig, PeftModel, get_peft_model
4
+
5
+ IGNORE_INDEX = -100
6
+ MAX_LENGTH = 512
7
+ DEFAULT_PAD_TOKEN = "[PAD]"
8
+ DEFAULT_EOS_TOKEN = "</s>"
9
+ DEFAULT_BOS_TOKEN = "<s>"
10
+ DEFAULT_UNK_TOKEN = "<unk>"
11
+
12
+ def smart_tokenizer_and_embedding_resize(
13
+ special_tokens_dict,
14
+ llama_tokenizer,
15
+ model,
16
+ ):
17
+ """Resize tokenizer and embedding.
18
+
19
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
20
+ """
21
+ num_new_tokens = llama_tokenizer.add_special_tokens(special_tokens_dict)
22
+ model.resize_token_embeddings(len(llama_tokenizer))
23
+
24
+ if num_new_tokens > 0:
25
+ input_embeddings = model.get_input_embeddings().weight.data
26
+ output_embeddings = model.get_output_embeddings().weight.data
27
+
28
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
29
+ dim=0, keepdim=True
30
+ )
31
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
32
+ dim=0, keepdim=True
33
+ )
34
+
35
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
36
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
37
+
38
+ def prepare_model_and_tokenizer(args):
39
+ model_id = "meta-llama/Meta-Llama-3-8B"
40
+ print(f"Model size: {model_id}")
41
+ if hasattr(args, 'device_map'):
42
+ device_map = args.device_map
43
+ else:
44
+ device_map = 'auto'
45
+ pipeline = transformers.pipeline("text2text-generation",
46
+ model=model_id, model_kwargs={"torch_dtype": torch.float32}, device_map=device_map)
47
+ tokenizer = pipeline.tokenizer
48
+ base_model = pipeline.model
49
+
50
+ special_tokens_dict = dict()
51
+ if tokenizer.pad_token is None:
52
+ special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
53
+ if tokenizer.eos_token is None:
54
+ special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
55
+ if tokenizer.bos_token is None:
56
+ special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
57
+ if tokenizer.unk_token is None:
58
+ special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
59
+
60
+ smart_tokenizer_and_embedding_resize(
61
+ special_tokens_dict=special_tokens_dict,
62
+ llama_tokenizer=tokenizer,
63
+ model=base_model,
64
+ )
65
+
66
+ peft_config = LoraConfig(
67
+ r=args.lora_rank,
68
+ lora_alpha=args.lora_alpha,
69
+ lora_dropout=args.lora_dropout,
70
+ bias="none",
71
+ task_type="CAUSAL_LM",
72
+ )
73
+
74
+ tokenizer.padding_side = 'left'
75
+ peftmodel = get_peft_model(base_model, peft_config)
76
+ if args.pretrained_path:
77
+ # load a previous checkpoint if the path is given
78
+ model = PeftModel.from_pretrained(base_model, args.pretrained_path, device_map=device_map)
79
+ peft_state_dict = {f"{k}": v for k, v in model.state_dict().items()}
80
+ peftmodel.load_state_dict(peft_state_dict)
81
+
82
+ for name, param in peftmodel.named_parameters():
83
+ if "lora" in name: # Check if "lora" is in the parameter's name
84
+ param.requires_grad = True
85
+ peftmodel.print_trainable_parameters()
86
+ return peftmodel, tokenizer