HongFangzhou
commited on
Commit
•
8ee45cc
1
Parent(s):
6fcfbfd
3DTopia test
Browse files- app.py +145 -92
- requirements.txt +4 -1
app.py
CHANGED
@@ -2,7 +2,10 @@ import os
|
|
2 |
import sys
|
3 |
import cv2
|
4 |
import time
|
|
|
5 |
import json
|
|
|
|
|
6 |
import torch
|
7 |
import mcubes
|
8 |
import trimesh
|
@@ -11,7 +14,6 @@ import argparse
|
|
11 |
import subprocess
|
12 |
import numpy as np
|
13 |
import gradio as gr
|
14 |
-
from tqdm import tqdm
|
15 |
import imageio.v2 as imageio
|
16 |
import pytorch_lightning as pl
|
17 |
from omegaconf import OmegaConf
|
@@ -28,10 +30,90 @@ from utility.initialize import instantiate_from_config, get_obj_from_str
|
|
28 |
from utility.triplane_renderer.eg3d_renderer import sample_from_planes, generate_planes
|
29 |
from utility.triplane_renderer.renderer import get_rays, to8b
|
30 |
|
|
|
|
|
|
|
31 |
import warnings
|
32 |
warnings.filterwarnings("ignore", category=UserWarning)
|
33 |
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
def add_text(rgb, caption):
|
36 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
37 |
# org
|
@@ -51,76 +133,6 @@ def add_text(rgb, caption):
|
|
51 |
cv2.putText(rgb, bci, (gap, gap*(i+1)), font, fontScale, color, thickness, cv2.LINE_AA)
|
52 |
return rgb
|
53 |
|
54 |
-
config = "3DTopia/configs/default.yaml"
|
55 |
-
# local_ckpt = "3DTopia/checkpoints/3dtopia_diffusion_state_dict.ckpt"
|
56 |
-
local_ckpt = "/data/3DTopia_all/3DTopia_code/checkpoints/model.safetensors"
|
57 |
-
if os.path.exists(local_ckpt):
|
58 |
-
ckpt = local_ckpt
|
59 |
-
else:
|
60 |
-
ckpt = hf_hub_download(repo_id="hongfz16/3DTopia", filename="model.safetensors")
|
61 |
-
configs = OmegaConf.load(config)
|
62 |
-
os.makedirs("tmp", exist_ok=True)
|
63 |
-
|
64 |
-
import sys
|
65 |
-
import traceback
|
66 |
-
|
67 |
-
try:
|
68 |
-
if ckpt.endswith(".ckpt"):
|
69 |
-
model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(ckpt, map_location='cpu', strict=False, **configs.model.params)
|
70 |
-
elif ckpt.endswith(".safetensors"):
|
71 |
-
model = get_obj_from_str(configs.model["target"])(**configs.model.params)
|
72 |
-
print("download finish")
|
73 |
-
model_ckpt = load_file(ckpt)
|
74 |
-
print("download finish")
|
75 |
-
model.load_state_dict(model_ckpt)
|
76 |
-
print("download finish")
|
77 |
-
else:
|
78 |
-
raise NotImplementedError
|
79 |
-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
80 |
-
model = model.to(device)
|
81 |
-
print("download finish")
|
82 |
-
sampler = DDIMSampler(model)
|
83 |
-
|
84 |
-
img_size = configs.model.params.unet_config.params.image_size
|
85 |
-
channels = configs.model.params.unet_config.params.in_channels
|
86 |
-
shape = [channels, img_size, img_size * 3]
|
87 |
-
|
88 |
-
pose_folder = '3DTopia/assets/sample_data/pose'
|
89 |
-
poses_fname = sorted([os.path.join(pose_folder, f) for f in os.listdir(pose_folder)])
|
90 |
-
batch_rays_list = []
|
91 |
-
H = 128
|
92 |
-
ratio = 512 // H
|
93 |
-
for p in poses_fname:
|
94 |
-
c2w = np.loadtxt(p).reshape(4, 4)
|
95 |
-
c2w[:3, 3] *= 2.2
|
96 |
-
c2w = np.array([
|
97 |
-
[1, 0, 0, 0],
|
98 |
-
[0, 0, -1, 0],
|
99 |
-
[0, 1, 0, 0],
|
100 |
-
[0, 0, 0, 1]
|
101 |
-
]) @ c2w
|
102 |
-
|
103 |
-
k = np.array([
|
104 |
-
[560 / ratio, 0, H * 0.5],
|
105 |
-
[0, 560 / ratio, H * 0.5],
|
106 |
-
[0, 0, 1]
|
107 |
-
])
|
108 |
-
|
109 |
-
rays_o, rays_d = get_rays(H, H, torch.Tensor(k), torch.Tensor(c2w[:3, :4]))
|
110 |
-
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, H-1, H), indexing='ij'), -1)
|
111 |
-
coords = torch.reshape(coords, [-1,2]).long()
|
112 |
-
rays_o = rays_o[coords[:, 0], coords[:, 1]]
|
113 |
-
rays_d = rays_d[coords[:, 0], coords[:, 1]]
|
114 |
-
batch_rays = torch.stack([rays_o, rays_d], 0)
|
115 |
-
batch_rays_list.append(batch_rays)
|
116 |
-
batch_rays_list = torch.stack(batch_rays_list, 0)
|
117 |
-
except Exception as e:
|
118 |
-
print(e)
|
119 |
-
print(traceback.format_exc())
|
120 |
-
print(sys.exc_info()[2])
|
121 |
-
|
122 |
-
|
123 |
-
print("download finish")
|
124 |
def marching_cube(b, text, global_info):
|
125 |
# prepare volumn for marching cube
|
126 |
res = 128
|
@@ -169,7 +181,7 @@ def marching_cube(b, text, global_info):
|
|
169 |
]
|
170 |
rgb_final = None
|
171 |
diff_final = None
|
172 |
-
for rays_o in tqdm(rays_o_list):
|
173 |
rays_o = torch.from_numpy(rays_o.reshape(1, 3)).repeat(vertices.shape[0], 1).float().to(device)
|
174 |
rays_d = pt_vertices.reshape(-1, 3) - rays_o
|
175 |
rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1)
|
@@ -246,7 +258,7 @@ def infer(prompt, samples, steps, scale, seed, global_info):
|
|
246 |
|
247 |
view_num = len(batch_rays_list)
|
248 |
video_list = []
|
249 |
-
for v in tqdm(range(view_num//8*3, view_num//8*5, 2)):
|
250 |
rgb_sample = render_img(v)
|
251 |
video_list.append(rgb_sample)
|
252 |
big_video_list.append(video_list)
|
@@ -287,25 +299,62 @@ def infer(prompt, samples, steps, scale, seed, global_info):
|
|
287 |
|
288 |
return global_info, path
|
289 |
|
290 |
-
def infer_stage2(prompt, selection, seed, global_info):
|
291 |
prompt = prompt.replace('/', '')
|
292 |
mesh_path = marching_cube(int(selection), prompt, global_info)
|
293 |
mesh_name = mesh_path.split('/')[-1][:-4]
|
294 |
-
|
295 |
-
if2_cmd
|
296 |
-
|
297 |
-
#
|
298 |
-
subprocess.Popen(if2_cmd, shell=True).wait()
|
299 |
-
torch.cuda.empty_cache()
|
300 |
-
|
301 |
video_path = f"tmp/{prompt.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.mp4"
|
302 |
-
render_cmd = f"kire {os.path.join('tmp', mesh_name + '_if2.glb')} --save_video {video_path} --wogui --force_cuda_rast --H 256 --W 256"
|
303 |
-
print(render_cmd)
|
304 |
-
#
|
305 |
-
|
|
|
|
|
306 |
torch.cuda.empty_cache()
|
307 |
|
308 |
-
return video_path,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
markdown=f'''
|
311 |
# 3DTopia
|
@@ -315,7 +364,7 @@ markdown=f'''
|
|
315 |
First enter prompt for a 3D object, hit "Generate 3D". Then choose one candidate from the dropdown options for the second stage refinement and hit "Start Refinement". The final mesh can be downloaded from the bottom right box.
|
316 |
|
317 |
### Runtime:
|
318 |
-
The first stage takes 30s if generating 4 samples. The second stage takes roughly
|
319 |
|
320 |
### Useful links:
|
321 |
[Github Repo](https://github.com/3DTopia/3DTopia)
|
@@ -337,7 +386,7 @@ with block:
|
|
337 |
)
|
338 |
btn = gr.Button("Generate 3D")
|
339 |
gallery = gr.Video(height=512)
|
340 |
-
# advanced_button = gr.Button("Advanced
|
341 |
with gr.Row(elem_id="advanced-options"):
|
342 |
with gr.Tab("Advanced options"):
|
343 |
samples = gr.Slider(label="Number of Samples", minimum=1, maximum=4, value=4, step=1)
|
@@ -361,11 +410,15 @@ with block:
|
|
361 |
with gr.Column():
|
362 |
with gr.Row():
|
363 |
dropdown = gr.Dropdown(
|
364 |
-
['0', '1', '2', '3'], label="Choose a
|
365 |
)
|
366 |
btn_stage2 = gr.Button("Start Refinement")
|
367 |
gallery = gr.Video(height=512)
|
368 |
-
|
369 |
-
|
|
|
|
|
|
|
|
|
370 |
|
371 |
-
block.launch(share=True
|
|
|
2 |
import sys
|
3 |
import cv2
|
4 |
import time
|
5 |
+
import tyro
|
6 |
import json
|
7 |
+
import kiui
|
8 |
+
import tqdm
|
9 |
import torch
|
10 |
import mcubes
|
11 |
import trimesh
|
|
|
14 |
import subprocess
|
15 |
import numpy as np
|
16 |
import gradio as gr
|
|
|
17 |
import imageio.v2 as imageio
|
18 |
import pytorch_lightning as pl
|
19 |
from omegaconf import OmegaConf
|
|
|
30 |
from utility.triplane_renderer.eg3d_renderer import sample_from_planes, generate_planes
|
31 |
from utility.triplane_renderer.renderer import get_rays, to8b
|
32 |
|
33 |
+
from threefiner.gui import GUI
|
34 |
+
from threefiner.opt import config_defaults, config_doc, check_options, Options
|
35 |
+
|
36 |
import warnings
|
37 |
warnings.filterwarnings("ignore", category=UserWarning)
|
38 |
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
39 |
|
40 |
+
###################################### INIT STAGE 1 #########################################
|
41 |
+
config = "3DTopia/configs/default.yaml"
|
42 |
+
download_ckpt = "3DTopia/checkpoints/3dtopia_diffusion_state_dict.ckpt"
|
43 |
+
if not os.path.exists(download_ckpt):
|
44 |
+
ckpt = hf_hub_download(repo_id="hongfz16/3DTopia", filename="model.safetensors")
|
45 |
+
else:
|
46 |
+
ckpt = download_ckpt
|
47 |
+
configs = OmegaConf.load(config)
|
48 |
+
os.makedirs("tmp", exist_ok=True)
|
49 |
+
|
50 |
+
if ckpt.endswith(".ckpt"):
|
51 |
+
model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(ckpt, map_location='cpu', strict=False, **configs.model.params)
|
52 |
+
elif ckpt.endswith(".safetensors"):
|
53 |
+
model = get_obj_from_str(configs.model["target"])(**configs.model.params)
|
54 |
+
model_ckpt = load_file(ckpt)
|
55 |
+
model.load_state_dict(model_ckpt)
|
56 |
+
else:
|
57 |
+
raise NotImplementedError
|
58 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
59 |
+
model = model.to(device)
|
60 |
+
sampler = DDIMSampler(model)
|
61 |
+
|
62 |
+
img_size = configs.model.params.unet_config.params.image_size
|
63 |
+
channels = configs.model.params.unet_config.params.in_channels
|
64 |
+
shape = [channels, img_size, img_size * 3]
|
65 |
+
|
66 |
+
pose_folder = '3DTopia/assets/sample_data/pose'
|
67 |
+
poses_fname = sorted([os.path.join(pose_folder, f) for f in os.listdir(pose_folder)])
|
68 |
+
batch_rays_list = []
|
69 |
+
H = 128
|
70 |
+
ratio = 512 // H
|
71 |
+
for p in poses_fname:
|
72 |
+
c2w = np.loadtxt(p).reshape(4, 4)
|
73 |
+
c2w[:3, 3] *= 2.2
|
74 |
+
c2w = np.array([
|
75 |
+
[1, 0, 0, 0],
|
76 |
+
[0, 0, -1, 0],
|
77 |
+
[0, 1, 0, 0],
|
78 |
+
[0, 0, 0, 1]
|
79 |
+
]) @ c2w
|
80 |
+
|
81 |
+
k = np.array([
|
82 |
+
[560 / ratio, 0, H * 0.5],
|
83 |
+
[0, 560 / ratio, H * 0.5],
|
84 |
+
[0, 0, 1]
|
85 |
+
])
|
86 |
+
|
87 |
+
rays_o, rays_d = get_rays(H, H, torch.Tensor(k), torch.Tensor(c2w[:3, :4]))
|
88 |
+
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, H-1, H), indexing='ij'), -1)
|
89 |
+
coords = torch.reshape(coords, [-1,2]).long()
|
90 |
+
rays_o = rays_o[coords[:, 0], coords[:, 1]]
|
91 |
+
rays_d = rays_d[coords[:, 0], coords[:, 1]]
|
92 |
+
batch_rays = torch.stack([rays_o, rays_d], 0)
|
93 |
+
batch_rays_list.append(batch_rays)
|
94 |
+
batch_rays_list = torch.stack(batch_rays_list, 0)
|
95 |
+
###################################### INIT STAGE 1 #########################################
|
96 |
+
|
97 |
+
###################################### INIT STAGE 2 #########################################
|
98 |
+
GRADIO_SAVE_PATH_MESH = 'gradio_output.glb'
|
99 |
+
GRADIO_SAVE_PATH_VIDEO = 'gradio_output.mp4'
|
100 |
+
|
101 |
+
# opt = tyro.cli(tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc))
|
102 |
+
opt = Options(
|
103 |
+
mode='IF2',
|
104 |
+
iters=400,
|
105 |
+
)
|
106 |
+
|
107 |
+
# hacks for not loading mesh at initialization
|
108 |
+
# opt.mesh = 'tmp/_2024-01-25_19:33:03.110191_if2.glb'
|
109 |
+
opt.save = GRADIO_SAVE_PATH_MESH
|
110 |
+
opt.prompt = ''
|
111 |
+
opt.text_dir = True
|
112 |
+
opt.front_dir = '+z'
|
113 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
114 |
+
gui = GUI(opt)
|
115 |
+
###################################### INIT STAGE 2 #########################################
|
116 |
+
|
117 |
def add_text(rgb, caption):
|
118 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
119 |
# org
|
|
|
133 |
cv2.putText(rgb, bci, (gap, gap*(i+1)), font, fontScale, color, thickness, cv2.LINE_AA)
|
134 |
return rgb
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
def marching_cube(b, text, global_info):
|
137 |
# prepare volumn for marching cube
|
138 |
res = 128
|
|
|
181 |
]
|
182 |
rgb_final = None
|
183 |
diff_final = None
|
184 |
+
for rays_o in tqdm.tqdm(rays_o_list):
|
185 |
rays_o = torch.from_numpy(rays_o.reshape(1, 3)).repeat(vertices.shape[0], 1).float().to(device)
|
186 |
rays_d = pt_vertices.reshape(-1, 3) - rays_o
|
187 |
rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1)
|
|
|
258 |
|
259 |
view_num = len(batch_rays_list)
|
260 |
video_list = []
|
261 |
+
for v in tqdm.tqdm(range(view_num//8*3, view_num//8*5, 2)):
|
262 |
rgb_sample = render_img(v)
|
263 |
video_list.append(rgb_sample)
|
264 |
big_video_list.append(video_list)
|
|
|
299 |
|
300 |
return global_info, path
|
301 |
|
302 |
+
def infer_stage2(prompt, selection, seed, global_info, iters):
|
303 |
prompt = prompt.replace('/', '')
|
304 |
mesh_path = marching_cube(int(selection), prompt, global_info)
|
305 |
mesh_name = mesh_path.split('/')[-1][:-4]
|
306 |
+
# if2_cmd = f"threefiner if2 --mesh {mesh_path} --prompt \"{prompt}\" --outdir tmp --save {mesh_name}_if2.glb --text_dir --front_dir=-y"
|
307 |
+
# print(if2_cmd)
|
308 |
+
# subprocess.Popen(if2_cmd, shell=True).wait()
|
309 |
+
# torch.cuda.empty_cache()
|
|
|
|
|
|
|
310 |
video_path = f"tmp/{prompt.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.mp4"
|
311 |
+
# render_cmd = f"kire {os.path.join('tmp', mesh_name + '_if2.glb')} --save_video {video_path} --wogui --force_cuda_rast --H 256 --W 256"
|
312 |
+
# print(render_cmd)
|
313 |
+
# subprocess.Popen(render_cmd, shell=True).wait()
|
314 |
+
# torch.cuda.empty_cache()
|
315 |
+
|
316 |
+
process_stage2(mesh_path, prompt, "down", iters, f'tmp/{mesh_name}_if2.glb', video_path)
|
317 |
torch.cuda.empty_cache()
|
318 |
|
319 |
+
return video_path, f'tmp/{mesh_name}_if2.glb'
|
320 |
+
|
321 |
+
def process_stage2(input_model, input_text, input_dir, iters, output_model, output_video):
|
322 |
+
# set front facing direction (map from gradio model3D's mysterious coordinate system to OpenGL...)
|
323 |
+
opt.text_dir = True
|
324 |
+
if input_dir == 'front':
|
325 |
+
opt.front_dir = '-z'
|
326 |
+
elif input_dir == 'back':
|
327 |
+
opt.front_dir = '+z'
|
328 |
+
elif input_dir == 'left':
|
329 |
+
opt.front_dir = '+x'
|
330 |
+
elif input_dir == 'right':
|
331 |
+
opt.front_dir = '-x'
|
332 |
+
elif input_dir == 'up':
|
333 |
+
opt.front_dir = '+y'
|
334 |
+
elif input_dir == 'down':
|
335 |
+
opt.front_dir = '-y'
|
336 |
+
else:
|
337 |
+
# turn off text_dir
|
338 |
+
opt.text_dir = False
|
339 |
+
opt.front_dir = '+z'
|
340 |
+
|
341 |
+
# set mesh path
|
342 |
+
opt.mesh = input_model
|
343 |
+
|
344 |
+
# load mesh!
|
345 |
+
gui.renderer = gui.renderer_class(opt, device).to(device)
|
346 |
+
|
347 |
+
# set prompt
|
348 |
+
gui.prompt = opt.positive_prompt + ', ' + input_text
|
349 |
+
|
350 |
+
# train
|
351 |
+
gui.prepare_train() # update optimizer and prompt embeddings
|
352 |
+
for i in tqdm.trange(iters):
|
353 |
+
gui.train_step()
|
354 |
+
|
355 |
+
# save mesh & video
|
356 |
+
gui.save_model(output_model)
|
357 |
+
gui.save_model(output_video)
|
358 |
|
359 |
markdown=f'''
|
360 |
# 3DTopia
|
|
|
364 |
First enter prompt for a 3D object, hit "Generate 3D". Then choose one candidate from the dropdown options for the second stage refinement and hit "Start Refinement". The final mesh can be downloaded from the bottom right box.
|
365 |
|
366 |
### Runtime:
|
367 |
+
The first stage takes 30s if generating 4 samples. The second stage takes roughly 1m30s.
|
368 |
|
369 |
### Useful links:
|
370 |
[Github Repo](https://github.com/3DTopia/3DTopia)
|
|
|
386 |
)
|
387 |
btn = gr.Button("Generate 3D")
|
388 |
gallery = gr.Video(height=512)
|
389 |
+
# advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
|
390 |
with gr.Row(elem_id="advanced-options"):
|
391 |
with gr.Tab("Advanced options"):
|
392 |
samples = gr.Slider(label="Number of Samples", minimum=1, maximum=4, value=4, step=1)
|
|
|
410 |
with gr.Column():
|
411 |
with gr.Row():
|
412 |
dropdown = gr.Dropdown(
|
413 |
+
['0', '1', '2', '3'], label="Choose a Candidate For Stage2", value='0'
|
414 |
)
|
415 |
btn_stage2 = gr.Button("Start Refinement")
|
416 |
gallery = gr.Video(height=512)
|
417 |
+
with gr.Row(elem_id="advanced-options"):
|
418 |
+
with gr.Tab("Advanced options"):
|
419 |
+
# input_dir = gr.Radio(['front', 'back', 'left', 'right', 'up', 'down'], value='down', label="front-facing direction")
|
420 |
+
iters = gr.Slider(minimum=100, maximum=1000, step=100, value=400, label="Refine iterations")
|
421 |
+
download = gr.File(label="Download Mesh", file_count="single", height=100)
|
422 |
+
gr.on([btn_stage2.click], infer_stage2, inputs=[text, dropdown, seed, global_info, iters], outputs=[gallery, download])
|
423 |
|
424 |
+
block.launch(share=True)
|
requirements.txt
CHANGED
@@ -54,4 +54,7 @@ trimesh
|
|
54 |
vit-pytorch
|
55 |
wandb
|
56 |
wcwidth
|
57 |
-
zipp
|
|
|
|
|
|
|
|
54 |
vit-pytorch
|
55 |
wandb
|
56 |
wcwidth
|
57 |
+
zipp
|
58 |
+
git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
|
59 |
+
git+https://github.com/NVlabs/nvdiffrast
|
60 |
+
git+https://github.com/3DTopia/threefiner
|