Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- .gitignore +42 -0
- README.md +43 -12
- README_zh_cn.md +242 -0
- app.py +231 -110
- env_install.sh +1 -1
- infer/gif_render.py +3 -3
- infer/image_to_views.py +9 -4
- infer/text_to_image.py +1 -2
- infer/utils.py +7 -1
- infer/views_to_mesh.py +7 -4
- main.py +60 -12
- requirements.txt +1 -0
- svrm/ldm/models/svrm.py +16 -19
- svrm/ldm/modules/attention.py +20 -11
- svrm/ldm/vis_util.py +14 -15
- svrm/predictor.py +1 -3
- third_party/check.py +25 -0
- third_party/dust3r_utils.py +366 -0
- third_party/gen_baking.py +288 -0
- third_party/mesh_baker.py +142 -0
- third_party/utils/camera_utils.py +90 -0
- third_party/utils/img_utils.py +211 -0
.gitignore
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
**/*~
|
2 |
+
**/*.bk
|
3 |
+
**/*.xx
|
4 |
+
**/*.so
|
5 |
+
**/*.ipynb
|
6 |
+
**/*.log
|
7 |
+
**/*.swp
|
8 |
+
**/*.zip
|
9 |
+
**/*.look
|
10 |
+
**/*.lock
|
11 |
+
**/*.think
|
12 |
+
**/dosth.sh
|
13 |
+
**/nohup.out
|
14 |
+
**/*polaris*
|
15 |
+
**/*egg*/
|
16 |
+
**/cl5/
|
17 |
+
**/tmp/
|
18 |
+
**/look/
|
19 |
+
**/temp/
|
20 |
+
**/build/
|
21 |
+
**/model/
|
22 |
+
**/log/
|
23 |
+
**/backup/
|
24 |
+
**/outputs/
|
25 |
+
**/work_dir/
|
26 |
+
**/work_dirs/
|
27 |
+
**/__pycache__/
|
28 |
+
**/.ipynb_checkpoints/
|
29 |
+
*.jpg
|
30 |
+
*.png
|
31 |
+
*.gif
|
32 |
+
### PreCI ###
|
33 |
+
.codecc
|
34 |
+
|
35 |
+
app_hg.py
|
36 |
+
outputs
|
37 |
+
weights
|
38 |
+
.vscode/
|
39 |
+
baking
|
40 |
+
inference.py
|
41 |
+
third_party/weights
|
42 |
+
third_party/dust3r
|
README.md
CHANGED
@@ -1,14 +1,5 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
emoji: 😻
|
4 |
-
colorFrom: purple
|
5 |
-
colorTo: red
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.5.0
|
8 |
-
app_file: app_hg.py
|
9 |
-
pinned: false
|
10 |
-
short_description: Text-to-3D and Image-to-3D Generation
|
11 |
-
---
|
12 |
<!-- ## **Hunyuan3D-1.0** -->
|
13 |
|
14 |
<p align="center">
|
@@ -19,7 +10,7 @@ short_description: Text-to-3D and Image-to-3D Generation
|
|
19 |
|
20 |
<div align="center">
|
21 |
<a href="https://github.com/tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Code&message=Github&color=blue&logo=github-pages"></a>  
|
22 |
-
<a href="https://3d.hunyuan.tencent.com"><img src="https://img.shields.io/static/v1?label=Homepage&message=Tencent
|
23 |
<a href="https://arxiv.org/pdf/2411.02293"><img src="https://img.shields.io/static/v1?label=Tech Report&message=Arxiv&color=red&logo=arxiv"></a>  
|
24 |
<a href="https://huggingface.co/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Checkpoints&message=HuggingFace&color=yellow"></a>  
|
25 |
<a href="https://huggingface.co/spaces/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Demo&message=HuggingFace&color=yellow"></a>  
|
@@ -101,6 +92,19 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
|
|
101 |
# step 3. install other packages
|
102 |
bash env_install.sh
|
103 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
<details>
|
105 |
<summary>💡Other tips for envrionment installation</summary>
|
106 |
|
@@ -204,6 +208,33 @@ bash scripts/image_to_3d_std_separately.sh ./demos/example_000.png ./outputs/tes
|
|
204 |
bash scripts/image_to_3d_lite_separately.sh ./demos/example_000.png ./outputs/test # >= 10G
|
205 |
```
|
206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
#### Using Gradio
|
208 |
|
209 |
We have prepared two versions of multi-view generation, std and lite.
|
|
|
1 |
+
[English](README.md) | [简体中文](README_zh_cn.md)
|
2 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
<!-- ## **Hunyuan3D-1.0** -->
|
4 |
|
5 |
<p align="center">
|
|
|
10 |
|
11 |
<div align="center">
|
12 |
<a href="https://github.com/tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Code&message=Github&color=blue&logo=github-pages"></a>  
|
13 |
+
<a href="https://3d.hunyuan.tencent.com"><img src="https://img.shields.io/static/v1?label=Homepage&message=Tencent%20Hunyuan3D&color=blue&logo=github-pages"></a>  
|
14 |
<a href="https://arxiv.org/pdf/2411.02293"><img src="https://img.shields.io/static/v1?label=Tech Report&message=Arxiv&color=red&logo=arxiv"></a>  
|
15 |
<a href="https://huggingface.co/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Checkpoints&message=HuggingFace&color=yellow"></a>  
|
16 |
<a href="https://huggingface.co/spaces/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Demo&message=HuggingFace&color=yellow"></a>  
|
|
|
92 |
# step 3. install other packages
|
93 |
bash env_install.sh
|
94 |
```
|
95 |
+
|
96 |
+
because of dust3r, we offer a guide:
|
97 |
+
|
98 |
+
```
|
99 |
+
cd third_party
|
100 |
+
git clone --recursive https://github.com/naver/dust3r.git
|
101 |
+
|
102 |
+
cd ../third_party/weights
|
103 |
+
wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
|
104 |
+
|
105 |
+
```
|
106 |
+
|
107 |
+
|
108 |
<details>
|
109 |
<summary>💡Other tips for envrionment installation</summary>
|
110 |
|
|
|
208 |
bash scripts/image_to_3d_lite_separately.sh ./demos/example_000.png ./outputs/test # >= 10G
|
209 |
```
|
210 |
|
211 |
+
#### Baking related
|
212 |
+
|
213 |
+
```bash
|
214 |
+
cd ./third_party
|
215 |
+
git clone --recursive https://github.com/naver/dust3r.git
|
216 |
+
|
217 |
+
mkdir -p weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt
|
218 |
+
cd weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt
|
219 |
+
|
220 |
+
wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
|
221 |
+
cd ../../..
|
222 |
+
```
|
223 |
+
|
224 |
+
If you download related code and weights, we list some additional arg:
|
225 |
+
|
226 |
+
| Argument | Default | Description |
|
227 |
+
|:------------------:|:---------:|:---------------------------------------------------:|
|
228 |
+
|`--do_bake` | False | baking multi-view into mesh |
|
229 |
+
|`--bake_align_times` | 3 | the times of align image with mesh |
|
230 |
+
|
231 |
+
|
232 |
+
Note: When running main.py, ensure that do_bake is set to True and do_texture_mapping is also set to True.
|
233 |
+
|
234 |
+
```bash
|
235 |
+
python main.py ... --do_texture_mapping --do_bake (--do_render)
|
236 |
+
```
|
237 |
+
|
238 |
#### Using Gradio
|
239 |
|
240 |
We have prepared two versions of multi-view generation, std and lite.
|
README_zh_cn.md
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[English](README.md) | [简体中文](README_zh_cn.md)
|
2 |
+
|
3 |
+
<!-- ## **Hunyuan3D-1.0** -->
|
4 |
+
|
5 |
+
<p align="center">
|
6 |
+
<img src="./assets/logo.png" height=200>
|
7 |
+
</p>
|
8 |
+
|
9 |
+
# Tencent Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation
|
10 |
+
|
11 |
+
<div align="center">
|
12 |
+
<a href="https://github.com/tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Code&message=Github&color=blue&logo=github-pages"></a>  
|
13 |
+
<a href="https://3d.hunyuan.tencent.com"><img src="https://img.shields.io/static/v1?label=Homepage&message=Tencent%20Hunyuan3D&color=blue&logo=github-pages"></a>  
|
14 |
+
<a href="https://arxiv.org/pdf/2411.02293"><img src="https://img.shields.io/static/v1?label=Tech Report&message=Arxiv&color=red&logo=arxiv"></a>  
|
15 |
+
<a href="https://huggingface.co/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Checkpoints&message=HuggingFace&color=yellow"></a>  
|
16 |
+
<a href="https://huggingface.co/spaces/Tencent/Hunyuan3D-1"><img src="https://img.shields.io/static/v1?label=Demo&message=HuggingFace&color=yellow"></a>  
|
17 |
+
</div>
|
18 |
+
|
19 |
+
|
20 |
+
## 🔥🔥🔥 更新!!
|
21 |
+
|
22 |
+
* Nov 5, 2024: 💬 已经支持图生3D。请在[script](#using-gradio)体验。
|
23 |
+
* Nov 5, 2024: 💬 已经支持文生3D,请在[script](#using-gradio)体验。
|
24 |
+
|
25 |
+
|
26 |
+
## 📑 开源计划
|
27 |
+
|
28 |
+
- [x] Inference
|
29 |
+
- [x] Checkpoints
|
30 |
+
- [ ] Baking related
|
31 |
+
- [ ] Training
|
32 |
+
- [ ] ComfyUI
|
33 |
+
- [ ] Distillation Version
|
34 |
+
- [ ] TensorRT Version
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
## **概要**
|
39 |
+
<p align="center">
|
40 |
+
<img src="./assets/teaser.png" height=450>
|
41 |
+
</p>
|
42 |
+
|
43 |
+
为了解决现有的3D生成模型在生成速度和泛化能力上存在不足,我们开源了混元3D-1.0模型,可以帮助3D创作者和艺术家自动化生产3D资产。我们的模型采用两阶段生成方法,在保证质量和可控的基础上,仅需10秒即可生成3D资产。在第一阶段,我们采用了一种多视角扩散模型,轻量版模型能够在大约4秒内高效生成多视角图像,这些多视角图像从不同的视角捕捉了3D资产的丰富的纹理和几何先验,将任务从单视角重建松弛到多视角重建。在第二阶段,我们引入了一种前馈重建模型,利用上一阶段生成的多视角图像。该模型能够在大约3秒内快速而准确地重建3D资产。重建模型学习处理多视角扩散引入的噪声和不一致性,并利用条件图像中的可用信息高效恢复3D结构。最终,该模型可以实现输入任意单视角实现三维生成。
|
44 |
+
|
45 |
+
|
46 |
+
## 🎉 **Hunyuan3D-1.0 模型架构**
|
47 |
+
|
48 |
+
<p align="center">
|
49 |
+
<img src="./assets/overview_3.png" height=400>
|
50 |
+
</p>
|
51 |
+
|
52 |
+
|
53 |
+
## 📈 比较
|
54 |
+
|
55 |
+
通过和其他开源模型比较, 混元3D-1.0在5项指标都得到了最高用户评分。细节请查看以下用户研究结果。
|
56 |
+
|
57 |
+
在A100显卡上,轻量版模型仅需10s即可完成单图生成3D,标准版则大约需要25s。以下散点图表明腾讯混元3D-1.0实现了质量和速度的合理平衡。
|
58 |
+
|
59 |
+
<p align="center">
|
60 |
+
<img src="./assets/radar.png" height=300>
|
61 |
+
<img src="./assets/runtime.png" height=300>
|
62 |
+
</p>
|
63 |
+
|
64 |
+
## 使用
|
65 |
+
|
66 |
+
#### 复制代码仓库
|
67 |
+
|
68 |
+
```shell
|
69 |
+
git clone https://github.com/tencent/Hunyuan3D-1
|
70 |
+
cd Hunyuan3D-1
|
71 |
+
```
|
72 |
+
|
73 |
+
#### Linux系统安装
|
74 |
+
|
75 |
+
env_install.sh 脚本提供了如何安装环境:
|
76 |
+
|
77 |
+
```
|
78 |
+
# 第一步:创建环境
|
79 |
+
conda create -n hunyuan3d-1 python=3.9 or 3.10 or 3.11 or 3.12
|
80 |
+
conda activate hunyuan3d-1
|
81 |
+
|
82 |
+
# 第二部:安装torch和相关依赖包
|
83 |
+
which pip # check pip corresponds to python
|
84 |
+
|
85 |
+
# modify the cuda version according to your machine (recommended)
|
86 |
+
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
|
87 |
+
|
88 |
+
# 第三步:安装其他相关依赖包
|
89 |
+
bash env_install.sh
|
90 |
+
```
|
91 |
+
|
92 |
+
由于dust3r的许可证限制, 我们仅提供其安装途径:
|
93 |
+
|
94 |
+
```
|
95 |
+
cd third_party
|
96 |
+
git clone --recursive https://github.com/naver/dust3r.git
|
97 |
+
|
98 |
+
cd ../third_party/weights
|
99 |
+
wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
|
100 |
+
|
101 |
+
```
|
102 |
+
|
103 |
+
|
104 |
+
<details>
|
105 |
+
<summary>💡一些环境安装建议</summary>
|
106 |
+
|
107 |
+
可以选择安装 xformers 或 flash_attn 进行加速:
|
108 |
+
|
109 |
+
```
|
110 |
+
pip install xformers --index-url https://download.pytorch.org/whl/cu121
|
111 |
+
```
|
112 |
+
```
|
113 |
+
pip install flash_attn
|
114 |
+
```
|
115 |
+
|
116 |
+
Most environment errors are caused by a mismatch between machine and packages. You can try manually specifying the version, as shown in the following successful cases:
|
117 |
+
```
|
118 |
+
# python3.9
|
119 |
+
pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu118
|
120 |
+
```
|
121 |
+
|
122 |
+
when install pytorch3d, the gcc version is preferably greater than 9, and the gpu driver should not be too old.
|
123 |
+
|
124 |
+
</details>
|
125 |
+
|
126 |
+
#### 下载预训练模型
|
127 |
+
|
128 |
+
模型下载链接 [https://huggingface.co/tencent/Hunyuan3D-1](https://huggingface.co/tencent/Hunyuan3D-1):
|
129 |
+
|
130 |
+
+ `Hunyuan3D-1/lite`, lite model for multi-view generation.
|
131 |
+
+ `Hunyuan3D-1/std`, standard model for multi-view generation.
|
132 |
+
+ `Hunyuan3D-1/svrm`, sparse-view reconstruction model.
|
133 |
+
|
134 |
+
|
135 |
+
为了通过Hugging Face下载模型,请先下载 huggingface-cli. (安装细节可见 [here](https://huggingface.co/docs/huggingface_hub/guides/cli).)
|
136 |
+
|
137 |
+
```shell
|
138 |
+
python3 -m pip install "huggingface_hub[cli]"
|
139 |
+
```
|
140 |
+
|
141 |
+
请使用以下命令下载模型:
|
142 |
+
|
143 |
+
```shell
|
144 |
+
mkdir weights
|
145 |
+
huggingface-cli download tencent/Hunyuan3D-1 --local-dir ./weights
|
146 |
+
|
147 |
+
mkdir weights/hunyuanDiT
|
148 |
+
huggingface-cli download Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled --local-dir ./weights/hunyuanDiT
|
149 |
+
```
|
150 |
+
|
151 |
+
#### 推理
|
152 |
+
对于文生3D,我们支持中/英双语生成,请使用以下命令进行本地推理:
|
153 |
+
```python
|
154 |
+
python3 main.py \
|
155 |
+
--text_prompt "a lovely rabbit" \
|
156 |
+
--save_folder ./outputs/test/ \
|
157 |
+
--max_faces_num 90000 \
|
158 |
+
--do_texture_mapping \
|
159 |
+
--do_render
|
160 |
+
```
|
161 |
+
|
162 |
+
对于图生3D,请使用以下命令进行本地推理:
|
163 |
+
```python
|
164 |
+
python3 main.py \
|
165 |
+
--image_prompt "/path/to/your/image" \
|
166 |
+
--save_folder ./outputs/test/ \
|
167 |
+
--max_faces_num 90000 \
|
168 |
+
--do_texture_mapping \
|
169 |
+
--do_render
|
170 |
+
```
|
171 |
+
更多参数详解:
|
172 |
+
|
173 |
+
| Argument | Default | Description |
|
174 |
+
|:------------------:|:---------:|:---------------------------------------------------:|
|
175 |
+
|`--text_prompt` | None |The text prompt for 3D generation |
|
176 |
+
|`--image_prompt` | None |The image prompt for 3D generation |
|
177 |
+
|`--t2i_seed` | 0 |The random seed for generating images |
|
178 |
+
|`--t2i_steps` | 25 |The number of steps for sampling of text to image |
|
179 |
+
|`--gen_seed` | 0 |The random seed for generating 3d generation |
|
180 |
+
|`--gen_steps` | 50 |The number of steps for sampling of 3d generation |
|
181 |
+
|`--max_faces_numm` | 90000 |The limit number of faces of 3d mesh |
|
182 |
+
|`--save_memory` | False |module will move to cpu automatically|
|
183 |
+
|`--do_texture_mapping` | False |Change vertex shadding to texture shading |
|
184 |
+
|`--do_render` | False |render gif |
|
185 |
+
|
186 |
+
|
187 |
+
如果显卡内存有限,可以使用`--save_memory`命令,最低显卡内存要求如下:
|
188 |
+
- Inference Std-pipeline requires 30GB VRAM (24G VRAM with --save_memory).
|
189 |
+
- Inference Lite-pipeline requires 22GB VRAM (18G VRAM with --save_memory).
|
190 |
+
- Note: --save_memory will increase inference time
|
191 |
+
|
192 |
+
```bash
|
193 |
+
bash scripts/text_to_3d_std.sh
|
194 |
+
bash scripts/text_to_3d_lite.sh
|
195 |
+
bash scripts/image_to_3d_std.sh
|
196 |
+
bash scripts/image_to_3d_lite.sh
|
197 |
+
```
|
198 |
+
|
199 |
+
如果你的显卡内存为16G,可以分别加载模型到显卡:
|
200 |
+
```bash
|
201 |
+
bash scripts/text_to_3d_std_separately.sh 'a lovely rabbit' ./outputs/test # >= 16G
|
202 |
+
bash scripts/text_to_3d_lite_separately.sh 'a lovely rabbit' ./outputs/test # >= 14G
|
203 |
+
bash scripts/image_to_3d_std_separately.sh ./demos/example_000.png ./outputs/test # >= 16G
|
204 |
+
bash scripts/image_to_3d_lite_separately.sh ./demos/example_000.png ./outputs/test # >= 10G
|
205 |
+
```
|
206 |
+
|
207 |
+
#### Gradio界面部署
|
208 |
+
|
209 |
+
我们分别提供轻量版和标准版界面:
|
210 |
+
|
211 |
+
```shell
|
212 |
+
# std
|
213 |
+
python3 app.py
|
214 |
+
python3 app.py --save_memory
|
215 |
+
|
216 |
+
# lite
|
217 |
+
python3 app.py --use_lite
|
218 |
+
python3 app.py --use_lite --save_memory
|
219 |
+
```
|
220 |
+
|
221 |
+
Gradio界面体验地址为 http://0.0.0.0:8080. 这里 0.0.0.0 应当填写运行模型的机器IP地址。
|
222 |
+
|
223 |
+
## 相机参数
|
224 |
+
|
225 |
+
生成多视图视角固定为
|
226 |
+
|
227 |
+
+ Azimuth (relative to input view): `+0, +60, +120, +180, +240, +300`.
|
228 |
+
|
229 |
+
|
230 |
+
## 引用
|
231 |
+
|
232 |
+
如果我们的仓库对您有帮助,请引用我们的工作
|
233 |
+
```bibtex
|
234 |
+
@misc{yang2024tencent,
|
235 |
+
title={Tencent Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation},
|
236 |
+
author={Xianghui Yang and Huiwen Shi and Bowen Zhang and Fan Yang and Jiacheng Wang and Hongxu Zhao and Xinhai Liu and Xinzhou Wang and Qingxiang Lin and Jiaao Yu and Lifu Wang and Zhuo Chen and Sicong Liu and Yuhong Liu and Yong Yang and Di Wang and Jie Jiang and Chunchao Guo},
|
237 |
+
year={2024},
|
238 |
+
eprint={2411.02293},
|
239 |
+
archivePrefix={arXiv},
|
240 |
+
primaryClass={cs.CV}
|
241 |
+
}
|
242 |
+
```
|
app.py
CHANGED
@@ -32,9 +32,21 @@ import torch
|
|
32 |
import numpy as np
|
33 |
from PIL import Image
|
34 |
from einops import rearrange
|
|
|
35 |
|
36 |
from infer import seed_everything, save_gif
|
37 |
from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
warnings.simplefilter('ignore', category=UserWarning)
|
40 |
warnings.simplefilter('ignore', category=FutureWarning)
|
@@ -58,33 +70,19 @@ CONST_MAX_QUEUE = 1
|
|
58 |
CONST_SERVER = '0.0.0.0'
|
59 |
|
60 |
CONST_HEADER = '''
|
61 |
-
<h2><
|
62 |
-
|
63 |
-
Code: <a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'>GitHub</a>. Techenical report: <a href='https://arxiv.org/abs/placeholder' target='_blank'>ArXiv</a>.
|
64 |
-
|
65 |
-
❗️❗️❗️**Important Notes:**
|
66 |
-
- By default, our demo can export a .obj mesh with vertex colors or a .glb mesh.
|
67 |
-
- If you select "texture mapping," it will export a .obj mesh with a texture map or a .glb mesh.
|
68 |
-
- If you select "render GIF," it will export a GIF image rendering of the .glb file.
|
69 |
-
- If the result is unsatisfactory, please try a different seed value (Default: 0).
|
70 |
'''
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
year={2024},
|
82 |
-
eprint={2411.02293},
|
83 |
-
archivePrefix={arXiv},
|
84 |
-
primaryClass={cs.CV}
|
85 |
-
}
|
86 |
-
```
|
87 |
-
"""
|
88 |
|
89 |
################################################################
|
90 |
# prepare text examples and image examples
|
@@ -129,6 +127,13 @@ worker_v23 = Views2Mesh(
|
|
129 |
)
|
130 |
worker_gif = GifRenderer(args.device)
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
def stage_0_t2i(text, image, seed, step):
|
133 |
os.makedirs('./outputs/app_output', exist_ok=True)
|
134 |
exists = set(int(_) for _ in os.listdir('./outputs/app_output') if not _.startswith("."))
|
@@ -153,11 +158,11 @@ def stage_0_t2i(text, image, seed, step):
|
|
153 |
dst = worker_xbg(image, save_folder)
|
154 |
return dst, save_folder
|
155 |
|
156 |
-
def stage_1_xbg(image, save_folder):
|
157 |
if isinstance(image, str):
|
158 |
image = Image.open(image)
|
159 |
dst = save_folder + '/img_nobg.png'
|
160 |
-
rgba = worker_xbg(image)
|
161 |
rgba.save(dst)
|
162 |
return dst
|
163 |
|
@@ -181,12 +186,9 @@ def stage_3_v23(
|
|
181 |
seed,
|
182 |
save_folder,
|
183 |
target_face_count = 30000,
|
184 |
-
|
185 |
-
do_render =True
|
186 |
):
|
187 |
-
do_texture_mapping =
|
188 |
-
obj_dst = save_folder + '/mesh_with_colors.obj'
|
189 |
-
glb_dst = save_folder + '/mesh.glb'
|
190 |
worker_v23(
|
191 |
views_pil,
|
192 |
cond_pil,
|
@@ -195,149 +197,268 @@ def stage_3_v23(
|
|
195 |
target_face_count = target_face_count,
|
196 |
do_texture_mapping = do_texture_mapping
|
197 |
)
|
|
|
|
|
|
|
198 |
return obj_dst, glb_dst
|
199 |
|
200 |
-
def
|
201 |
-
if
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
return gif_dst
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
# ===============================================================
|
209 |
# gradio display
|
210 |
# ===============================================================
|
|
|
211 |
with gr.Blocks() as demo:
|
212 |
gr.Markdown(CONST_HEADER)
|
213 |
with gr.Row(variant="panel"):
|
|
|
|
|
|
|
214 |
with gr.Column(scale=2):
|
|
|
|
|
|
|
215 |
with gr.Tab("Text to 3D"):
|
216 |
with gr.Column():
|
217 |
-
text = gr.TextArea('一只黑白相间的熊猫在白色背景上居中坐着,呈现出卡通风格和可爱氛围。',
|
|
|
218 |
with gr.Row():
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
with gr.Row():
|
226 |
-
textgen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False, interactive=True)
|
227 |
-
textgen_do_render_gif = gr.Checkbox(label="Render gif", value=False, interactive=True)
|
228 |
textgen_submit = gr.Button("Generate", variant="primary")
|
229 |
|
230 |
with gr.Row():
|
231 |
-
gr.Examples(examples=example_ts, inputs=[text], label="
|
232 |
|
|
|
|
|
233 |
with gr.Tab("Image to 3D"):
|
234 |
-
with gr.
|
235 |
-
input_image = gr.Image(label="Input image",
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
|
244 |
-
with gr.Row():
|
245 |
-
imggen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False, interactive=True)
|
246 |
-
imggen_do_render_gif = gr.Checkbox(label="Render gif", value=False, interactive=True)
|
247 |
-
imggen_submit = gr.Button("Generate", variant="primary")
|
248 |
-
with gr.Row():
|
249 |
-
gr.Examples(
|
250 |
-
examples=example_is,
|
251 |
-
inputs=[input_image],
|
252 |
-
label="Img examples",
|
253 |
-
examples_per_page=10
|
254 |
-
)
|
255 |
-
|
256 |
with gr.Column(scale=3):
|
257 |
with gr.Row():
|
258 |
with gr.Column(scale=2):
|
259 |
-
rem_bg_image = gr.Image(
|
260 |
-
|
|
|
|
|
|
|
|
|
261 |
with gr.Column(scale=3):
|
262 |
-
result_image = gr.Image(
|
263 |
-
|
264 |
-
|
|
|
|
|
|
|
|
|
265 |
result_3dobj = gr.Model3D(
|
266 |
clear_color=[0.0, 0.0, 0.0, 0.0],
|
267 |
-
label="
|
268 |
show_label=True,
|
269 |
visible=True,
|
270 |
camera_position=[90, 90, None],
|
271 |
interactive=False
|
272 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
|
274 |
-
|
275 |
clear_color=[0.0, 0.0, 0.0, 0.0],
|
276 |
-
label="
|
277 |
show_label=True,
|
278 |
visible=True,
|
279 |
camera_position=[90, 90, None],
|
280 |
-
interactive=False
|
281 |
-
)
|
282 |
-
result_gif = gr.Image(label="Rendered GIF", interactive=False)
|
283 |
|
284 |
-
with gr.Row():
|
285 |
-
gr.Markdown(
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
#===============================================================
|
292 |
-
# gradio running code
|
293 |
-
#===============================================================
|
294 |
|
|
|
|
|
|
|
|
|
295 |
none = gr.State(None)
|
296 |
save_folder = gr.State()
|
297 |
cond_image = gr.State()
|
298 |
views_image = gr.State()
|
299 |
text_image = gr.State()
|
300 |
|
|
|
301 |
textgen_submit.click(
|
302 |
-
fn=stage_0_t2i,
|
|
|
303 |
outputs=[rem_bg_image, save_folder],
|
304 |
).success(
|
305 |
-
fn=stage_2_i2v,
|
|
|
306 |
outputs=[views_image, cond_image, result_image],
|
307 |
).success(
|
308 |
-
fn=stage_3_v23,
|
309 |
-
|
310 |
-
|
311 |
-
outputs=[result_3dobj, result_3dglb],
|
312 |
).success(
|
313 |
-
fn=
|
|
|
|
|
|
|
|
|
|
|
314 |
outputs=[result_gif],
|
315 |
).success(lambda: print('Text_to_3D Done ...'))
|
316 |
|
|
|
317 |
imggen_submit.click(
|
318 |
-
fn=stage_0_t2i,
|
|
|
319 |
outputs=[text_image, save_folder],
|
320 |
).success(
|
321 |
-
fn=stage_1_xbg,
|
|
|
322 |
outputs=[rem_bg_image],
|
323 |
).success(
|
324 |
-
fn=stage_2_i2v,
|
|
|
325 |
outputs=[views_image, cond_image, result_image],
|
326 |
).success(
|
327 |
-
fn=stage_3_v23,
|
328 |
-
|
329 |
-
|
330 |
-
|
|
|
|
|
|
|
331 |
).success(
|
332 |
-
fn=stage_4_gif,
|
|
|
333 |
outputs=[result_gif],
|
334 |
).success(lambda: print('Image_to_3D Done ...'))
|
335 |
|
336 |
-
#===============================================================
|
337 |
-
# start gradio server
|
338 |
-
#===============================================================
|
339 |
|
340 |
-
gr.Markdown(CONST_CITATION)
|
341 |
demo.queue(max_size=CONST_MAX_QUEUE)
|
342 |
demo.launch(server_name=CONST_SERVER, server_port=CONST_PORT)
|
343 |
|
|
|
32 |
import numpy as np
|
33 |
from PIL import Image
|
34 |
from einops import rearrange
|
35 |
+
import pandas as pd
|
36 |
|
37 |
from infer import seed_everything, save_gif
|
38 |
from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
|
39 |
+
from third_party.check import check_bake_available
|
40 |
+
|
41 |
+
try:
|
42 |
+
from third_party.mesh_baker import MeshBaker
|
43 |
+
BAKE_AVAILEBLE = True
|
44 |
+
except Exception as err:
|
45 |
+
print(err)
|
46 |
+
print("import baking related fail, run without baking")
|
47 |
+
check_bake_available()
|
48 |
+
BAKE_AVAILEBLE = False
|
49 |
+
|
50 |
|
51 |
warnings.simplefilter('ignore', category=UserWarning)
|
52 |
warnings.simplefilter('ignore', category=FutureWarning)
|
|
|
70 |
CONST_SERVER = '0.0.0.0'
|
71 |
|
72 |
CONST_HEADER = '''
|
73 |
+
<h2><a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'><b>Tencent Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation</b></a></h2>
|
74 |
+
⭐️Technical report: <a href='https://arxiv.org/pdf/2411.02293' target='_blank'>ArXiv</a>. ⭐️Code: <a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'>GitHub</a>.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
'''
|
76 |
|
77 |
+
CONST_NOTE = '''
|
78 |
+
❗️❗️❗️Usage❗️❗️❗️<br>
|
79 |
+
|
80 |
+
Limited by format, the model can only export *.obj mesh with vertex colors. The "texture" mod can only work on *.glb.<br>
|
81 |
+
Please click "Do Rendering" to export a GIF.<br>
|
82 |
+
You can click "Do Baking" to bake multi-view imgaes onto the shape.<br>
|
83 |
+
|
84 |
+
If the results aren't satisfactory, please try a different radnom seed (default is 0).
|
85 |
+
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
################################################################
|
88 |
# prepare text examples and image examples
|
|
|
127 |
)
|
128 |
worker_gif = GifRenderer(args.device)
|
129 |
|
130 |
+
|
131 |
+
if BAKE_AVAILEBLE:
|
132 |
+
worker_baker = MeshBaker()
|
133 |
+
|
134 |
+
|
135 |
+
### functional modules
|
136 |
+
|
137 |
def stage_0_t2i(text, image, seed, step):
|
138 |
os.makedirs('./outputs/app_output', exist_ok=True)
|
139 |
exists = set(int(_) for _ in os.listdir('./outputs/app_output') if not _.startswith("."))
|
|
|
158 |
dst = worker_xbg(image, save_folder)
|
159 |
return dst, save_folder
|
160 |
|
161 |
+
def stage_1_xbg(image, save_folder, force_remove):
|
162 |
if isinstance(image, str):
|
163 |
image = Image.open(image)
|
164 |
dst = save_folder + '/img_nobg.png'
|
165 |
+
rgba = worker_xbg(image, force=force_remove)
|
166 |
rgba.save(dst)
|
167 |
return dst
|
168 |
|
|
|
186 |
seed,
|
187 |
save_folder,
|
188 |
target_face_count = 30000,
|
189 |
+
texture_color = 'texture'
|
|
|
190 |
):
|
191 |
+
do_texture_mapping = texture_color == 'texture'
|
|
|
|
|
192 |
worker_v23(
|
193 |
views_pil,
|
194 |
cond_pil,
|
|
|
197 |
target_face_count = target_face_count,
|
198 |
do_texture_mapping = do_texture_mapping
|
199 |
)
|
200 |
+
glb_dst = save_folder + '/mesh.glb' if do_texture_mapping else None
|
201 |
+
obj_dst = save_folder + '/mesh.obj'
|
202 |
+
obj_dst = save_folder + '/mesh_vertex_colors.obj' # gradio just only can show vertex shading
|
203 |
return obj_dst, glb_dst
|
204 |
|
205 |
+
def stage_3p_baking(save_folder, color, bake):
|
206 |
+
if color == "texture" and bake:
|
207 |
+
obj_dst = worker_baker(save_folder)
|
208 |
+
glb_dst = obj_dst.replace(".obj", ".glb")
|
209 |
+
return glb_dst
|
210 |
+
else:
|
211 |
+
return None
|
212 |
+
|
213 |
+
def stage_4_gif(save_folder, color, bake, render):
|
214 |
+
if not render: return None
|
215 |
+
if os.path.exists(save_folder + '/view_1/bake/mesh.obj'):
|
216 |
+
obj_dst = save_folder + '/view_1/bake/mesh.obj'
|
217 |
+
elif os.path.exists(save_folder + '/view_0/bake/mesh.obj'):
|
218 |
+
obj_dst = save_folder + '/view_0/bake/mesh.obj'
|
219 |
+
elif os.path.exists(save_folder + '/mesh.obj'):
|
220 |
+
obj_dst = save_folder + '/mesh.obj'
|
221 |
+
else:
|
222 |
+
print(save_folder)
|
223 |
+
raise FileNotFoundError("mesh obj file not found")
|
224 |
+
gif_dst = obj_dst.replace(".obj", ".gif")
|
225 |
+
worker_gif(obj_dst, gif_dst_path=gif_dst)
|
226 |
return gif_dst
|
227 |
+
|
228 |
+
|
229 |
+
def check_image_available(image):
|
230 |
+
if image.mode == "RGBA":
|
231 |
+
data = np.array(image)
|
232 |
+
alpha_channel = data[:, :, 3]
|
233 |
+
unique_alpha_values = np.unique(alpha_channel)
|
234 |
+
if len(unique_alpha_values) == 1:
|
235 |
+
msg = "The alpha channel is missing or invalid. The background removal option is selected for you."
|
236 |
+
return msg, gr.update(value=True, interactive=False)
|
237 |
+
else:
|
238 |
+
msg = "The image has four channels, and you can choose to remove the background or not."
|
239 |
+
return msg, gr.update(value=False, interactive=True)
|
240 |
+
elif image.mode == "RGB":
|
241 |
+
msg = "The alpha channel is missing or invalid. The background removal option is selected for you."
|
242 |
+
return msg, gr.update(value=True, interactive=False)
|
243 |
+
else:
|
244 |
+
raise Exception("Image Error")
|
245 |
+
|
246 |
+
def update_bake_render(color):
|
247 |
+
if color == "vertex":
|
248 |
+
return gr.update(value=False, interactive=False), gr.update(value=False, interactive=False)
|
249 |
+
else:
|
250 |
+
return gr.update(interactive=True), gr.update(interactive=True)
|
251 |
+
|
252 |
# ===============================================================
|
253 |
# gradio display
|
254 |
# ===============================================================
|
255 |
+
|
256 |
with gr.Blocks() as demo:
|
257 |
gr.Markdown(CONST_HEADER)
|
258 |
with gr.Row(variant="panel"):
|
259 |
+
|
260 |
+
###### Input region
|
261 |
+
|
262 |
with gr.Column(scale=2):
|
263 |
+
|
264 |
+
### Text iutput region
|
265 |
+
|
266 |
with gr.Tab("Text to 3D"):
|
267 |
with gr.Column():
|
268 |
+
text = gr.TextArea('一只黑白相间的熊猫在白色背景上居中坐着,呈现出卡通风格和可爱氛围。',
|
269 |
+
lines=3, max_lines=20, label='Input text')
|
270 |
with gr.Row():
|
271 |
+
textgen_color = gr.Radio(choices=["vertex", "texture"], label="Color", value="texture")
|
272 |
+
with gr.Row():
|
273 |
+
textgen_render = gr.Checkbox(label="Do Rendering", value=True, interactive=True)
|
274 |
+
if BAKE_AVAILEBLE:
|
275 |
+
textgen_bake = gr.Checkbox(label="Do Baking", value=True, interactive=True)
|
276 |
+
else:
|
277 |
+
textgen_bake = gr.Checkbox(label="Do Baking", value=False, interactive=False)
|
278 |
+
|
279 |
+
textgen_color.change(
|
280 |
+
fn=update_bake_render,
|
281 |
+
inputs=textgen_color,
|
282 |
+
outputs=[textgen_bake, textgen_render]
|
283 |
+
)
|
284 |
+
|
285 |
+
with gr.Row():
|
286 |
+
textgen_seed = gr.Number(value=0, label="T2I seed", precision=0, interactive=True)
|
287 |
+
textgen_step = gr.Number(value=25, label="T2I steps", precision=0,
|
288 |
+
minimum=10, maximum=50, interactive=True)
|
289 |
+
textgen_SEED = gr.Number(value=0, label="Gen seed", precision=0, interactive=True)
|
290 |
+
textgen_STEP = gr.Number(value=50, label="Gen steps", precision=0,
|
291 |
+
minimum=40, maximum=100, interactive=True)
|
292 |
+
textgen_max_faces = gr.Number(value=90000, label="Face number", precision=0,
|
293 |
+
minimum=5000, maximum=1000000, interactive=True)
|
294 |
with gr.Row():
|
|
|
|
|
295 |
textgen_submit = gr.Button("Generate", variant="primary")
|
296 |
|
297 |
with gr.Row():
|
298 |
+
gr.Examples(examples=example_ts, inputs=[text], label="Text examples", examples_per_page=10)
|
299 |
|
300 |
+
### Image iutput region
|
301 |
+
|
302 |
with gr.Tab("Image to 3D"):
|
303 |
+
with gr.Row():
|
304 |
+
input_image = gr.Image(label="Input image", width=256, height=256, type="pil",
|
305 |
+
image_mode="RGBA", sources="upload", interactive=True)
|
306 |
+
with gr.Row():
|
307 |
+
alert_message = gr.Markdown("") # for warning
|
308 |
+
with gr.Row():
|
309 |
+
imggen_color = gr.Radio(choices=["vertex", "texture"], label="Color", value="texture")
|
310 |
+
with gr.Row():
|
311 |
+
imggen_removebg = gr.Checkbox(label="Remove Background", value=True, interactive=True)
|
312 |
+
imggen_render = gr.Checkbox(label="Do Rendering", value=True, interactive=True)
|
313 |
+
if BAKE_AVAILEBLE:
|
314 |
+
imggen_bake = gr.Checkbox(label="Do Baking", value=True, interactive=True)
|
315 |
+
else:
|
316 |
+
imggen_bake = gr.Checkbox(label="Do Baking", value=False, interactive=False)
|
317 |
+
|
318 |
+
input_image.change(
|
319 |
+
fn=check_image_available,
|
320 |
+
inputs=input_image,
|
321 |
+
outputs=[alert_message, imggen_removebg]
|
322 |
+
)
|
323 |
+
imggen_color.change(
|
324 |
+
fn=update_bake_render,
|
325 |
+
inputs=imggen_color,
|
326 |
+
outputs=[imggen_bake, imggen_render]
|
327 |
+
)
|
328 |
+
|
329 |
+
with gr.Row():
|
330 |
+
imggen_SEED = gr.Number(value=0, label="Gen seed", precision=0, interactive=True)
|
331 |
+
imggen_STEP = gr.Number(value=50, label="Gen steps", precision=0,
|
332 |
+
minimum=40, maximum=100, interactive=True)
|
333 |
+
imggen_max_faces = gr.Number(value=90000, label="Face number", precision=0,
|
334 |
+
minimum=5000, maximum=1000000, interactive=True)
|
335 |
+
with gr.Row():
|
336 |
+
imggen_submit = gr.Button("Generate", variant="primary")
|
337 |
+
|
338 |
+
with gr.Row():
|
339 |
+
gr.Examples(examples=example_is, inputs=[input_image],
|
340 |
+
label="Img examples", examples_per_page=10)
|
341 |
+
|
342 |
+
gr.Markdown(CONST_NOTE)
|
343 |
+
|
344 |
+
###### Output region
|
345 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
with gr.Column(scale=3):
|
347 |
with gr.Row():
|
348 |
with gr.Column(scale=2):
|
349 |
+
rem_bg_image = gr.Image(
|
350 |
+
label="Image without background",
|
351 |
+
type="pil",
|
352 |
+
image_mode="RGBA",
|
353 |
+
interactive=False
|
354 |
+
)
|
355 |
with gr.Column(scale=3):
|
356 |
+
result_image = gr.Image(
|
357 |
+
label="Multi-view images",
|
358 |
+
type="pil",
|
359 |
+
interactive=False
|
360 |
+
)
|
361 |
+
|
362 |
+
with gr.Row():
|
363 |
result_3dobj = gr.Model3D(
|
364 |
clear_color=[0.0, 0.0, 0.0, 0.0],
|
365 |
+
label="OBJ vertex color",
|
366 |
show_label=True,
|
367 |
visible=True,
|
368 |
camera_position=[90, 90, None],
|
369 |
interactive=False
|
370 |
)
|
371 |
+
result_gif = gr.Image(label="GIF", interactive=False)
|
372 |
+
|
373 |
+
with gr.Row():
|
374 |
+
result_3dglb_texture = gr.Model3D(
|
375 |
+
clear_color=[0.0, 0.0, 0.0, 0.0],
|
376 |
+
label="GLB texture color",
|
377 |
+
show_label=True,
|
378 |
+
visible=True,
|
379 |
+
camera_position=[90, 90, None],
|
380 |
+
interactive=False)
|
381 |
|
382 |
+
result_3dglb_baked = gr.Model3D(
|
383 |
clear_color=[0.0, 0.0, 0.0, 0.0],
|
384 |
+
label="GLB baked color",
|
385 |
show_label=True,
|
386 |
visible=True,
|
387 |
camera_position=[90, 90, None],
|
388 |
+
interactive=False)
|
|
|
|
|
389 |
|
390 |
+
with gr.Row():
|
391 |
+
gr.Markdown(
|
392 |
+
"Due to Gradio limitations, OBJ files are displayed with vertex shading only, "
|
393 |
+
"while GLB files can be viewed with texture shading. <br>For the best experience, "
|
394 |
+
"we recommend downloading the GLB files and opening them with 3D software "
|
395 |
+
"like Blender or MeshLab."
|
396 |
+
)
|
|
|
|
|
|
|
397 |
|
398 |
+
#===============================================================
|
399 |
+
# gradio running code
|
400 |
+
#===============================================================
|
401 |
+
|
402 |
none = gr.State(None)
|
403 |
save_folder = gr.State()
|
404 |
cond_image = gr.State()
|
405 |
views_image = gr.State()
|
406 |
text_image = gr.State()
|
407 |
|
408 |
+
|
409 |
textgen_submit.click(
|
410 |
+
fn=stage_0_t2i,
|
411 |
+
inputs=[text, none, textgen_seed, textgen_step],
|
412 |
outputs=[rem_bg_image, save_folder],
|
413 |
).success(
|
414 |
+
fn=stage_2_i2v,
|
415 |
+
inputs=[rem_bg_image, textgen_SEED, textgen_STEP, save_folder],
|
416 |
outputs=[views_image, cond_image, result_image],
|
417 |
).success(
|
418 |
+
fn=stage_3_v23,
|
419 |
+
inputs=[views_image, cond_image, textgen_SEED, save_folder, textgen_max_faces, textgen_color],
|
420 |
+
outputs=[result_3dobj, result_3dglb_texture],
|
|
|
421 |
).success(
|
422 |
+
fn=stage_3p_baking,
|
423 |
+
inputs=[save_folder, textgen_color, textgen_bake],
|
424 |
+
outputs=[result_3dglb_baked],
|
425 |
+
).success(
|
426 |
+
fn=stage_4_gif,
|
427 |
+
inputs=[save_folder, textgen_color, textgen_bake, textgen_render],
|
428 |
outputs=[result_gif],
|
429 |
).success(lambda: print('Text_to_3D Done ...'))
|
430 |
|
431 |
+
|
432 |
imggen_submit.click(
|
433 |
+
fn=stage_0_t2i,
|
434 |
+
inputs=[none, input_image, textgen_seed, textgen_step],
|
435 |
outputs=[text_image, save_folder],
|
436 |
).success(
|
437 |
+
fn=stage_1_xbg,
|
438 |
+
inputs=[text_image, save_folder, imggen_removebg],
|
439 |
outputs=[rem_bg_image],
|
440 |
).success(
|
441 |
+
fn=stage_2_i2v,
|
442 |
+
inputs=[rem_bg_image, imggen_SEED, imggen_STEP, save_folder],
|
443 |
outputs=[views_image, cond_image, result_image],
|
444 |
).success(
|
445 |
+
fn=stage_3_v23,
|
446 |
+
inputs=[views_image, cond_image, imggen_SEED, save_folder, imggen_max_faces, imggen_color],
|
447 |
+
outputs=[result_3dobj, result_3dglb_texture],
|
448 |
+
).success(
|
449 |
+
fn=stage_3p_baking,
|
450 |
+
inputs=[save_folder, imggen_color, imggen_bake],
|
451 |
+
outputs=[result_3dglb_baked],
|
452 |
).success(
|
453 |
+
fn=stage_4_gif,
|
454 |
+
inputs=[save_folder, imggen_color, imggen_bake, imggen_render],
|
455 |
outputs=[result_gif],
|
456 |
).success(lambda: print('Image_to_3D Done ...'))
|
457 |
|
458 |
+
#===============================================================
|
459 |
+
# start gradio server
|
460 |
+
#===============================================================
|
461 |
|
|
|
462 |
demo.queue(max_size=CONST_MAX_QUEUE)
|
463 |
demo.launch(server_name=CONST_SERVER, server_port=CONST_PORT)
|
464 |
|
env_install.sh
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
pip3 install diffusers transformers
|
2 |
pip3 install rembg tqdm omegaconf matplotlib opencv-python imageio jaxtyping einops
|
3 |
-
pip3 install SentencePiece accelerate trimesh PyMCubes xatlas libigl ninja gradio
|
4 |
pip3 install git+https://github.com/facebookresearch/pytorch3d@stable
|
5 |
pip3 install git+https://github.com/NVlabs/nvdiffrast
|
6 |
pip3 install open3d
|
|
|
1 |
pip3 install diffusers transformers
|
2 |
pip3 install rembg tqdm omegaconf matplotlib opencv-python imageio jaxtyping einops
|
3 |
+
pip3 install SentencePiece accelerate trimesh PyMCubes xatlas libigl ninja gradio roma
|
4 |
pip3 install git+https://github.com/facebookresearch/pytorch3d@stable
|
5 |
pip3 install git+https://github.com/NVlabs/nvdiffrast
|
6 |
pip3 install open3d
|
infer/gif_render.py
CHANGED
@@ -25,7 +25,7 @@
|
|
25 |
import os, sys
|
26 |
sys.path.insert(0, f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}")
|
27 |
|
28 |
-
from svrm.ldm.vis_util import
|
29 |
from infer.utils import seed_everything, timing_decorator
|
30 |
|
31 |
class GifRenderer():
|
@@ -40,14 +40,14 @@ class GifRenderer():
|
|
40 |
self,
|
41 |
obj_filename,
|
42 |
elev=0,
|
43 |
-
azim=
|
44 |
resolution=512,
|
45 |
gif_dst_path='',
|
46 |
n_views=120,
|
47 |
fps=30,
|
48 |
rgb=True
|
49 |
):
|
50 |
-
|
51 |
obj_filename,
|
52 |
elev=elev,
|
53 |
azim=azim,
|
|
|
25 |
import os, sys
|
26 |
sys.path.insert(0, f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}")
|
27 |
|
28 |
+
from svrm.ldm.vis_util import render_func
|
29 |
from infer.utils import seed_everything, timing_decorator
|
30 |
|
31 |
class GifRenderer():
|
|
|
40 |
self,
|
41 |
obj_filename,
|
42 |
elev=0,
|
43 |
+
azim=None,
|
44 |
resolution=512,
|
45 |
gif_dst_path='',
|
46 |
n_views=120,
|
47 |
fps=30,
|
48 |
rgb=True
|
49 |
):
|
50 |
+
render_func(
|
51 |
obj_filename,
|
52 |
elev=elev,
|
53 |
azim=azim,
|
infer/image_to_views.py
CHANGED
@@ -48,21 +48,26 @@ def save_gif(pils, save_path, df=False):
|
|
48 |
|
49 |
|
50 |
class Image2Views():
|
51 |
-
def __init__(self,
|
|
|
|
|
|
|
52 |
self.device = device
|
53 |
if use_lite:
|
|
|
54 |
self.pipe = Hunyuan3d_MVD_Lite_Pipeline.from_pretrained(
|
55 |
-
|
56 |
torch_dtype = torch.float16,
|
57 |
use_safetensors = True,
|
58 |
)
|
59 |
else:
|
|
|
60 |
self.pipe = HunYuan3D_MVD_Std_Pipeline.from_pretrained(
|
61 |
-
|
62 |
torch_dtype = torch.float16,
|
63 |
use_safetensors = True,
|
64 |
)
|
65 |
-
self.pipe = self.pipe.to(device)
|
66 |
self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
|
67 |
self.save_memory = save_memory
|
68 |
set_parameter_grad_false(self.pipe.unet)
|
|
|
48 |
|
49 |
|
50 |
class Image2Views():
|
51 |
+
def __init__(self,
|
52 |
+
device="cuda:0", use_lite=False, save_memory=False,
|
53 |
+
std_pretrain='./weights/mvd_std', lite_pretrain='./weights/mvd_lite'
|
54 |
+
):
|
55 |
self.device = device
|
56 |
if use_lite:
|
57 |
+
print("loading", lite_pretrain)
|
58 |
self.pipe = Hunyuan3d_MVD_Lite_Pipeline.from_pretrained(
|
59 |
+
lite_pretrain,
|
60 |
torch_dtype = torch.float16,
|
61 |
use_safetensors = True,
|
62 |
)
|
63 |
else:
|
64 |
+
print("loadding", std_pretrain)
|
65 |
self.pipe = HunYuan3D_MVD_Std_Pipeline.from_pretrained(
|
66 |
+
std_pretrain,
|
67 |
torch_dtype = torch.float16,
|
68 |
use_safetensors = True,
|
69 |
)
|
70 |
+
self.pipe = self.pipe if save_memory else self.pipe.to(device)
|
71 |
self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
|
72 |
self.save_memory = save_memory
|
73 |
set_parameter_grad_false(self.pipe.unet)
|
infer/text_to_image.py
CHANGED
@@ -46,8 +46,7 @@ class Text2Image():
|
|
46 |
)
|
47 |
set_parameter_grad_false(self.pipe.transformer)
|
48 |
print('text2image transformer model', get_parameter_number(self.pipe.transformer))
|
49 |
-
if
|
50 |
-
self.pipe = self.pipe.to(device)
|
51 |
self.neg_txt = "文本,特写,裁剪,出框,最差质量,低质量,JPEG伪影,PGLY,重复,病态,残缺,多余的手指,变异的手," \
|
52 |
"画得不好的手,画得不好的脸,变异,畸形,模糊,脱水,糟糕的解剖学,糟糕的比例,多余的肢体,克隆的脸," \
|
53 |
"毁容,恶心的比例,畸形的肢体,缺失的手臂,缺失的腿,额外的手臂,额外的腿,融合的手指,手指太多,长脖子"
|
|
|
46 |
)
|
47 |
set_parameter_grad_false(self.pipe.transformer)
|
48 |
print('text2image transformer model', get_parameter_number(self.pipe.transformer))
|
49 |
+
self.pipe = self.pipe if save_memory else self.pipe.to(device)
|
|
|
50 |
self.neg_txt = "文本,特写,裁剪,出框,最差质量,低质量,JPEG伪影,PGLY,重复,病态,残缺,多余的手指,变异的手," \
|
51 |
"画得不好的手,画得不好的脸,变异,畸形,模糊,脱水,糟糕的解剖学,糟糕的比例,多余的肢体,克隆的脸," \
|
52 |
"毁容,恶心的比例,畸形的肢体,缺失的手臂,缺失的腿,额外的手臂,额外的腿,融合的手指,手指太多,长脖子"
|
infer/utils.py
CHANGED
@@ -21,7 +21,8 @@
|
|
21 |
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
22 |
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
23 |
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
24 |
-
|
|
|
25 |
import os
|
26 |
import time
|
27 |
import random
|
@@ -30,6 +31,7 @@ import torch
|
|
30 |
from torch.cuda.amp import autocast, GradScaler
|
31 |
from functools import wraps
|
32 |
|
|
|
33 |
def seed_everything(seed):
|
34 |
'''
|
35 |
seed everthing
|
@@ -39,6 +41,7 @@ def seed_everything(seed):
|
|
39 |
torch.manual_seed(seed)
|
40 |
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
41 |
|
|
|
42 |
def timing_decorator(category: str):
|
43 |
'''
|
44 |
timing_decorator: record time
|
@@ -57,6 +60,7 @@ def timing_decorator(category: str):
|
|
57 |
return wrapper
|
58 |
return decorator
|
59 |
|
|
|
60 |
def auto_amp_inference(func):
|
61 |
'''
|
62 |
with torch.cuda.amp.autocast()"
|
@@ -69,11 +73,13 @@ def auto_amp_inference(func):
|
|
69 |
return output
|
70 |
return wrapper
|
71 |
|
|
|
72 |
def get_parameter_number(model):
|
73 |
total_num = sum(p.numel() for p in model.parameters())
|
74 |
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
75 |
return {'Total': total_num, 'Trainable': trainable_num}
|
76 |
|
|
|
77 |
def set_parameter_grad_false(model):
|
78 |
for p in model.parameters():
|
79 |
p.requires_grad = False
|
|
|
21 |
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
22 |
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
23 |
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
24 |
+
import sys
|
25 |
+
import io
|
26 |
import os
|
27 |
import time
|
28 |
import random
|
|
|
31 |
from torch.cuda.amp import autocast, GradScaler
|
32 |
from functools import wraps
|
33 |
|
34 |
+
|
35 |
def seed_everything(seed):
|
36 |
'''
|
37 |
seed everthing
|
|
|
41 |
torch.manual_seed(seed)
|
42 |
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
43 |
|
44 |
+
|
45 |
def timing_decorator(category: str):
|
46 |
'''
|
47 |
timing_decorator: record time
|
|
|
60 |
return wrapper
|
61 |
return decorator
|
62 |
|
63 |
+
|
64 |
def auto_amp_inference(func):
|
65 |
'''
|
66 |
with torch.cuda.amp.autocast()"
|
|
|
73 |
return output
|
74 |
return wrapper
|
75 |
|
76 |
+
|
77 |
def get_parameter_number(model):
|
78 |
total_num = sum(p.numel() for p in model.parameters())
|
79 |
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
80 |
return {'Total': total_num, 'Trainable': trainable_num}
|
81 |
|
82 |
+
|
83 |
def set_parameter_grad_false(model):
|
84 |
for p in model.parameters():
|
85 |
p.requires_grad = False
|
infer/views_to_mesh.py
CHANGED
@@ -47,11 +47,15 @@ class Views2Mesh():
|
|
47 |
use_lite: lite version
|
48 |
save_memory: cpu auto
|
49 |
'''
|
50 |
-
self.mv23d_predictor = MV23DPredictor(mv23d_ckt_path, mv23d_cfg_path, device=device)
|
51 |
-
self.mv23d_predictor.model.eval()
|
52 |
-
self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
|
53 |
self.device = device
|
54 |
self.save_memory = save_memory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
set_parameter_grad_false(self.mv23d_predictor.model)
|
56 |
print('view2mesh model', get_parameter_number(self.mv23d_predictor.model))
|
57 |
|
@@ -109,7 +113,6 @@ class Views2Mesh():
|
|
109 |
do_texture_mapping = do_texture_mapping
|
110 |
)
|
111 |
torch.cuda.empty_cache()
|
112 |
-
return save_dir
|
113 |
|
114 |
|
115 |
if __name__ == "__main__":
|
|
|
47 |
use_lite: lite version
|
48 |
save_memory: cpu auto
|
49 |
'''
|
|
|
|
|
|
|
50 |
self.device = device
|
51 |
self.save_memory = save_memory
|
52 |
+
self.mv23d_predictor = MV23DPredictor(
|
53 |
+
mv23d_ckt_path,
|
54 |
+
mv23d_cfg_path,
|
55 |
+
device = "cpu" if save_memory else device
|
56 |
+
)
|
57 |
+
self.mv23d_predictor.model.eval()
|
58 |
+
self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
|
59 |
set_parameter_grad_false(self.mv23d_predictor.model)
|
60 |
print('view2mesh model', get_parameter_number(self.mv23d_predictor.model))
|
61 |
|
|
|
113 |
do_texture_mapping = do_texture_mapping
|
114 |
)
|
115 |
torch.cuda.empty_cache()
|
|
|
116 |
|
117 |
|
118 |
if __name__ == "__main__":
|
main.py
CHANGED
@@ -24,16 +24,28 @@
|
|
24 |
|
25 |
import os
|
26 |
import warnings
|
27 |
-
import torch
|
28 |
-
from PIL import Image
|
29 |
import argparse
|
30 |
-
|
31 |
-
from
|
|
|
32 |
|
33 |
warnings.simplefilter('ignore', category=UserWarning)
|
34 |
warnings.simplefilter('ignore', category=FutureWarning)
|
35 |
warnings.simplefilter('ignore', category=DeprecationWarning)
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
def get_args():
|
38 |
parser = argparse.ArgumentParser()
|
39 |
parser.add_argument(
|
@@ -73,8 +85,8 @@ def get_args():
|
|
73 |
"--gen_steps", default=50, type=int
|
74 |
)
|
75 |
parser.add_argument(
|
76 |
-
"--max_faces_num", default=
|
77 |
-
help="max num of face, suggest
|
78 |
)
|
79 |
parser.add_argument(
|
80 |
"--save_memory", default=False, action="store_true"
|
@@ -85,6 +97,13 @@ def get_args():
|
|
85 |
parser.add_argument(
|
86 |
"--do_render", default=False, action="store_true"
|
87 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
return parser.parse_args()
|
89 |
|
90 |
|
@@ -95,6 +114,7 @@ if __name__ == "__main__":
|
|
95 |
assert args.text_prompt or args.image_prompt, "Text and image can only be given to one"
|
96 |
|
97 |
# init model
|
|
|
98 |
rembg_model = Removebg()
|
99 |
image_to_views_model = Image2Views(
|
100 |
device=args.device,
|
@@ -116,9 +136,18 @@ if __name__ == "__main__":
|
|
116 |
device = args.device,
|
117 |
save_memory = args.save_memory
|
118 |
)
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
gif_renderer = GifRenderer(device=args.device)
|
121 |
-
|
|
|
|
|
122 |
# ---- ----- ---- ---- ---- ----
|
123 |
|
124 |
os.makedirs(args.save_folder, exist_ok=True)
|
@@ -136,7 +165,7 @@ if __name__ == "__main__":
|
|
136 |
|
137 |
# stage 2, remove back ground
|
138 |
res_rgba_pil = rembg_model(res_rgb_pil)
|
139 |
-
|
140 |
|
141 |
# stage 3, image to views
|
142 |
(views_grid_pil, cond_img), view_pil_list = image_to_views_model(
|
@@ -155,10 +184,29 @@ if __name__ == "__main__":
|
|
155 |
save_folder = args.save_folder,
|
156 |
do_texture_mapping = args.do_texture_mapping
|
157 |
)
|
158 |
-
|
159 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
if args.do_render:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
gif_renderer(
|
162 |
-
|
163 |
gif_dst_path = os.path.join(args.save_folder, 'output.gif'),
|
164 |
)
|
|
|
24 |
|
25 |
import os
|
26 |
import warnings
|
|
|
|
|
27 |
import argparse
|
28 |
+
import time
|
29 |
+
from PIL import Image
|
30 |
+
import torch
|
31 |
|
32 |
warnings.simplefilter('ignore', category=UserWarning)
|
33 |
warnings.simplefilter('ignore', category=FutureWarning)
|
34 |
warnings.simplefilter('ignore', category=DeprecationWarning)
|
35 |
|
36 |
+
from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
|
37 |
+
from third_party.mesh_baker import MeshBaker
|
38 |
+
from third_party.check import check_bake_available
|
39 |
+
|
40 |
+
try:
|
41 |
+
from third_party.mesh_baker import MeshBaker
|
42 |
+
assert check_bake_available()
|
43 |
+
BAKE_AVAILEBLE = True
|
44 |
+
except Exception as err:
|
45 |
+
print(err)
|
46 |
+
print("import baking related fail, run without baking")
|
47 |
+
BAKE_AVAILEBLE = False
|
48 |
+
|
49 |
def get_args():
|
50 |
parser = argparse.ArgumentParser()
|
51 |
parser.add_argument(
|
|
|
85 |
"--gen_steps", default=50, type=int
|
86 |
)
|
87 |
parser.add_argument(
|
88 |
+
"--max_faces_num", default=90000, type=int,
|
89 |
+
help="max num of face, suggest 90000 for effect, 10000 for speed"
|
90 |
)
|
91 |
parser.add_argument(
|
92 |
"--save_memory", default=False, action="store_true"
|
|
|
97 |
parser.add_argument(
|
98 |
"--do_render", default=False, action="store_true"
|
99 |
)
|
100 |
+
parser.add_argument(
|
101 |
+
"--do_bake", default=False, action="store_true"
|
102 |
+
)
|
103 |
+
parser.add_argument(
|
104 |
+
"--bake_align_times", default=3, type=int,
|
105 |
+
help="align times between view image and mesh, suggest 1~6"
|
106 |
+
)
|
107 |
return parser.parse_args()
|
108 |
|
109 |
|
|
|
114 |
assert args.text_prompt or args.image_prompt, "Text and image can only be given to one"
|
115 |
|
116 |
# init model
|
117 |
+
st = time.time()
|
118 |
rembg_model = Removebg()
|
119 |
image_to_views_model = Image2Views(
|
120 |
device=args.device,
|
|
|
136 |
device = args.device,
|
137 |
save_memory = args.save_memory
|
138 |
)
|
139 |
+
|
140 |
+
if args.do_bake and BAKE_AVAILEBLE:
|
141 |
+
mesh_baker = MeshBaker(
|
142 |
+
device = args.device,
|
143 |
+
align_times = args.bake_align_times
|
144 |
+
)
|
145 |
+
|
146 |
+
if check_bake_available():
|
147 |
gif_renderer = GifRenderer(device=args.device)
|
148 |
+
|
149 |
+
print(f"Init Models cost {time.time()-st}s")
|
150 |
+
|
151 |
# ---- ----- ---- ---- ---- ----
|
152 |
|
153 |
os.makedirs(args.save_folder, exist_ok=True)
|
|
|
165 |
|
166 |
# stage 2, remove back ground
|
167 |
res_rgba_pil = rembg_model(res_rgb_pil)
|
168 |
+
res_rgba_pil.save(os.path.join(args.save_folder, "img_nobg.png"))
|
169 |
|
170 |
# stage 3, image to views
|
171 |
(views_grid_pil, cond_img), view_pil_list = image_to_views_model(
|
|
|
184 |
save_folder = args.save_folder,
|
185 |
do_texture_mapping = args.do_texture_mapping
|
186 |
)
|
187 |
+
|
188 |
+
# stage 5, baking
|
189 |
+
mesh_file_for_render = None
|
190 |
+
if args.do_bake and BAKE_AVAILEBLE:
|
191 |
+
mesh_file_for_render = mesh_baker(args.save_folder)
|
192 |
+
|
193 |
+
# stage 6, render gif
|
194 |
+
# todo fix: if init folder unclear, it maybe mistake rendering
|
195 |
if args.do_render:
|
196 |
+
if mesh_file_for_render and os.path.exists(mesh_file_for_render):
|
197 |
+
mesh_file_for_render = mesh_file_for_render
|
198 |
+
elif os.path.exists(os.path.join(args.save_folder, 'view_1/bake/mesh.obj')):
|
199 |
+
mesh_file_for_render = os.path.join(args.save_folder, 'view_1/bake/mesh.obj')
|
200 |
+
elif os.path.exists(os.path.join(args.save_folder, 'view_0/bake/mesh.obj')):
|
201 |
+
mesh_file_for_render = os.path.join(args.save_folder, 'view_0/bake/mesh.obj')
|
202 |
+
elif os.path.exists(os.path.join(args.save_folder, 'mesh.obj')):
|
203 |
+
mesh_file_for_render = os.path.join(args.save_folder, 'mesh.obj')
|
204 |
+
else:
|
205 |
+
raise FileNotFoundError("mesh_file_for_render not found")
|
206 |
+
|
207 |
+
print("Rendering 3d file:", mesh_file_for_render)
|
208 |
+
|
209 |
gif_renderer(
|
210 |
+
mesh_file_for_render,
|
211 |
gif_dst_path = os.path.join(args.save_folder, 'output.gif'),
|
212 |
)
|
requirements.txt
CHANGED
@@ -22,3 +22,4 @@ git+https://github.com/facebookresearch/pytorch3d@stable
|
|
22 |
git+https://github.com/NVlabs/nvdiffrast
|
23 |
open3d
|
24 |
ninja
|
|
|
|
22 |
git+https://github.com/NVlabs/nvdiffrast
|
23 |
open3d
|
24 |
ninja
|
25 |
+
roma
|
svrm/ldm/models/svrm.py
CHANGED
@@ -46,7 +46,7 @@ from ..modules.rendering_neus.rasterize import NVDiffRasterizerContext
|
|
46 |
|
47 |
from ..utils.ops import scale_tensor
|
48 |
from ..util import count_params, instantiate_from_config
|
49 |
-
from ..vis_util import
|
50 |
|
51 |
|
52 |
def unwrap_uv(v_pos, t_pos_idx):
|
@@ -58,7 +58,6 @@ def unwrap_uv(v_pos, t_pos_idx):
|
|
58 |
indices = indices.astype(np.int64, casting="same_kind")
|
59 |
return uvs, indices
|
60 |
|
61 |
-
|
62 |
def uv_padding(image, hole_mask, uv_padding_size = 2):
|
63 |
return cv2.inpaint(
|
64 |
(image.detach().cpu().numpy() * 255).astype(np.uint8),
|
@@ -120,14 +119,16 @@ class SVRMModel(torch.nn.Module):
|
|
120 |
out_dir = 'outputs/test'
|
121 |
):
|
122 |
"""
|
123 |
-
|
124 |
"""
|
125 |
|
126 |
-
obj_vertext_path = os.path.join(out_dir, '
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
131 |
|
132 |
st = time.time()
|
133 |
|
@@ -204,15 +205,13 @@ class SVRMModel(torch.nn.Module):
|
|
204 |
mesh = trimesh.load_mesh(obj_vertext_path)
|
205 |
print(f"=====> generate mesh with vertex shading time: {time.time() - st}")
|
206 |
st = time.time()
|
207 |
-
|
208 |
if not do_texture_mapping:
|
209 |
-
|
210 |
-
mesh.export(glb_path, file_type='glb')
|
211 |
-
return None
|
212 |
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
|
217 |
st = time.time()
|
218 |
|
@@ -238,12 +237,9 @@ class SVRMModel(torch.nn.Module):
|
|
238 |
|
239 |
# Interpolate world space position
|
240 |
gb_pos = ctx.interpolate_one(vtx_refine, rast[None, ...], faces_refine)[0][0]
|
241 |
-
|
242 |
with torch.no_grad():
|
243 |
gb_mask_pos_scale = scale_tensor(gb_pos.unsqueeze(0).view(1, -1, 3), (-1, 1), (-1, 1))
|
244 |
-
|
245 |
tex_map = self.render.forward_points(cur_triplane, gb_mask_pos_scale)['rgb']
|
246 |
-
|
247 |
tex_map = tex_map.float().squeeze(0) # (0, 1)
|
248 |
tex_map = tex_map.view((texture_res, texture_res, 3))
|
249 |
img = uv_padding(tex_map, hole_mask)
|
@@ -257,7 +253,7 @@ class SVRMModel(torch.nn.Module):
|
|
257 |
fid.write('newmtl material_0\n')
|
258 |
fid.write("Ka 1.000 1.000 1.000\n")
|
259 |
fid.write("Kd 1.000 1.000 1.000\n")
|
260 |
-
fid.write("Ks 0.
|
261 |
fid.write("d 1.0\n")
|
262 |
fid.write("illum 2\n")
|
263 |
fid.write(f'map_Kd texture.png\n')
|
@@ -278,4 +274,5 @@ class SVRMModel(torch.nn.Module):
|
|
278 |
mesh = trimesh.load_mesh(obj_path)
|
279 |
mesh.export(glb_path, file_type='glb')
|
280 |
print(f"=====> generate mesh with texture shading time: {time.time() - st}")
|
|
|
281 |
|
|
|
46 |
|
47 |
from ..utils.ops import scale_tensor
|
48 |
from ..util import count_params, instantiate_from_config
|
49 |
+
from ..vis_util import render_func
|
50 |
|
51 |
|
52 |
def unwrap_uv(v_pos, t_pos_idx):
|
|
|
58 |
indices = indices.astype(np.int64, casting="same_kind")
|
59 |
return uvs, indices
|
60 |
|
|
|
61 |
def uv_padding(image, hole_mask, uv_padding_size = 2):
|
62 |
return cv2.inpaint(
|
63 |
(image.detach().cpu().numpy() * 255).astype(np.uint8),
|
|
|
119 |
out_dir = 'outputs/test'
|
120 |
):
|
121 |
"""
|
122 |
+
do_texture_mapping: True for ray texture, False for vertices texture
|
123 |
"""
|
124 |
|
125 |
+
obj_vertext_path = os.path.join(out_dir, 'mesh_vertex_colors.obj')
|
126 |
+
|
127 |
+
if do_texture_mapping:
|
128 |
+
obj_path = os.path.join(out_dir, 'mesh.obj')
|
129 |
+
obj_texture_path = os.path.join(out_dir, 'texture.png')
|
130 |
+
obj_mtl_path = os.path.join(out_dir, 'texture.mtl')
|
131 |
+
glb_path = os.path.join(out_dir, 'mesh.glb')
|
132 |
|
133 |
st = time.time()
|
134 |
|
|
|
205 |
mesh = trimesh.load_mesh(obj_vertext_path)
|
206 |
print(f"=====> generate mesh with vertex shading time: {time.time() - st}")
|
207 |
st = time.time()
|
208 |
+
|
209 |
if not do_texture_mapping:
|
210 |
+
return obj_vertext_path, None
|
|
|
|
|
211 |
|
212 |
+
###########################################################
|
213 |
+
#------------- export texture -----------------------
|
214 |
+
###########################################################
|
215 |
|
216 |
st = time.time()
|
217 |
|
|
|
237 |
|
238 |
# Interpolate world space position
|
239 |
gb_pos = ctx.interpolate_one(vtx_refine, rast[None, ...], faces_refine)[0][0]
|
|
|
240 |
with torch.no_grad():
|
241 |
gb_mask_pos_scale = scale_tensor(gb_pos.unsqueeze(0).view(1, -1, 3), (-1, 1), (-1, 1))
|
|
|
242 |
tex_map = self.render.forward_points(cur_triplane, gb_mask_pos_scale)['rgb']
|
|
|
243 |
tex_map = tex_map.float().squeeze(0) # (0, 1)
|
244 |
tex_map = tex_map.view((texture_res, texture_res, 3))
|
245 |
img = uv_padding(tex_map, hole_mask)
|
|
|
253 |
fid.write('newmtl material_0\n')
|
254 |
fid.write("Ka 1.000 1.000 1.000\n")
|
255 |
fid.write("Kd 1.000 1.000 1.000\n")
|
256 |
+
fid.write("Ks 0.500 0.500 0.500\n")
|
257 |
fid.write("d 1.0\n")
|
258 |
fid.write("illum 2\n")
|
259 |
fid.write(f'map_Kd texture.png\n')
|
|
|
274 |
mesh = trimesh.load_mesh(obj_path)
|
275 |
mesh.export(glb_path, file_type='glb')
|
276 |
print(f"=====> generate mesh with texture shading time: {time.time() - st}")
|
277 |
+
return obj_path, glb_path
|
278 |
|
svrm/ldm/modules/attention.py
CHANGED
@@ -246,8 +246,11 @@ class CrossAttention(nn.Module):
|
|
246 |
class FlashAttention(nn.Module):
|
247 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
248 |
super().__init__()
|
249 |
-
print(
|
250 |
-
|
|
|
|
|
|
|
251 |
inner_dim = dim_head * heads
|
252 |
context_dim = default(context_dim, query_dim)
|
253 |
self.scale = dim_head ** -0.5
|
@@ -269,7 +272,12 @@ class FlashAttention(nn.Module):
|
|
269 |
k = self.to_k(context).to(dtype)
|
270 |
v = self.to_v(context).to(dtype)
|
271 |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v)) # q is [b, 3079, 16, 64]
|
272 |
-
out = flash_attn_func(q, k, v,
|
|
|
|
|
|
|
|
|
|
|
273 |
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
274 |
return self.to_out(out.float())
|
275 |
|
@@ -277,8 +285,11 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|
277 |
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
278 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
279 |
super().__init__()
|
280 |
-
print(
|
281 |
-
|
|
|
|
|
|
|
282 |
inner_dim = dim_head * heads
|
283 |
context_dim = default(context_dim, query_dim)
|
284 |
|
@@ -327,10 +338,12 @@ class BasicTransformerBlock(nn.Module):
|
|
327 |
super().__init__()
|
328 |
self.disable_self_attn = disable_self_attn
|
329 |
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
330 |
-
context_dim=context_dim if self.disable_self_attn else None)
|
|
|
331 |
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
332 |
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
333 |
-
heads=n_heads, dim_head=d_head, dropout=dropout)
|
|
|
334 |
self.norm1 = Fp32LayerNorm(dim)
|
335 |
self.norm2 = Fp32LayerNorm(dim)
|
336 |
self.norm3 = Fp32LayerNorm(dim)
|
@@ -451,7 +464,3 @@ class ImgToTriplaneTransformer(nn.Module):
|
|
451 |
x = self.norm(x)
|
452 |
return x
|
453 |
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
|
|
246 |
class FlashAttention(nn.Module):
|
247 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
248 |
super().__init__()
|
249 |
+
# print(
|
250 |
+
# f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
|
251 |
+
# "context_dim is {context_dim} and using "
|
252 |
+
# f"{heads} heads."
|
253 |
+
# )
|
254 |
inner_dim = dim_head * heads
|
255 |
context_dim = default(context_dim, query_dim)
|
256 |
self.scale = dim_head ** -0.5
|
|
|
272 |
k = self.to_k(context).to(dtype)
|
273 |
v = self.to_v(context).to(dtype)
|
274 |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v)) # q is [b, 3079, 16, 64]
|
275 |
+
out = flash_attn_func(q, k, v,
|
276 |
+
dropout_p=self.dropout,
|
277 |
+
softmax_scale=None,
|
278 |
+
causal=False,
|
279 |
+
window_size=(-1, -1)
|
280 |
+
) # out is same shape to q
|
281 |
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
282 |
return self.to_out(out.float())
|
283 |
|
|
|
285 |
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
286 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
287 |
super().__init__()
|
288 |
+
# print(
|
289 |
+
# f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
|
290 |
+
# "context_dim is {context_dim} and using "
|
291 |
+
# f"{heads} heads."
|
292 |
+
# )
|
293 |
inner_dim = dim_head * heads
|
294 |
context_dim = default(context_dim, query_dim)
|
295 |
|
|
|
338 |
super().__init__()
|
339 |
self.disable_self_attn = disable_self_attn
|
340 |
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
341 |
+
context_dim=context_dim if self.disable_self_attn else None)
|
342 |
+
# is a self-attention if not self.disable_self_attn
|
343 |
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
344 |
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
345 |
+
heads=n_heads, dim_head=d_head, dropout=dropout)
|
346 |
+
# is self-attn if context is none
|
347 |
self.norm1 = Fp32LayerNorm(dim)
|
348 |
self.norm2 = Fp32LayerNorm(dim)
|
349 |
self.norm3 = Fp32LayerNorm(dim)
|
|
|
464 |
x = self.norm(x)
|
465 |
return x
|
466 |
|
|
|
|
|
|
|
|
svrm/ldm/vis_util.py
CHANGED
@@ -27,10 +27,10 @@ from pytorch3d.renderer import (
|
|
27 |
)
|
28 |
|
29 |
|
30 |
-
def
|
31 |
obj_filename,
|
32 |
elev=0,
|
33 |
-
azim=
|
34 |
resolution=512,
|
35 |
gif_dst_path='',
|
36 |
n_views=120,
|
@@ -49,7 +49,7 @@ def render(
|
|
49 |
mesh = load_objs_as_meshes([obj_filename], device=device)
|
50 |
meshes = mesh.extend(n_views)
|
51 |
|
52 |
-
if
|
53 |
elev = torch.linspace(elev, elev, n_views+1)[:-1]
|
54 |
azim = torch.linspace(0, 360, n_views+1)[:-1]
|
55 |
|
@@ -76,16 +76,15 @@ def render(
|
|
76 |
)
|
77 |
images = renderer(meshes)
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
# orbit frames rendering
|
86 |
-
with imageio.get_writer(uri=gif_dst_path, mode='I', duration=1. / fps * 1000, loop=0) as writer:
|
87 |
-
for i in range(n_views):
|
88 |
-
frame = images[i, ..., :3] if rgb else images[i, ...]
|
89 |
-
frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8"))
|
90 |
-
writer.append_data(frame)
|
91 |
-
return gif_dst_path
|
|
|
27 |
)
|
28 |
|
29 |
|
30 |
+
def render_func(
|
31 |
obj_filename,
|
32 |
elev=0,
|
33 |
+
azim=None,
|
34 |
resolution=512,
|
35 |
gif_dst_path='',
|
36 |
n_views=120,
|
|
|
49 |
mesh = load_objs_as_meshes([obj_filename], device=device)
|
50 |
meshes = mesh.extend(n_views)
|
51 |
|
52 |
+
if azim is None:
|
53 |
elev = torch.linspace(elev, elev, n_views+1)[:-1]
|
54 |
azim = torch.linspace(0, 360, n_views+1)[:-1]
|
55 |
|
|
|
76 |
)
|
77 |
images = renderer(meshes)
|
78 |
|
79 |
+
if gif_dst_path != '':
|
80 |
+
with imageio.get_writer(uri=gif_dst_path, mode='I', duration=1. / fps * 1000, loop=0) as writer:
|
81 |
+
for i in range(n_views):
|
82 |
+
frame = images[i, ..., :3] if rgb else images[i, ...]
|
83 |
+
frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8"))
|
84 |
+
writer.append_data(frame)
|
85 |
+
|
86 |
+
frame = images[..., :3] if rgb else images
|
87 |
+
frames = [Image.fromarray((fra.cpu().squeeze(0) * 255).numpy().astype("uint8")) for fra in frame]
|
88 |
+
return frames
|
89 |
+
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
svrm/predictor.py
CHANGED
@@ -33,7 +33,7 @@ from omegaconf import OmegaConf
|
|
33 |
from torchvision import transforms
|
34 |
from safetensors.torch import save_file, load_file
|
35 |
from .ldm.util import instantiate_from_config
|
36 |
-
from .ldm.vis_util import
|
37 |
|
38 |
class MV23DPredictor(object):
|
39 |
def __init__(self, ckpt_path, cfg_path, elevation=15, number_view=60,
|
@@ -46,9 +46,7 @@ class MV23DPredictor(object):
|
|
46 |
self.elevation_list = [0, 0, 0, 0, 0, 0, 0]
|
47 |
self.azimuth_list = [0, 60, 120, 180, 240, 300, 0]
|
48 |
|
49 |
-
st = time.time()
|
50 |
self.model = self.init_model(ckpt_path, cfg_path)
|
51 |
-
print(f"=====> mv23d model init time: {time.time() - st}")
|
52 |
|
53 |
self.input_view_transform = transforms.Compose([
|
54 |
transforms.Resize(504, interpolation=Image.BICUBIC),
|
|
|
33 |
from torchvision import transforms
|
34 |
from safetensors.torch import save_file, load_file
|
35 |
from .ldm.util import instantiate_from_config
|
36 |
+
from .ldm.vis_util import render_func
|
37 |
|
38 |
class MV23DPredictor(object):
|
39 |
def __init__(self, ckpt_path, cfg_path, elevation=15, number_view=60,
|
|
|
46 |
self.elevation_list = [0, 0, 0, 0, 0, 0, 0]
|
47 |
self.azimuth_list = [0, 60, 120, 180, 240, 300, 0]
|
48 |
|
|
|
49 |
self.model = self.init_model(ckpt_path, cfg_path)
|
|
|
50 |
|
51 |
self.input_view_transform = transforms.Compose([
|
52 |
transforms.Resize(504, interpolation=Image.BICUBIC),
|
third_party/check.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import io
|
4 |
+
|
5 |
+
def check_bake_available():
|
6 |
+
is_ok = os.path.exists("./third_party/weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt/model.safetensors")
|
7 |
+
is_ok = is_ok and os.path.exists("./third_party/dust3r")
|
8 |
+
is_ok = is_ok and os.path.exists("./third_party/dust3r/dust3r")
|
9 |
+
is_ok = is_ok and os.path.exists("./third_party/dust3r/croco/models")
|
10 |
+
if is_ok:
|
11 |
+
print("Baking is avaliable")
|
12 |
+
print("Baking is avaliable")
|
13 |
+
print("Baking is avaliable")
|
14 |
+
else:
|
15 |
+
print("Baking is unavailable, please download related files in README")
|
16 |
+
print("Baking is unavailable, please download related files in README")
|
17 |
+
print("Baking is unavailable, please download related files in README")
|
18 |
+
return is_ok
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
|
24 |
+
check_bake_available()
|
25 |
+
|
third_party/dust3r_utils.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
import cv2
|
5 |
+
import math
|
6 |
+
import numpy as np
|
7 |
+
from scipy.signal import medfilt
|
8 |
+
from scipy.spatial import KDTree
|
9 |
+
from matplotlib import pyplot as plt
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from dust3r.inference import inference
|
13 |
+
|
14 |
+
from dust3r.utils.image import load_images# , resize_images
|
15 |
+
from dust3r.image_pairs import make_pairs
|
16 |
+
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
17 |
+
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid
|
18 |
+
|
19 |
+
from third_party.utils.camera_utils import remap_points
|
20 |
+
from third_party.utils.img_utils import rgba_to_rgb, resize_with_aspect_ratio
|
21 |
+
from third_party.utils.img_utils import compute_img_diff
|
22 |
+
|
23 |
+
from PIL.ImageOps import exif_transpose
|
24 |
+
import torchvision.transforms as tvf
|
25 |
+
ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
26 |
+
|
27 |
+
|
28 |
+
def suppress_output(func):
|
29 |
+
def wrapper(*args, **kwargs):
|
30 |
+
original_stdout = sys.stdout
|
31 |
+
original_stderr = sys.stderr
|
32 |
+
sys.stdout = io.StringIO()
|
33 |
+
sys.stderr = io.StringIO()
|
34 |
+
try:
|
35 |
+
return func(*args, **kwargs)
|
36 |
+
finally:
|
37 |
+
sys.stdout = original_stdout
|
38 |
+
sys.stderr = original_stderr
|
39 |
+
return wrapper
|
40 |
+
|
41 |
+
def _resize_pil_image(img, long_edge_size):
|
42 |
+
S = max(img.size)
|
43 |
+
if S > long_edge_size:
|
44 |
+
interp = Image.LANCZOS
|
45 |
+
elif S <= long_edge_size:
|
46 |
+
interp = Image.BICUBIC
|
47 |
+
new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size)
|
48 |
+
return img.resize(new_size, interp)
|
49 |
+
|
50 |
+
def resize_images(imgs_list, size, square_ok=False):
|
51 |
+
""" open and convert all images in a list or folder to proper input format for DUSt3R
|
52 |
+
"""
|
53 |
+
imgs = []
|
54 |
+
for img in imgs_list:
|
55 |
+
img = exif_transpose(Image.fromarray(img)).convert('RGB')
|
56 |
+
W1, H1 = img.size
|
57 |
+
if size == 224:
|
58 |
+
# resize short side to 224 (then crop)
|
59 |
+
img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)))
|
60 |
+
else:
|
61 |
+
# resize long side to 512
|
62 |
+
img = _resize_pil_image(img, size)
|
63 |
+
W, H = img.size
|
64 |
+
cx, cy = W//2, H//2
|
65 |
+
if size == 224:
|
66 |
+
half = min(cx, cy)
|
67 |
+
img = img.crop((cx-half, cy-half, cx+half, cy+half))
|
68 |
+
else:
|
69 |
+
halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8
|
70 |
+
if not (square_ok) and W == H:
|
71 |
+
halfh = 3*halfw/4
|
72 |
+
img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh))
|
73 |
+
|
74 |
+
W2, H2 = img.size
|
75 |
+
imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32(
|
76 |
+
[img.size[::-1]]), idx=len(imgs), instance=str(len(imgs))))
|
77 |
+
|
78 |
+
return imgs
|
79 |
+
|
80 |
+
@suppress_output
|
81 |
+
def infer_match(images, model, vis=False, niter=300, lr=0.01, schedule='cosine', device="cuda:0"):
|
82 |
+
batch_size = 1
|
83 |
+
schedule = 'cosine'
|
84 |
+
lr = 0.01
|
85 |
+
niter = 300
|
86 |
+
|
87 |
+
images_packed = resize_images(images, size=512, square_ok=True)
|
88 |
+
# images_packed = images
|
89 |
+
|
90 |
+
pairs = make_pairs(images_packed, scene_graph='complete', prefilter=None, symmetrize=True)
|
91 |
+
output = inference(pairs, model, device, batch_size=batch_size, verbose=False)
|
92 |
+
|
93 |
+
scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
|
94 |
+
loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)
|
95 |
+
|
96 |
+
# retrieve useful values from scene:
|
97 |
+
imgs = scene.imgs
|
98 |
+
# focals = scene.get_focals()
|
99 |
+
# poses = scene.get_im_poses()
|
100 |
+
pts3d = scene.get_pts3d()
|
101 |
+
confidence_masks = scene.get_masks()
|
102 |
+
|
103 |
+
# visualize reconstruction
|
104 |
+
# scene.show()
|
105 |
+
|
106 |
+
# find 2D-2D matches between the two images
|
107 |
+
pts2d_list, pts3d_list = [], []
|
108 |
+
for i in range(2):
|
109 |
+
conf_i = confidence_masks[i].cpu().numpy()
|
110 |
+
pts2d_list.append(xy_grid(*imgs[i].shape[:2][::-1])[conf_i]) # imgs[i].shape[:2] = (H, W)
|
111 |
+
pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
|
112 |
+
if pts3d_list[-1].shape[0] == 0:
|
113 |
+
return np.zeros((0, 2)), np.zeros((0, 2))
|
114 |
+
reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(*pts3d_list)
|
115 |
+
|
116 |
+
matches_im1 = pts2d_list[1][reciprocal_in_P2]
|
117 |
+
matches_im0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
|
118 |
+
|
119 |
+
# visualize a few matches
|
120 |
+
if vis == True:
|
121 |
+
print(f'found {num_matches} matches')
|
122 |
+
n_viz = 20
|
123 |
+
match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int)
|
124 |
+
viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz]
|
125 |
+
|
126 |
+
H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2]
|
127 |
+
img0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
|
128 |
+
img1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
|
129 |
+
img = np.concatenate((img0, img1), axis=1)
|
130 |
+
plt.figure()
|
131 |
+
plt.imshow(img)
|
132 |
+
cmap = plt.get_cmap('jet')
|
133 |
+
for i in range(n_viz):
|
134 |
+
(x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T
|
135 |
+
plt.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False)
|
136 |
+
plt.show(block=True)
|
137 |
+
|
138 |
+
matches_im0 = remap_points(images[0].shape, matches_im0)
|
139 |
+
matches_im1 = remap_points(images[1].shape, matches_im1)
|
140 |
+
return matches_im0, matches_im1
|
141 |
+
|
142 |
+
|
143 |
+
def point_transform(H, pt):
|
144 |
+
"""
|
145 |
+
@param: H is homography matrix of dimension (3x3)
|
146 |
+
@param: pt is the (x, y) point to be transformed
|
147 |
+
|
148 |
+
Return:
|
149 |
+
returns a transformed point ptrans = H*pt.
|
150 |
+
"""
|
151 |
+
a = H[0, 0] * pt[0] + H[0, 1] * pt[1] + H[0, 2]
|
152 |
+
b = H[1, 0] * pt[0] + H[1, 1] * pt[1] + H[1, 2]
|
153 |
+
c = H[2, 0] * pt[0] + H[2, 1] * pt[1] + H[2, 2]
|
154 |
+
return [a / c, b / c]
|
155 |
+
|
156 |
+
|
157 |
+
def points_transform(H, pt_x, pt_y):
|
158 |
+
"""
|
159 |
+
@param: H is homography matrix of dimension (3x3)
|
160 |
+
@param: pt is the (x, y) point to be transformed
|
161 |
+
|
162 |
+
Return:
|
163 |
+
returns a transformed point ptrans = H*pt.
|
164 |
+
"""
|
165 |
+
a = H[0, 0] * pt_x + H[0, 1] * pt_y + H[0, 2]
|
166 |
+
b = H[1, 0] * pt_x + H[1, 1] * pt_y + H[1, 2]
|
167 |
+
c = H[2, 0] * pt_x + H[2, 1] * pt_y + H[2, 2]
|
168 |
+
return (a / c, b / c)
|
169 |
+
|
170 |
+
|
171 |
+
def motion_propagate(old_points, new_points, old_size, new_size, H_size=(21, 21)):
|
172 |
+
"""
|
173 |
+
@param: old_points are points in old_frame that are
|
174 |
+
matched feature points with new_frame
|
175 |
+
@param: new_points are points in new_frame that are
|
176 |
+
matched feature points with old_frame
|
177 |
+
@param: old_frame is the frame to which
|
178 |
+
motion mesh needs to be obtained
|
179 |
+
@param: H is the homography between old and new points
|
180 |
+
|
181 |
+
Return:
|
182 |
+
returns a motion mesh in x-direction
|
183 |
+
and y-direction for old_frame
|
184 |
+
"""
|
185 |
+
# spreads motion over the mesh for the old_frame
|
186 |
+
x_motion = np.zeros(H_size)
|
187 |
+
y_motion = np.zeros(H_size)
|
188 |
+
mesh_x_num, mesh_y_num = H_size[0], H_size[1]
|
189 |
+
pixels_x, pixels_y = (old_size[1]) / (mesh_x_num - 1), (old_size[0]) / (mesh_y_num - 1)
|
190 |
+
radius = max(pixels_x, pixels_y) * 5
|
191 |
+
sigma = radius / 3.0
|
192 |
+
|
193 |
+
H_global = None
|
194 |
+
if old_points.shape[0] > 3:
|
195 |
+
# pre-warping with global homography
|
196 |
+
H_global, _ = cv2.findHomography(old_points, new_points, cv2.RANSAC)
|
197 |
+
if H_global is None:
|
198 |
+
old_tmp = np.array([[0, 0], [0, old_size[0]], [old_size[1], 0], [old_size[1], old_size[0]]])
|
199 |
+
new_tmp = np.array([[0, 0], [0, new_size[0]], [new_size[1], 0], [new_size[1], new_size[0]]])
|
200 |
+
H_global, _ = cv2.findHomography(old_tmp, new_tmp, cv2.RANSAC)
|
201 |
+
|
202 |
+
for i in range(mesh_x_num):
|
203 |
+
for j in range(mesh_y_num):
|
204 |
+
pt = [pixels_x * i, pixels_y * j]
|
205 |
+
ptrans = point_transform(H_global, pt)
|
206 |
+
x_motion[i, j] = ptrans[0]
|
207 |
+
y_motion[i, j] = ptrans[1]
|
208 |
+
|
209 |
+
# disturbute feature motion vectors
|
210 |
+
weighted_move_x = np.zeros(H_size)
|
211 |
+
weighted_move_y = np.zeros(H_size)
|
212 |
+
# 构建 KDTree
|
213 |
+
tree = KDTree(old_points)
|
214 |
+
# 计算权重和移动值
|
215 |
+
for i in range(mesh_x_num):
|
216 |
+
for j in range(mesh_y_num):
|
217 |
+
vertex = [pixels_x * i, pixels_y * j]
|
218 |
+
neighbor_indices = tree.query_ball_point(vertex, radius, workers=-1)
|
219 |
+
if len(neighbor_indices) > 0:
|
220 |
+
pts = old_points[neighbor_indices]
|
221 |
+
sts = new_points[neighbor_indices]
|
222 |
+
ptrans_x, ptrans_y = points_transform(H_global, pts[:, 0], pts[:, 1])
|
223 |
+
moves_x = sts[:, 0] - ptrans_x
|
224 |
+
moves_y = sts[:, 1] - ptrans_y
|
225 |
+
|
226 |
+
dists = np.sqrt((vertex[0] - pts[:, 0]) ** 2 + (vertex[1] - pts[:, 1]) ** 2)
|
227 |
+
weights_x = np.exp(-(dists ** 2) / (2 * sigma ** 2))
|
228 |
+
weights_y = np.exp(-(dists ** 2) / (2 * sigma ** 2))
|
229 |
+
|
230 |
+
weighted_move_x[i, j] = np.sum(weights_x * moves_x) / (np.sum(weights_x) + 0.1)
|
231 |
+
weighted_move_y[i, j] = np.sum(weights_y * moves_y) / (np.sum(weights_y) + 0.1)
|
232 |
+
|
233 |
+
x_motion_mesh = x_motion + weighted_move_x
|
234 |
+
y_motion_mesh = y_motion + weighted_move_y
|
235 |
+
'''
|
236 |
+
# apply median filter (f-1) on obtained motion for each vertex
|
237 |
+
x_motion_mesh = np.zeros((mesh_x_num, mesh_y_num), dtype=float)
|
238 |
+
y_motion_mesh = np.zeros((mesh_x_num, mesh_y_num), dtype=float)
|
239 |
+
|
240 |
+
for key in x_motion.keys():
|
241 |
+
try:
|
242 |
+
temp_x_motion[key].sort()
|
243 |
+
x_motion_mesh[key] = x_motion[key]+temp_x_motion[key][len(temp_x_motion[key])//2]
|
244 |
+
except KeyError:
|
245 |
+
x_motion_mesh[key] = x_motion[key]
|
246 |
+
try:
|
247 |
+
temp_y_motion[key].sort()
|
248 |
+
y_motion_mesh[key] = y_motion[key]+temp_y_motion[key][len(temp_y_motion[key])//2]
|
249 |
+
except KeyError:
|
250 |
+
y_motion_mesh[key] = y_motion[key]
|
251 |
+
|
252 |
+
# apply second median filter (f-2) over the motion mesh for outliers
|
253 |
+
#x_motion_mesh = medfilt(x_motion_mesh, kernel_size=[3, 3])
|
254 |
+
#y_motion_mesh = medfilt(y_motion_mesh, kernel_size=[3, 3])
|
255 |
+
'''
|
256 |
+
return x_motion_mesh, y_motion_mesh
|
257 |
+
|
258 |
+
|
259 |
+
def mesh_warp_points(points, x_motion_mesh, y_motion_mesh, img_size):
|
260 |
+
ptrans = []
|
261 |
+
mesh_x_num, mesh_y_num = x_motion_mesh.shape
|
262 |
+
pixels_x, pixels_y = (img_size[1]) / (mesh_x_num - 1), (img_size[0]) / (mesh_y_num - 1)
|
263 |
+
for pt in points:
|
264 |
+
i = int(pt[0] // pixels_x)
|
265 |
+
j = int(pt[1] // pixels_y)
|
266 |
+
src = [[i * pixels_x, j * pixels_y],
|
267 |
+
[(i + 1) * pixels_x, j * pixels_y],
|
268 |
+
[i * pixels_x, (j + 1) * pixels_y],
|
269 |
+
[(i + 1) * pixels_x, (j + 1) * pixels_y]]
|
270 |
+
src = np.asarray(src)
|
271 |
+
dst = [[x_motion_mesh[i, j], y_motion_mesh[i, j]],
|
272 |
+
[x_motion_mesh[i + 1, j], y_motion_mesh[i + 1, j]],
|
273 |
+
[x_motion_mesh[i, j + 1], y_motion_mesh[i, j + 1]],
|
274 |
+
[x_motion_mesh[i + 1, j + 1], y_motion_mesh[i + 1, j + 1]]]
|
275 |
+
dst = np.asarray(dst)
|
276 |
+
H, _ = cv2.findHomography(src, dst, cv2.RANSAC)
|
277 |
+
x, y = points_transform(H, pt[0], pt[1])
|
278 |
+
ptrans.append([x, y])
|
279 |
+
|
280 |
+
return np.array(ptrans)
|
281 |
+
|
282 |
+
|
283 |
+
def mesh_warp_frame(frame, x_motion_mesh, y_motion_mesh, resize):
|
284 |
+
"""
|
285 |
+
@param: frame is the current frame
|
286 |
+
@param: x_motion_mesh is the motion_mesh to
|
287 |
+
be warped on frame along x-direction
|
288 |
+
@param: y_motion_mesh is the motion mesh to
|
289 |
+
be warped on frame along y-direction
|
290 |
+
@param: resize is the desired output size (tuple of width, height)
|
291 |
+
|
292 |
+
Returns:
|
293 |
+
returns a mesh warped frame according
|
294 |
+
to given motion meshes x_motion_mesh,
|
295 |
+
y_motion_mesh, resized to the specified size
|
296 |
+
"""
|
297 |
+
|
298 |
+
map_x = np.zeros(resize, np.float32)
|
299 |
+
map_y = np.zeros(resize, np.float32)
|
300 |
+
|
301 |
+
mesh_x_num, mesh_y_num = x_motion_mesh.shape
|
302 |
+
pixels_x, pixels_y = (resize[1]) / (mesh_x_num - 1), (resize[0]) / (mesh_y_num - 1)
|
303 |
+
|
304 |
+
for i in range(mesh_x_num - 1):
|
305 |
+
for j in range(mesh_y_num - 1):
|
306 |
+
src = [[i * pixels_x, j * pixels_y],
|
307 |
+
[(i + 1) * pixels_x, j * pixels_y],
|
308 |
+
[i * pixels_x, (j + 1) * pixels_y],
|
309 |
+
[(i + 1) * pixels_x, (j + 1) * pixels_y]]
|
310 |
+
src = np.asarray(src)
|
311 |
+
|
312 |
+
dst = [[x_motion_mesh[i, j], y_motion_mesh[i, j]],
|
313 |
+
[x_motion_mesh[i + 1, j], y_motion_mesh[i + 1, j]],
|
314 |
+
[x_motion_mesh[i, j + 1], y_motion_mesh[i, j + 1]],
|
315 |
+
[x_motion_mesh[i + 1, j + 1], y_motion_mesh[i + 1, j + 1]]]
|
316 |
+
dst = np.asarray(dst)
|
317 |
+
H, _ = cv2.findHomography(src, dst, cv2.RANSAC)
|
318 |
+
|
319 |
+
start_x = math.ceil(pixels_x * i)
|
320 |
+
end_x = math.ceil(pixels_x * (i + 1))
|
321 |
+
start_y = math.ceil(pixels_y * j)
|
322 |
+
end_y = math.ceil(pixels_y * (j + 1))
|
323 |
+
|
324 |
+
x, y = np.meshgrid(range(start_x, end_x), range(start_y, end_y), indexing='ij')
|
325 |
+
|
326 |
+
map_x[y, x], map_y[y, x] = points_transform(H, x, y)
|
327 |
+
|
328 |
+
# deforms mesh and directly outputs the resized frame
|
329 |
+
resized_frame = cv2.remap(frame, map_x, map_y,
|
330 |
+
interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT,
|
331 |
+
borderValue=(255, 255, 255))
|
332 |
+
return resized_frame
|
333 |
+
|
334 |
+
|
335 |
+
def infer_warp_mesh_img(src, dst, model, vis=False):
|
336 |
+
if isinstance(src, str):
|
337 |
+
image1 = cv2.imread(src, cv2.IMREAD_UNCHANGED)
|
338 |
+
image2 = cv2.imread(dst, cv2.IMREAD_UNCHANGED)
|
339 |
+
image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
|
340 |
+
image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
|
341 |
+
elif isinstance(src, Image.Image):
|
342 |
+
image1 = np.array(src)
|
343 |
+
image2 = np.array(dst)
|
344 |
+
else:
|
345 |
+
assert isinstance(src, np.ndarray)
|
346 |
+
|
347 |
+
image1 = rgba_to_rgb(image1)
|
348 |
+
image2 = rgba_to_rgb(image2)
|
349 |
+
|
350 |
+
image1_padded = resize_with_aspect_ratio(image1, image2)
|
351 |
+
resized_image1 = cv2.resize(image1_padded, (image2.shape[1], image2.shape[0]), interpolation=cv2.INTER_AREA)
|
352 |
+
|
353 |
+
matches_im0, matches_im1 = infer_match([resized_image1, image2], model, vis=vis)
|
354 |
+
matches_im0 = matches_im0 * image1_padded.shape[0] / resized_image1.shape[0]
|
355 |
+
|
356 |
+
# print('Estimate Mesh Grid')
|
357 |
+
mesh_x, mesh_y = motion_propagate(matches_im1, matches_im0, image2.shape[:2], image1_padded.shape[:2])
|
358 |
+
|
359 |
+
aligned_image = mesh_warp_frame(image1_padded, mesh_x, mesh_y, (image2.shape[0], image2.shape[1]))
|
360 |
+
|
361 |
+
matches_im0_from_im1 = mesh_warp_points(matches_im1, mesh_x, mesh_y, (image2.shape[1], image2.shape[0]))
|
362 |
+
|
363 |
+
info = compute_img_diff(aligned_image, image2, matches_im0, matches_im0_from_im1, vis=vis)
|
364 |
+
|
365 |
+
return aligned_image, info
|
366 |
+
|
third_party/gen_baking.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys, time
|
2 |
+
from typing import List, Optional
|
3 |
+
from iopath.common.file_io import PathManager
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import imageio
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torchvision import transforms
|
14 |
+
|
15 |
+
import trimesh
|
16 |
+
from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
|
17 |
+
from pytorch3d.ops import interpolate_face_attributes
|
18 |
+
from pytorch3d.common.datatypes import Device
|
19 |
+
from pytorch3d.structures import Meshes
|
20 |
+
from pytorch3d.renderer import (
|
21 |
+
look_at_view_transform,
|
22 |
+
FoVPerspectiveCameras,
|
23 |
+
PointLights,
|
24 |
+
DirectionalLights,
|
25 |
+
AmbientLights,
|
26 |
+
Materials,
|
27 |
+
RasterizationSettings,
|
28 |
+
MeshRenderer,
|
29 |
+
MeshRasterizer,
|
30 |
+
SoftPhongShader,
|
31 |
+
TexturesUV,
|
32 |
+
TexturesVertex,
|
33 |
+
camera_position_from_spherical_angles,
|
34 |
+
BlendParams,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
def erode_mask(src_mask, p=1 / 20.0):
|
39 |
+
monoMaskImage = cv2.split(src_mask)[0]
|
40 |
+
br = cv2.boundingRect(monoMaskImage)
|
41 |
+
k = int(min(br[2], br[3]) * p)
|
42 |
+
kernel = np.ones((k, k), dtype=np.uint8)
|
43 |
+
dst_mask = cv2.erode(src_mask, kernel, 1)
|
44 |
+
return dst_mask
|
45 |
+
|
46 |
+
def load_objs_as_meshes_fast(
|
47 |
+
verts,
|
48 |
+
faces,
|
49 |
+
aux,
|
50 |
+
device: Optional[Device] = None,
|
51 |
+
load_textures: bool = True,
|
52 |
+
create_texture_atlas: bool = False,
|
53 |
+
texture_atlas_size: int = 4,
|
54 |
+
texture_wrap: Optional[str] = "repeat",
|
55 |
+
path_manager: Optional[PathManager] = None,
|
56 |
+
):
|
57 |
+
tex = None
|
58 |
+
if create_texture_atlas:
|
59 |
+
# TexturesAtlas type
|
60 |
+
tex = TexturesAtlas(atlas=[aux.texture_atlas.to(device)])
|
61 |
+
else:
|
62 |
+
# TexturesUV type
|
63 |
+
tex_maps = aux.texture_images
|
64 |
+
if tex_maps is not None and len(tex_maps) > 0:
|
65 |
+
verts_uvs = aux.verts_uvs.to(device) # (V, 2)
|
66 |
+
faces_uvs = faces.textures_idx.to(device) # (F, 3)
|
67 |
+
image = list(tex_maps.values())[0].to(device)[None]
|
68 |
+
tex = TexturesUV(verts_uvs=[verts_uvs], faces_uvs=[faces_uvs], maps=image)
|
69 |
+
mesh = Meshes( verts=[verts.to(device)], faces=[faces.verts_idx.to(device)], textures=tex)
|
70 |
+
return mesh
|
71 |
+
|
72 |
+
|
73 |
+
def get_triangle_to_triangle(tri_1, tri_2, img_refined):
|
74 |
+
'''
|
75 |
+
args:
|
76 |
+
tri_1:
|
77 |
+
tri_2:
|
78 |
+
'''
|
79 |
+
r1 = cv2.boundingRect(tri_1)
|
80 |
+
r2 = cv2.boundingRect(tri_2)
|
81 |
+
|
82 |
+
tri_1_cropped = []
|
83 |
+
tri_2_cropped = []
|
84 |
+
for i in range(0, 3):
|
85 |
+
tri_1_cropped.append(((tri_1[i][1] - r1[1]), (tri_1[i][0] - r1[0])))
|
86 |
+
tri_2_cropped.append(((tri_2[i][1] - r2[1]), (tri_2[i][0] - r2[0])))
|
87 |
+
|
88 |
+
trans = cv2.getAffineTransform(np.float32(tri_1_cropped), np.float32(tri_2_cropped))
|
89 |
+
|
90 |
+
img_1_cropped = np.float32(img_refined[r1[0]:r1[0] + r1[2], r1[1]:r1[1] + r1[3]])
|
91 |
+
|
92 |
+
mask = np.zeros((r2[2], r2[3], 3), dtype=np.float32)
|
93 |
+
|
94 |
+
cv2.fillConvexPoly(mask, np.int32(tri_2_cropped), (1.0, 1.0, 1.0), 16, 0)
|
95 |
+
|
96 |
+
img_2_cropped = cv2.warpAffine(
|
97 |
+
img_1_cropped, trans, (r2[3], r2[2]), None,
|
98 |
+
flags = cv2.INTER_LINEAR,
|
99 |
+
borderMode = cv2.BORDER_REFLECT_101
|
100 |
+
)
|
101 |
+
return mask, img_2_cropped, r2
|
102 |
+
|
103 |
+
|
104 |
+
def back_projection(
|
105 |
+
obj_file,
|
106 |
+
init_texture_file,
|
107 |
+
front_view_file,
|
108 |
+
dst_dir,
|
109 |
+
render_resolution=512,
|
110 |
+
uv_resolution=600,
|
111 |
+
normalThreshold=0.3, # 0.3
|
112 |
+
rgb_thresh=820, # 520
|
113 |
+
views=None,
|
114 |
+
camera_dist=1.5,
|
115 |
+
erode_scale=1/100.0,
|
116 |
+
device="cuda:0"
|
117 |
+
):
|
118 |
+
# obj_file: 带有uv的obj
|
119 |
+
# init_texture_file: 初始展开的uv贴图
|
120 |
+
# render_resolution 正面视角渲染分辨率
|
121 |
+
# uv_resolution 贴图分辨率
|
122 |
+
# thres:normal threshold
|
123 |
+
|
124 |
+
os.makedirs(dst_dir, exist_ok=True)
|
125 |
+
|
126 |
+
if isinstance(front_view_file, str):
|
127 |
+
src = np.array(Image.open(front_view_file).convert("RGB"))
|
128 |
+
elif isinstance(front_view_file, Image.Image):
|
129 |
+
src = np.array(front_view_file.convert("RGB"))
|
130 |
+
else:
|
131 |
+
raise "need file_path or pil"
|
132 |
+
|
133 |
+
image_size = (render_resolution, render_resolution)
|
134 |
+
|
135 |
+
init_texture = Image.open(init_texture_file)
|
136 |
+
init_texture = init_texture.convert("RGB")
|
137 |
+
# init_texture = init_texture.resize((uv_resolution, uv_resolution))
|
138 |
+
init_texture = np.array(init_texture).astype(np.float32)
|
139 |
+
|
140 |
+
print("load obj", obj_file)
|
141 |
+
verts, faces, aux = load_obj(obj_file, device=device)
|
142 |
+
mesh = load_objs_as_meshes_fast(verts, faces, aux, device=device)
|
143 |
+
|
144 |
+
|
145 |
+
t0 = time.time()
|
146 |
+
verts_uvs = aux.verts_uvs
|
147 |
+
triangle_uvs = verts_uvs[faces.textures_idx]
|
148 |
+
triangle_uvs = torch.cat([
|
149 |
+
((1 - triangle_uvs[..., 1]) * uv_resolution).unsqueeze(2),
|
150 |
+
(triangle_uvs[..., 0] * uv_resolution).unsqueeze(2),
|
151 |
+
], dim=-1)
|
152 |
+
triangle_uvs = np.clip(np.round(np.float32(triangle_uvs.cpu())).astype(np.int64), 0, uv_resolution-1)
|
153 |
+
|
154 |
+
# import ipdb;ipdb.set_trace()
|
155 |
+
|
156 |
+
|
157 |
+
R0, T0 = look_at_view_transform(camera_dist, views[0][0], views[0][1])
|
158 |
+
|
159 |
+
cameras = FoVPerspectiveCameras(device=device, R=R0, T=T0, fov=49.1)
|
160 |
+
|
161 |
+
camera_normal = camera_position_from_spherical_angles(1, views[0][0], views[0][1]).to(device)
|
162 |
+
screen_coords = cameras.transform_points_screen(verts, image_size=image_size)[:, :2]
|
163 |
+
screen_coords = torch.cat([screen_coords[..., 1, None], screen_coords[..., 0, None]], dim=-1)
|
164 |
+
triangle_screen_coords = np.round(np.float32(screen_coords[faces.verts_idx].cpu())) # numpy.ndarray (90000, 3, 2)
|
165 |
+
triangle_screen_coords = np.clip(triangle_screen_coords.astype(np.int64), 0, render_resolution-1)
|
166 |
+
|
167 |
+
renderer = MeshRenderer(
|
168 |
+
rasterizer=MeshRasterizer(
|
169 |
+
cameras=cameras,
|
170 |
+
raster_settings= RasterizationSettings(
|
171 |
+
image_size=image_size,
|
172 |
+
blur_radius=0.0,
|
173 |
+
faces_per_pixel=1,
|
174 |
+
),
|
175 |
+
),
|
176 |
+
shader=SoftPhongShader(
|
177 |
+
device=device,
|
178 |
+
cameras=cameras,
|
179 |
+
lights= AmbientLights(device=device),
|
180 |
+
blend_params=BlendParams(background_color=(1.0, 1.0, 1.0)),
|
181 |
+
)
|
182 |
+
)
|
183 |
+
|
184 |
+
dst = renderer(mesh)
|
185 |
+
dst = (dst[..., :3] * 255).squeeze(0).cpu().numpy().astype(np.uint8)
|
186 |
+
|
187 |
+
src_mask = np.ones((src.shape[0], src.shape[1]), dst.dtype)
|
188 |
+
ids = np.where(dst.sum(-1) > 253 * 3)
|
189 |
+
ids2 = np.where(src.sum(-1) > 250 * 3)
|
190 |
+
src_mask[ids[0], ids[1]] = 0
|
191 |
+
src_mask[ids2[0], ids2[1]] = 0
|
192 |
+
src_mask = (src_mask > 0).astype(np.uint8) * 255
|
193 |
+
|
194 |
+
monoMaskImage = cv2.split(src_mask)[0] # reducing the mask to a monochrome
|
195 |
+
br = cv2.boundingRect(monoMaskImage) # bounding rect (x,y,width,height)
|
196 |
+
center = (br[0] + br[2] // 2, br[1] + br[3] // 2)
|
197 |
+
|
198 |
+
# seamlessClone
|
199 |
+
try:
|
200 |
+
images = cv2.seamlessClone(src, dst, src_mask, center, cv2.NORMAL_CLONE) # more qingxi
|
201 |
+
# images = cv2.seamlessClone(src, dst, src_mask, center, cv2.MIXED_CLONE)
|
202 |
+
except Exception as err:
|
203 |
+
print(f"\n\n Warning seamlessClone error: {err} \n\n")
|
204 |
+
images = src
|
205 |
+
|
206 |
+
Image.fromarray(src_mask).save(os.path.join(dst_dir, 'mask.jpeg'))
|
207 |
+
Image.fromarray(src).save(os.path.join(dst_dir, 'src.jpeg'))
|
208 |
+
Image.fromarray(dst).save(os.path.join(dst_dir, 'dst.jpeg'))
|
209 |
+
Image.fromarray(images).save(os.path.join(dst_dir, 'blend.jpeg'))
|
210 |
+
|
211 |
+
fragments_scaled = renderer.rasterizer(mesh) # pytorch3d.renderer.mesh.rasterizer.Fragments
|
212 |
+
faces_covered = fragments_scaled.pix_to_face.unique()[1:] # torch.Tensor torch.Size([30025])
|
213 |
+
face_normals = mesh.faces_normals_packed().to(device) # torch.Tensor torch.Size([90000, 3]) cuda:0
|
214 |
+
|
215 |
+
# faces: pytorch3d.io.obj_io.Faces
|
216 |
+
# faces.textures_idx: torch.Tensor torch.Size([90000, 3])
|
217 |
+
# verts_uvs: torch.Tensor torch.Size([49554, 2])
|
218 |
+
triangle_uvs = verts_uvs[faces.textures_idx]
|
219 |
+
triangle_uvs = [
|
220 |
+
((1 - triangle_uvs[..., 1]) * uv_resolution).unsqueeze(2),
|
221 |
+
(triangle_uvs[..., 0] * uv_resolution).unsqueeze(2),
|
222 |
+
]
|
223 |
+
triangle_uvs = torch.cat(triangle_uvs, dim=-1) # numpy.ndarray (90000, 3, 2)
|
224 |
+
triangle_uvs = np.clip(np.round(np.float32(triangle_uvs.cpu())).astype(np.int64), 0, uv_resolution-1)
|
225 |
+
|
226 |
+
t0 = time.time()
|
227 |
+
|
228 |
+
|
229 |
+
SOFT_NORM = True # process big angle-diff face, true:flase? coeff:skip
|
230 |
+
|
231 |
+
for k in faces_covered:
|
232 |
+
# todo: accelerate this for-loop
|
233 |
+
# if cosine between face-camera is too low, skip current face baking
|
234 |
+
face_normal = face_normals[k]
|
235 |
+
cosine = torch.sum((face_normal * camera_normal) ** 2)
|
236 |
+
if not SOFT_NORM and cosine < normalThreshold: continue
|
237 |
+
|
238 |
+
# if coord in screen out of subject, skip current face baking
|
239 |
+
out_of_subject = src_mask[triangle_screen_coords[k][0][0], triangle_screen_coords[k][0][1]]==0
|
240 |
+
if out_of_subject: continue
|
241 |
+
|
242 |
+
coeff, img_2_cropped, r2 = get_triangle_to_triangle(triangle_screen_coords[k], triangle_uvs[k], images)
|
243 |
+
|
244 |
+
# if color difference between new-old, skip current face baking
|
245 |
+
err = np.abs(init_texture[r2[0]:r2[0]+r2[2], r2[1]:r2[1]+r2[3]]- img_2_cropped)
|
246 |
+
err = (err * coeff).sum(-1)
|
247 |
+
|
248 |
+
# print(err.shape, np.max(err))
|
249 |
+
if (np.max(err) > rgb_thresh): continue
|
250 |
+
|
251 |
+
color_for_debug = None
|
252 |
+
# if (np.max(err) > 400): color_for_debug = [255, 0, 0]
|
253 |
+
# if (np.max(err) > 450): color_for_debug = [0, 255, 0]
|
254 |
+
# if (np.max(err) > 500): color_for_debug = [0, 0, 255]
|
255 |
+
|
256 |
+
coeff = coeff.clip(0, 1)
|
257 |
+
|
258 |
+
if SOFT_NORM:
|
259 |
+
coeff *= ((cosine.detach().cpu().numpy() - normalThreshold) / normalThreshold).clip(0,1)
|
260 |
+
|
261 |
+
coeff *= (((rgb_thresh - err[...,None]) / rgb_thresh)**0.4).clip(0,1)
|
262 |
+
|
263 |
+
if color_for_debug is None:
|
264 |
+
init_texture[r2[0]:r2[0]+r2[2], r2[1]:r2[1]+r2[3]] = \
|
265 |
+
init_texture[r2[0]:r2[0]+r2[2], r2[1]:r2[1]+r2[3]] * ((1.0,1.0,1.0)-coeff) + img_2_cropped * coeff
|
266 |
+
else:
|
267 |
+
init_texture[r2[0]:r2[0]+r2[2], r2[1]:r2[1]+r2[3]] = color_for_debug
|
268 |
+
|
269 |
+
print(f'View baking time: {time.time() - t0}')
|
270 |
+
|
271 |
+
bake_dir = os.path.join(dst_dir, 'bake')
|
272 |
+
os.makedirs(bake_dir, exist_ok=True)
|
273 |
+
os.system(f'cp {obj_file} {bake_dir}')
|
274 |
+
|
275 |
+
textute_img = Image.fromarray(init_texture.astype(np.uint8))
|
276 |
+
textute_img.save(os.path.join(bake_dir, init_texture_file.split("/")[-1]))
|
277 |
+
|
278 |
+
mtl_dir = obj_file.replace('.obj', '.mtl')
|
279 |
+
if not os.path.exists(mtl_dir): mtl_dir = obj_file.replace("mesh.obj" ,"material.mtl")
|
280 |
+
if not os.path.exists(mtl_dir): mtl_dir = obj_file.replace("mesh.obj" ,"texture.mtl")
|
281 |
+
if not os.path.exists(mtl_dir): import ipdb;ipdb.set_trace()
|
282 |
+
os.system(f'cp {mtl_dir} {bake_dir}')
|
283 |
+
|
284 |
+
# convert .obj to .glb file
|
285 |
+
new_obj_pth = os.path.join(bake_dir, obj_file.split('/')[-1])
|
286 |
+
new_glb_path = new_obj_pth.replace('.obj', '.glb')
|
287 |
+
mesh = trimesh.load_mesh(new_obj_pth)
|
288 |
+
mesh.export(new_glb_path, file_type='glb')
|
third_party/mesh_baker.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys, time, traceback
|
2 |
+
print("sys path insert", os.path.join(os.path.dirname(__file__), "dust3r"))
|
3 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "dust3r"))
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image, ImageSequence
|
8 |
+
from einops import rearrange
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from infer.utils import seed_everything, timing_decorator
|
12 |
+
from infer.utils import get_parameter_number, set_parameter_grad_false
|
13 |
+
|
14 |
+
from dust3r.inference import inference
|
15 |
+
from dust3r.model import AsymmetricCroCo3DStereo
|
16 |
+
|
17 |
+
from third_party.gen_baking import back_projection
|
18 |
+
from third_party.dust3r_utils import infer_warp_mesh_img
|
19 |
+
from svrm.ldm.vis_util import render_func
|
20 |
+
|
21 |
+
|
22 |
+
class MeshBaker:
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
align_model = "third_party/weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt",
|
26 |
+
device = "cuda:0",
|
27 |
+
align_times = 1,
|
28 |
+
iou_thresh = 0.8,
|
29 |
+
force_baking_ele_list = None,
|
30 |
+
save_memory = False
|
31 |
+
):
|
32 |
+
self.device = device
|
33 |
+
self.save_memory = save_memory
|
34 |
+
self.align_model = AsymmetricCroCo3DStereo.from_pretrained(align_model)
|
35 |
+
self.align_model = self.align_model if save_memory else self.align_model.to(device)
|
36 |
+
self.align_times = align_times
|
37 |
+
self.align_model.eval()
|
38 |
+
self.iou_thresh = iou_thresh
|
39 |
+
self.force_baking_ele_list = [] if force_baking_ele_list is None else force_baking_ele_list
|
40 |
+
self.force_baking_ele_list = [int(_) for _ in self.force_baking_ele_list]
|
41 |
+
set_parameter_grad_false(self.align_model)
|
42 |
+
print('baking align model', get_parameter_number(self.align_model))
|
43 |
+
|
44 |
+
def align_and_check(self, src, dst, align_times=3):
|
45 |
+
try:
|
46 |
+
st = time.time()
|
47 |
+
best_baking_flag = False
|
48 |
+
best_aligned_image = aligned_image = src
|
49 |
+
best_info = {'match_num': 1000, "mask_iou": self.iou_thresh-0.1}
|
50 |
+
for i in range(align_times):
|
51 |
+
aligned_image, info = infer_warp_mesh_img(aligned_image, dst, self.align_model, vis=False)
|
52 |
+
aligned_image = Image.fromarray(aligned_image)
|
53 |
+
print(f"{i}-th time align process, mask-iou is {info['mask_iou']}")
|
54 |
+
if info['mask_iou'] > best_info['mask_iou']:
|
55 |
+
best_aligned_image, best_info = aligned_image, info
|
56 |
+
if info['mask_iou'] < self.iou_thresh:
|
57 |
+
break
|
58 |
+
print(f"Best Baking Info:{best_info['mask_iou']}")
|
59 |
+
best_baking_flag = best_info['mask_iou'] > self.iou_thresh
|
60 |
+
return best_aligned_image, best_info, best_baking_flag
|
61 |
+
except Exception as e:
|
62 |
+
print(f"Error processing image: {e}")
|
63 |
+
traceback.print_exc()
|
64 |
+
return None, None, None
|
65 |
+
|
66 |
+
@timing_decorator("baking mesh")
|
67 |
+
def __call__(self, *args, **kwargs):
|
68 |
+
if self.save_memory:
|
69 |
+
self.align_model = self.align_model.to(self.device)
|
70 |
+
torch.cuda.empty_cache()
|
71 |
+
res = self.call(*args, **kwargs)
|
72 |
+
self.align_model = self.align_model.to("cpu")
|
73 |
+
else:
|
74 |
+
res = self.call(*args, **kwargs)
|
75 |
+
torch.cuda.empty_cache()
|
76 |
+
return res
|
77 |
+
|
78 |
+
def call(self, save_folder):
|
79 |
+
obj_path = os.path.join(save_folder, "mesh.obj")
|
80 |
+
raw_texture_path = os.path.join(save_folder, "texture.png")
|
81 |
+
views_pil = os.path.join(save_folder, "views.jpg")
|
82 |
+
views_gif = os.path.join(save_folder, "views.gif")
|
83 |
+
cond_pil = os.path.join(save_folder, "img_nobg.png")
|
84 |
+
|
85 |
+
if os.path.exists(views_pil):
|
86 |
+
views_pil = Image.open(views_pil)
|
87 |
+
views = rearrange(np.asarray(views_pil, dtype=np.uint8), '(n h) (m w) c -> (n m) h w c', n=3, m=2)
|
88 |
+
views = [Image.fromarray(views[idx]).convert('RGB') for idx in [0,2,4,5,3,1]]
|
89 |
+
cond_pil = Image.open(cond_pil).resize((512,512))
|
90 |
+
elif os.path.exists(views_gif):
|
91 |
+
views_gif_pil = Image.open(views_gif)
|
92 |
+
views = [img.convert('RGB') for img in ImageSequence.Iterator(views_gif_pil)]
|
93 |
+
cond_pil, views = views[0], views[1:]
|
94 |
+
else:
|
95 |
+
raise FileNotFoundError("views file not found")
|
96 |
+
|
97 |
+
rendered_views = render_func(obj_path, elev=0, n_views=2)
|
98 |
+
|
99 |
+
for ele_idx, ele in enumerate([0, 180]):
|
100 |
+
|
101 |
+
if ele == 0:
|
102 |
+
aligned_cond, cond_info, _ = self.align_and_check(cond_pil, rendered_views[0], align_times=self.align_times)
|
103 |
+
aligned_cond.save(save_folder + f'/aligned_cond.jpg')
|
104 |
+
|
105 |
+
aligned_img, info, _ = self.align_and_check(views[0], rendered_views[0], align_times=self.align_times)
|
106 |
+
aligned_img.save(save_folder + f'/aligned_{ele}.jpg')
|
107 |
+
|
108 |
+
if info['mask_iou'] < cond_info['mask_iou']:
|
109 |
+
print("Using Cond Image to bake front view")
|
110 |
+
aligned_img = aligned_cond
|
111 |
+
info = cond_info
|
112 |
+
need_baking = info['mask_iou'] > self.iou_thresh
|
113 |
+
else:
|
114 |
+
aligned_img, info, need_baking = self.align_and_check(views[ele//60], rendered_views[ele_idx])
|
115 |
+
aligned_img.save(save_folder + f'/aligned_{ele}.jpg')
|
116 |
+
|
117 |
+
if need_baking or (ele in self.force_baking_ele_list):
|
118 |
+
st = time.time()
|
119 |
+
view1_res = back_projection(
|
120 |
+
obj_file = obj_path,
|
121 |
+
init_texture_file = raw_texture_path,
|
122 |
+
front_view_file = aligned_img,
|
123 |
+
dst_dir = os.path.join(save_folder, f"view_{ele_idx}"),
|
124 |
+
render_resolution = aligned_img.size[0],
|
125 |
+
uv_resolution = 1024,
|
126 |
+
views = [[0, ele]],
|
127 |
+
device = self.device
|
128 |
+
)
|
129 |
+
print(f"view_{ele_idx} elevation_{ele} baking finished at {time.time() - st}")
|
130 |
+
obj_path = os.path.join(save_folder, f"view_{ele_idx}/bake/mesh.obj")
|
131 |
+
raw_texture_path = os.path.join(save_folder, f"view_{ele_idx}/bake/texture.png")
|
132 |
+
else:
|
133 |
+
print(f"Skip view_{ele_idx} elevation_{ele} baking")
|
134 |
+
|
135 |
+
print("Baking Finished")
|
136 |
+
return obj_path
|
137 |
+
|
138 |
+
|
139 |
+
if __name__ == "__main__":
|
140 |
+
baker = MeshBaker()
|
141 |
+
obj_path = baker("./outputs/test")
|
142 |
+
print(obj_path)
|
third_party/utils/camera_utils.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def compute_extrinsic_matrix(elevation, azimuth, camera_distance):
|
5 |
+
# Convert angles to radians
|
6 |
+
elevation_rad = np.radians(elevation)
|
7 |
+
azimuth_rad = np.radians(azimuth)
|
8 |
+
|
9 |
+
R = np.array([
|
10 |
+
[np.cos(azimuth_rad), 0, -np.sin(azimuth_rad)],
|
11 |
+
[0, 1, 0],
|
12 |
+
[np.sin(azimuth_rad), 0, np.cos(azimuth_rad)],
|
13 |
+
], dtype=np.float32)
|
14 |
+
|
15 |
+
R = R @ np.array([
|
16 |
+
[1, 0, 0],
|
17 |
+
[0, np.cos(elevation_rad), -np.sin(elevation_rad)],
|
18 |
+
[0, np.sin(elevation_rad), np.cos(elevation_rad)]
|
19 |
+
], dtype=np.float32)
|
20 |
+
|
21 |
+
# Construct translation matrix T (3x1)
|
22 |
+
T = np.array([[camera_distance], [0], [0]], dtype=np.float32)
|
23 |
+
T = R @ T
|
24 |
+
|
25 |
+
# Combined into a 4x4 transformation matrix
|
26 |
+
extrinsic_matrix = np.vstack((np.hstack((R, T)), np.array([[0, 0, 0, 1]], dtype=np.float32)))
|
27 |
+
|
28 |
+
return extrinsic_matrix
|
29 |
+
|
30 |
+
|
31 |
+
def transform_camera_pose(im_pose, ori_pose, new_pose):
|
32 |
+
T = new_pose @ ori_pose.T
|
33 |
+
transformed_poses = []
|
34 |
+
|
35 |
+
for pose in im_pose:
|
36 |
+
transformed_pose = T @ pose
|
37 |
+
transformed_poses.append(transformed_pose)
|
38 |
+
|
39 |
+
return transformed_poses
|
40 |
+
|
41 |
+
def compute_fov(intrinsic_matrix):
|
42 |
+
# Get the focal length value in the internal parameter matrix
|
43 |
+
fx = intrinsic_matrix[0, 0]
|
44 |
+
fy = intrinsic_matrix[1, 1]
|
45 |
+
|
46 |
+
h, w = intrinsic_matrix[0,2]*2, intrinsic_matrix[1,2]*2
|
47 |
+
|
48 |
+
# Calculate horizontal and vertical FOV values
|
49 |
+
fov_x = 2 * math.atan(w / (2 * fx)) * 180 / math.pi
|
50 |
+
fov_y = 2 * math.atan(h / (2 * fy)) * 180 / math.pi
|
51 |
+
|
52 |
+
return fov_x, fov_y
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
def rotation_matrix_to_quaternion(rotation_matrix):
|
57 |
+
rot = Rotation.from_matrix(rotation_matrix)
|
58 |
+
quaternion = rot.as_quat()
|
59 |
+
return quaternion
|
60 |
+
|
61 |
+
def quaternion_to_rotation_matrix(quaternion):
|
62 |
+
rot = Rotation.from_quat(quaternion)
|
63 |
+
rotation_matrix = rot.as_matrix()
|
64 |
+
return rotation_matrix
|
65 |
+
|
66 |
+
def remap_points(img_size, match, size=512):
|
67 |
+
H, W, _ = img_size
|
68 |
+
|
69 |
+
S = max(W, H)
|
70 |
+
new_W = int(round(W * size / S))
|
71 |
+
new_H = int(round(H * size / S))
|
72 |
+
cx, cy = new_W // 2, new_H // 2
|
73 |
+
|
74 |
+
# Calculate the coordinates of the transformed image center point
|
75 |
+
halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
|
76 |
+
|
77 |
+
dw, dh = cx - halfw, cy - halfh
|
78 |
+
|
79 |
+
# store point coordinates mapped back to the original image
|
80 |
+
new_match = np.zeros_like(match)
|
81 |
+
|
82 |
+
# Map the transformed point coordinates back to the original image
|
83 |
+
new_match[:, 0] = (match[:, 0] + dw) / new_W * W
|
84 |
+
new_match[:, 1] = (match[:, 1] + dh) / new_H * H
|
85 |
+
|
86 |
+
#print(dw,new_W,W,dh,new_H,H)
|
87 |
+
|
88 |
+
return new_match
|
89 |
+
|
90 |
+
|
third_party/utils/img_utils.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from skimage.metrics import hausdorff_distance
|
5 |
+
from matplotlib import pyplot as plt
|
6 |
+
|
7 |
+
|
8 |
+
def get_input_imgs_path(input_data_dir):
|
9 |
+
path = {}
|
10 |
+
names = ['000', 'ori_000']
|
11 |
+
for name in names:
|
12 |
+
jpg_path = os.path.join(input_data_dir, f"{name}.jpg")
|
13 |
+
png_path = os.path.join(input_data_dir, f"{name}.png")
|
14 |
+
if os.path.exists(jpg_path):
|
15 |
+
path[name] = jpg_path
|
16 |
+
elif os.path.exists(png_path):
|
17 |
+
path[name] = png_path
|
18 |
+
return path
|
19 |
+
|
20 |
+
|
21 |
+
def rgba_to_rgb(image, bg_color=[255, 255, 255]):
|
22 |
+
if image.shape[-1] == 3: return image
|
23 |
+
|
24 |
+
rgba = image.astype(float)
|
25 |
+
rgb = rgba[:, :, :3].copy()
|
26 |
+
alpha = rgba[:, :, 3] / 255.0
|
27 |
+
|
28 |
+
bg = np.ones((image.shape[0], image.shape[1], 3), dtype=np.float32)
|
29 |
+
bg = bg * np.array(bg_color, dtype=np.float32)
|
30 |
+
|
31 |
+
rgb = rgb * alpha[:, :, np.newaxis] + bg * (1 - alpha[:, :, np.newaxis])
|
32 |
+
rgb = rgb.astype(np.uint8)
|
33 |
+
|
34 |
+
return rgb
|
35 |
+
|
36 |
+
|
37 |
+
def resize_with_aspect_ratio(image1, image2, pad_value=[255, 255, 255]):
|
38 |
+
aspect_ratio1 = float(image1.shape[1]) / float(image1.shape[0])
|
39 |
+
aspect_ratio2 = float(image2.shape[1]) / float(image2.shape[0])
|
40 |
+
|
41 |
+
top_pad, bottom_pad, left_pad, right_pad = 0, 0, 0, 0
|
42 |
+
|
43 |
+
if aspect_ratio1 < aspect_ratio2:
|
44 |
+
new_width = (aspect_ratio2 * image1.shape[0])
|
45 |
+
right_pad = left_pad = int((new_width - image1.shape[1]) / 2)
|
46 |
+
else:
|
47 |
+
new_height = (image1.shape[1] / aspect_ratio2)
|
48 |
+
bottom_pad = top_pad = int((new_height - image1.shape[0]) / 2)
|
49 |
+
|
50 |
+
image1_padded = cv2.copyMakeBorder(
|
51 |
+
image1, top_pad, bottom_pad, left_pad, right_pad, cv2.BORDER_CONSTANT, value=pad_value
|
52 |
+
)
|
53 |
+
return image1_padded
|
54 |
+
|
55 |
+
|
56 |
+
def estimate_img_mask(image):
|
57 |
+
# to gray
|
58 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
59 |
+
|
60 |
+
# segment
|
61 |
+
# _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
62 |
+
# mask_otsu = thresh.astype(bool)
|
63 |
+
# thresh_gray = 240
|
64 |
+
|
65 |
+
edges = cv2.Canny(gray, 20, 50)
|
66 |
+
|
67 |
+
kernel = np.ones((3, 3), np.uint8)
|
68 |
+
edges_dilated = cv2.dilate(edges, kernel, iterations=1)
|
69 |
+
|
70 |
+
contours, _ = cv2.findContours(edges_dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
71 |
+
|
72 |
+
mask = np.zeros_like(gray, dtype=np.uint8)
|
73 |
+
|
74 |
+
cv2.drawContours(mask, contours, -1, 255, thickness=cv2.FILLED)
|
75 |
+
mask = mask.astype(bool)
|
76 |
+
|
77 |
+
return mask
|
78 |
+
|
79 |
+
|
80 |
+
def compute_img_diff(img1, img2, matches1, matches1_from_2, vis=False):
|
81 |
+
scale = 0.125
|
82 |
+
gray_trunc_thres = 25 / 255.0
|
83 |
+
|
84 |
+
# Match
|
85 |
+
if matches1.shape[0] > 0:
|
86 |
+
match_scale = np.max(np.ptp(matches1, axis=-1))
|
87 |
+
match_dists = np.sqrt(np.sum((matches1 - matches1_from_2) ** 2, axis=-1))
|
88 |
+
dist_threshold = match_scale * 0.01
|
89 |
+
match_num = np.sum(match_dists <= dist_threshold)
|
90 |
+
match_rate = np.mean(match_dists <= dist_threshold)
|
91 |
+
else:
|
92 |
+
match_num = 0
|
93 |
+
match_rate = 0
|
94 |
+
|
95 |
+
# IOU
|
96 |
+
img1_mask = estimate_img_mask(img1)
|
97 |
+
img2_mask = estimate_img_mask(img2)
|
98 |
+
img_intersection = (img1_mask == 1) & (img2_mask == 1)
|
99 |
+
img_union = (img1_mask == 1) | (img2_mask == 1)
|
100 |
+
intersection = np.sum(img_intersection == 1)
|
101 |
+
union = np.sum(img_union == 1)
|
102 |
+
mask_iou = intersection / union if union != 0 else 0
|
103 |
+
|
104 |
+
# Gray
|
105 |
+
height, width = img1.shape[:2]
|
106 |
+
img1_gray = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
|
107 |
+
img2_gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
|
108 |
+
img1_gray = cv2.GaussianBlur(img1_gray, (7, 7), 0)
|
109 |
+
img2_gray = cv2.GaussianBlur(img2_gray, (7, 7), 0)
|
110 |
+
|
111 |
+
# Gray Diff
|
112 |
+
img1_gray_small = cv2.resize(img1_gray, (int(width * scale), int(height * scale)),
|
113 |
+
interpolation=cv2.INTER_LINEAR) / 255.0
|
114 |
+
img2_gray_small = cv2.resize(img2_gray, (int(width * scale), int(height * scale)),
|
115 |
+
interpolation=cv2.INTER_LINEAR) / 255.0
|
116 |
+
img_gray_small_diff = np.abs(img1_gray_small - img2_gray_small)
|
117 |
+
gray_diff = img_gray_small_diff.sum() / (union * scale) if union != 0 else 1
|
118 |
+
|
119 |
+
img_gray_small_diff_trunc = img_gray_small_diff.copy()
|
120 |
+
img_gray_small_diff_trunc[img_gray_small_diff < gray_trunc_thres] = 0
|
121 |
+
gray_diff_trunc = img_gray_small_diff_trunc.sum() / (union * scale) if union != 0 else 1
|
122 |
+
|
123 |
+
# Edge
|
124 |
+
img1_edge = cv2.Canny(img1_gray, 100, 200)
|
125 |
+
img2_edge = cv2.Canny(img2_gray, 100, 200)
|
126 |
+
bw_edges1 = (img1_edge > 0).astype(bool)
|
127 |
+
bw_edges2 = (img2_edge > 0).astype(bool)
|
128 |
+
hausdorff_dist = hausdorff_distance(bw_edges1, bw_edges2)
|
129 |
+
if vis == True:
|
130 |
+
fig, axs = plt.subplots(1, 4, figsize=(15, 5))
|
131 |
+
axs[0].imshow(img1_gray, cmap='gray')
|
132 |
+
axs[0].set_title('Img1')
|
133 |
+
axs[1].imshow(img2_gray, cmap='gray')
|
134 |
+
axs[1].set_title('Img2')
|
135 |
+
axs[2].imshow(img1_mask)
|
136 |
+
axs[2].set_title('Mask1')
|
137 |
+
axs[3].imshow(img2_mask)
|
138 |
+
axs[3].set_title('Mask2')
|
139 |
+
plt.show()
|
140 |
+
plt.figure()
|
141 |
+
mask_cmp = np.zeros((height, width, 3))
|
142 |
+
mask_cmp[img_intersection, 1] = 1
|
143 |
+
mask_cmp[img_union, 0] = 1
|
144 |
+
plt.imshow(mask_cmp)
|
145 |
+
plt.show()
|
146 |
+
fig, axs = plt.subplots(1, 4, figsize=(15, 5))
|
147 |
+
axs[0].imshow(img1_gray_small, cmap='gray')
|
148 |
+
axs[0].set_title('Img1 Gray')
|
149 |
+
axs[1].imshow(img2_gray_small, cmap='gray')
|
150 |
+
axs[1].set_title('Img2 Gary')
|
151 |
+
axs[2].imshow(img_gray_small_diff, cmap='gray')
|
152 |
+
axs[2].set_title('diff')
|
153 |
+
axs[3].imshow(img_gray_small_diff_trunc, cmap='gray')
|
154 |
+
axs[3].set_title('diff_trunct')
|
155 |
+
plt.show()
|
156 |
+
fig, axs = plt.subplots(1, 2, figsize=(15, 5))
|
157 |
+
axs[0].imshow(img1_edge, cmap='gray')
|
158 |
+
axs[0].set_title('img1_edge')
|
159 |
+
axs[1].imshow(img2_edge, cmap='gray')
|
160 |
+
axs[1].set_title('img2_edge')
|
161 |
+
plt.show()
|
162 |
+
|
163 |
+
info = {}
|
164 |
+
info['match_num'] = match_num
|
165 |
+
info['match_rate'] = match_rate
|
166 |
+
info['mask_iou'] = mask_iou
|
167 |
+
info['gray_diff'] = gray_diff
|
168 |
+
info['gray_diff_trunc'] = gray_diff_trunc
|
169 |
+
info['hausdorff_dist'] = hausdorff_dist
|
170 |
+
return info
|
171 |
+
|
172 |
+
|
173 |
+
def predict_match_success_human(info):
|
174 |
+
match_num = info['match_num']
|
175 |
+
match_rate = info['match_rate']
|
176 |
+
mask_iou = info['mask_iou']
|
177 |
+
gray_diff = info['gray_diff']
|
178 |
+
gray_diff_trunc = info['gray_diff_trunc']
|
179 |
+
hausdorff_dist = info['hausdorff_dist']
|
180 |
+
|
181 |
+
if mask_iou > 0.95:
|
182 |
+
return True
|
183 |
+
|
184 |
+
if match_num < 20 or match_rate < 0.7:
|
185 |
+
return False
|
186 |
+
|
187 |
+
if mask_iou > 0.80 and gray_diff < 0.040 and gray_diff_trunc < 0.010:
|
188 |
+
return True
|
189 |
+
|
190 |
+
if mask_iou > 0.70 and gray_diff < 0.050 and gray_diff_trunc < 0.008:
|
191 |
+
return True
|
192 |
+
|
193 |
+
'''
|
194 |
+
if match_rate<0.70 or match_num<3000:
|
195 |
+
return False
|
196 |
+
|
197 |
+
if (mask_iou>0.85 and hausdorff_dist<20)or (gray_diff<0.015 and gray_diff_trunc<0.01) or match_rate>=0.90:
|
198 |
+
return True
|
199 |
+
'''
|
200 |
+
|
201 |
+
return False
|
202 |
+
|
203 |
+
|
204 |
+
def predict_match_success(info, model=None):
|
205 |
+
if model == None:
|
206 |
+
return predict_match_success_human(info)
|
207 |
+
else:
|
208 |
+
feat_name = ['match_num', 'match_rate', 'mask_iou', 'gray_diff', 'gray_diff_trunc', 'hausdorff_dist']
|
209 |
+
features = [info[f] for f in feat_name]
|
210 |
+
pred = model.predict([features])[0]
|
211 |
+
return pred >= 0.5
|