LTT commited on
Commit
72e5710
·
1 Parent(s): 84fce77
Files changed (1) hide show
  1. demo.py +325 -0
demo.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import subprocess
4
+ import shlex
5
+ import spaces
6
+ import torch
7
+ import numpy as numpy
8
+ access_token = os.getenv("HUGGINGFACE_TOKEN")
9
+ subprocess.run(
10
+ shlex.split(
11
+ "pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt210/download.html"
12
+ )
13
+ )
14
+
15
+ subprocess.run(
16
+ shlex.split(
17
+ "pip install ./extension/nvdiffrast-0.3.1+torch-py3-none-any.whl --force-reinstall --no-deps"
18
+ )
19
+ )
20
+
21
+ subprocess.run(
22
+ shlex.split(
23
+ "pip install ./extension/renderutils_plugin-1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
24
+ )
25
+ )
26
+ def install_cuda_toolkit():
27
+ # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
28
+ # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
29
+ CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
30
+ CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
31
+ subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
32
+ subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
33
+ subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
34
+
35
+ os.environ["CUDA_HOME"] = "/usr/local/cuda"
36
+ os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
37
+ os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
38
+ os.environ["CUDA_HOME"],
39
+ "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
40
+ )
41
+ # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
42
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
43
+ print("==> finfish install")
44
+ # install_cuda_toolkit()
45
+ @spaces.GPU
46
+ def check_gpu():
47
+ os.environ['CUDA_HOME'] = '/usr/local/cuda-12.1'
48
+ os.environ['PATH'] += ':/usr/local/cuda-12.1/bin'
49
+ # os.environ['LD_LIBRARY_PATH'] += ':/usr/local/cuda-12.1/lib64'
50
+ os.environ['LD_LIBRARY_PATH'] = "/usr/local/cuda-12.1/lib64:" + os.environ.get('LD_LIBRARY_PATH', '')
51
+ subprocess.run(['nvidia-smi']) # 测试 CUDA 是否可用
52
+ print(f"torch.cuda.is_available:{torch.cuda.is_available()}")
53
+ check_gpu()
54
+
55
+ from PIL import Image
56
+ from einops import rearrange
57
+ from diffusers import FluxPipeline
58
+ from models.lrm.utils.camera_util import get_flux_input_cameras
59
+ from models.lrm.utils.infer_util import save_video
60
+ from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
61
+ from models.lrm.utils.render_utils import rotate_x, rotate_y
62
+ from models.lrm.utils.train_util import instantiate_from_config
63
+ from models.ISOMER.reconstruction_func import reconstruction
64
+ from models.ISOMER.projection_func import projection
65
+ import os
66
+ from einops import rearrange
67
+ from omegaconf import OmegaConf
68
+ import torch
69
+ import numpy as np
70
+ import trimesh
71
+ import torchvision
72
+ import torch.nn.functional as F
73
+ from PIL import Image
74
+ from torchvision import transforms
75
+ from torchvision.transforms import v2
76
+ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
77
+ from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
78
+ from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
79
+ from diffusers import FluxPipeline
80
+ from pytorch_lightning import seed_everything
81
+ import os
82
+ from huggingface_hub import hf_hub_download
83
+
84
+
85
+ from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
86
+
87
+ device_0 = "cuda:0"
88
+ device_1 = "cuda:1"
89
+ resolution = 512
90
+ save_dir = "./outputs"
91
+ normal_transfer = NormalTransfer()
92
+ isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device_1)
93
+ isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device_1)
94
+ isomer_radius = 4.5
95
+ isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device_1)
96
+ isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device_1)
97
+
98
+ # model initialization and loading
99
+ # flux
100
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to(device_0)
101
+ good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, token=access_token).to(device_0)
102
+ # flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=access_token).to(device=device_0, dtype=torch.bfloat16)
103
+ flux_pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, vae=taef1, token=access_token).to(device_0)
104
+ flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model")
105
+ flux_pipe.load_lora_weights(flux_lora_ckpt_path)
106
+ # flux_pipe.to(device=device_0, dtype=torch.bfloat16)
107
+ torch.cuda.empty_cache()
108
+ flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(flux_pipe)
109
+
110
+
111
+ # lrm
112
+ config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
113
+ model_config = config.model_config
114
+ infer_config = config.infer_config
115
+ model = instantiate_from_config(model_config)
116
+ model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
117
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
118
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
119
+ model.load_state_dict(state_dict, strict=True)
120
+ model = model.to(device_1)
121
+ torch.cuda.empty_cache()
122
+ @spaces.GPU
123
+ def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
124
+ images = image.unsqueeze(0).to(device_1)
125
+ images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
126
+ # breakpoint()
127
+ with torch.no_grad():
128
+ # get triplane
129
+ planes = model.forward_planes(images, input_cameras)
130
+
131
+ mesh_path_idx = os.path.join(save_path, f'{name}.obj')
132
+
133
+ mesh_out = model.extract_mesh(
134
+ planes,
135
+ use_texture_map=export_texmap,
136
+ **infer_config,
137
+ )
138
+ if export_texmap:
139
+ vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
140
+ save_obj_with_mtl(
141
+ vertices.data.cpu().numpy(),
142
+ uvs.data.cpu().numpy(),
143
+ faces.data.cpu().numpy(),
144
+ mesh_tex_idx.data.cpu().numpy(),
145
+ tex_map.permute(1, 2, 0).data.cpu().numpy(),
146
+ mesh_path_idx,
147
+ )
148
+ else:
149
+ vertices, faces, vertex_colors = mesh_out
150
+ save_obj(vertices, faces, vertex_colors, mesh_path_idx)
151
+ print(f"Mesh saved to {mesh_path_idx}")
152
+
153
+ render_size = 512
154
+ if if_save_video:
155
+ video_path_idx = os.path.join(save_path, f'{name}.mp4')
156
+ render_size = infer_config.render_resolution
157
+ ENV = load_mipmap("models/lrm/env_mipmap/6")
158
+ materials = (0.0,0.9)
159
+
160
+ all_mv, all_mvp, all_campos = get_render_cameras_video(
161
+ batch_size=1,
162
+ M=240,
163
+ radius=4.5,
164
+ elevation=(90, 60.0),
165
+ is_flexicubes=True,
166
+ fov=30
167
+ )
168
+
169
+ frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
170
+ model,
171
+ planes,
172
+ render_cameras=all_mvp,
173
+ camera_pos=all_campos,
174
+ env=ENV,
175
+ materials=materials,
176
+ render_size=render_size,
177
+ chunk_size=20,
178
+ is_flexicubes=True,
179
+ )
180
+ normals = (torch.nn.functional.normalize(normals) + 1) / 2
181
+ normals = normals * alphas + (1-alphas)
182
+ all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
183
+
184
+ save_video(
185
+ all_frames,
186
+ video_path_idx,
187
+ fps=30,
188
+ )
189
+ print(f"Video saved to {video_path_idx}")
190
+
191
+ return vertices, faces
192
+
193
+
194
+ def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg):
195
+ if local_normal_images.min() >= 0:
196
+ local_normal = local_normal_images.float() * 2 - 1
197
+ else:
198
+ local_normal = local_normal_images.float()
199
+ global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
200
+ global_normal[...,0] *= -1
201
+ global_normal = (global_normal + 1) / 2
202
+ global_normal = global_normal.permute(0, 3, 1, 2)
203
+ return global_normal
204
+
205
+ # 生成多视图图像
206
+ @spaces.GPU(duration=120)
207
+ def generate_multi_view_images(prompt, seed):
208
+ # torch.cuda.empty_cache()
209
+ # generator = torch.manual_seed(seed)
210
+ generator = torch.Generator().manual_seed(seed)
211
+ with torch.no_grad():
212
+ # images = flux_pipe(
213
+ # prompt=prompt,
214
+ # num_inference_steps=10,
215
+ # guidance_scale=3.5,
216
+ # num_images_per_prompt=1,
217
+ # width=resolution * 4,
218
+ # height=resolution * 2,
219
+ # output_type='np',
220
+ # generator=generator,
221
+ # good_vae=good_vae,
222
+ # ).images
223
+ for img in flux_pipe.flux_pipe_call_that_returns_an_iterable_of_images(
224
+ prompt=prompt,
225
+ guidance_scale=3.5,
226
+ num_inference_steps=10,
227
+ width=resolution * 4,
228
+ height=resolution * 2,
229
+ generator=generator,
230
+ output_type="np",
231
+ good_vae=good_vae,
232
+ ):
233
+ pass
234
+ # 返回最终的图像和种子(通过外部调用处理)
235
+ return img
236
+
237
+ # 重建 3D 模型
238
+ @spaces.GPU
239
+ def reconstruct_3d_model(images, prompt):
240
+ global model
241
+ model.init_flexicubes_geometry(device_1, fovy=50.0)
242
+ model = model.eval()
243
+ rgb_normal_grid = images
244
+ save_dir_path = os.path.join(save_dir, prompt.replace(" ", "_"))
245
+ os.makedirs(save_dir_path, exist_ok=True)
246
+
247
+ images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
248
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
249
+ rgb_multi_view = images[:4, :3, :, :]
250
+ normal_multi_view = images[4:, :3, :, :]
251
+ multi_view_mask = get_background(normal_multi_view)
252
+ rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
253
+ input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device_1)
254
+ vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=False)
255
+ # local normal to global normal
256
+
257
+ global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
258
+ global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
259
+
260
+ global_normal = global_normal.permute(0,2,3,1)
261
+ rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
262
+ multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
263
+ vertices = torch.from_numpy(vertices).to(device_1)
264
+ faces = torch.from_numpy(faces).to(device_1)
265
+ vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
266
+ vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
267
+
268
+ # global_normal: B,H,W,3
269
+ # multi_view_mask: B,H,W
270
+ # rgb_multi_view: B,H,W,3
271
+
272
+ meshes = reconstruction(
273
+ normal_pils=global_normal,
274
+ masks=multi_view_mask,
275
+ weights=isomer_geo_weights,
276
+ fov=30,
277
+ radius=isomer_radius,
278
+ camera_angles_azi=isomer_azimuths,
279
+ camera_angles_ele=isomer_elevations,
280
+ expansion_weight_stage1=0.1,
281
+ init_type="file",
282
+ init_verts=vertices,
283
+ init_faces=faces,
284
+ stage1_steps=0,
285
+ stage2_steps=50,
286
+ start_edge_len_stage1=0.1,
287
+ end_edge_len_stage1=0.02,
288
+ start_edge_len_stage2=0.02,
289
+ end_edge_len_stage2=0.005,
290
+ )
291
+
292
+
293
+ save_glb_addr = projection(
294
+ meshes,
295
+ masks=multi_view_mask,
296
+ images=rgb_multi_view,
297
+ azimuths=isomer_azimuths,
298
+ elevations=isomer_elevations,
299
+ weights=isomer_color_weights,
300
+ fov=30,
301
+ radius=isomer_radius,
302
+ save_dir=f"{save_dir_path}/ISOMER/",
303
+ )
304
+
305
+ return save_glb_addr
306
+
307
+ # Gradio 接口函数
308
+ @spaces.GPU
309
+ def gradio_pipeline(prompt, seed):
310
+ # 生成多视图图像
311
+ rgb_normal_grid = generate_multi_view_images(prompt, seed)
312
+ image_preview = Image.fromarray((rgb_normal_grid * 255).astype(np.uint8))
313
+
314
+ # 3d reconstruction
315
+
316
+
317
+ # 重建 3D 模型并返回 glb 路径
318
+ save_glb_addr = reconstruct_3d_model(rgb_normal_grid, prompt)
319
+
320
+ return image_preview, save_glb_addr
321
+
322
+ if __name__ == "__main__":
323
+ prompt_input = "a owm"
324
+ sample_seed = 42
325
+ gradio_pipeline(prompt_input, sample_seed)