MaxMilan1 commited on
Commit
f75e089
1 Parent(s): 49db696
Files changed (2) hide show
  1. app.py +2 -261
  2. util/instantmesh.py +210 -0
app.py CHANGED
@@ -1,256 +1,7 @@
1
- import spaces
2
-
3
- import os
4
- import imageio
5
- import numpy as np
6
- import torch
7
- import rembg
8
- from PIL import Image
9
- from torchvision.transforms import v2
10
- from pytorch_lightning import seed_everything
11
- from omegaconf import OmegaConf
12
- from einops import rearrange, repeat
13
- from tqdm import tqdm
14
- from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
15
-
16
- from src.utils.train_util import instantiate_from_config
17
- from src.utils.camera_util import (
18
- FOV_to_intrinsics,
19
- get_zero123plus_input_cameras,
20
- get_circular_camera_poses,
21
- )
22
- from src.utils.mesh_util import save_obj
23
- from src.utils.infer_util import remove_background, resize_foreground, images_to_video
24
-
25
- import tempfile
26
- from functools import partial
27
-
28
- from huggingface_hub import hf_hub_download
29
-
30
  import gradio as gr
 
31
 
32
-
33
- def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
34
- """
35
- Get the rendering camera parameters.
36
- """
37
- c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
38
- if is_flexicubes:
39
- cameras = torch.linalg.inv(c2ws)
40
- cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
41
- else:
42
- extrinsics = c2ws.flatten(-2)
43
- intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
44
- cameras = torch.cat([extrinsics, intrinsics], dim=-1)
45
- cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
46
- return cameras
47
-
48
-
49
- def images_to_video(images, output_path, fps=30):
50
- # images: (N, C, H, W)
51
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
52
- frames = []
53
- for i in range(images.shape[0]):
54
- frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
55
- assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
56
- f"Frame shape mismatch: {frame.shape} vs {images.shape}"
57
- assert frame.min() >= 0 and frame.max() <= 255, \
58
- f"Frame value out of range: {frame.min()} ~ {frame.max()}"
59
- frames.append(frame)
60
- imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
61
-
62
-
63
- ###############################################################################
64
- # Configuration.
65
- ###############################################################################
66
-
67
- import shutil
68
-
69
- def find_cuda():
70
- # Check if CUDA_HOME or CUDA_PATH environment variables are set
71
- cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
72
-
73
- if cuda_home and os.path.exists(cuda_home):
74
- return cuda_home
75
-
76
- # Search for the nvcc executable in the system's PATH
77
- nvcc_path = shutil.which('nvcc')
78
-
79
- if nvcc_path:
80
- # Remove the 'bin/nvcc' part to get the CUDA installation path
81
- cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
82
- return cuda_path
83
-
84
- return None
85
-
86
- cuda_path = find_cuda()
87
-
88
- if cuda_path:
89
- print(f"CUDA installation found at: {cuda_path}")
90
- else:
91
- print("CUDA installation not found")
92
-
93
- config_path = 'configs/instant-mesh-large.yaml'
94
- config = OmegaConf.load(config_path)
95
- config_name = os.path.basename(config_path).replace('.yaml', '')
96
- model_config = config.model_config
97
- infer_config = config.infer_config
98
-
99
- IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
100
-
101
- device = torch.device('cuda')
102
-
103
- # load diffusion model
104
- print('Loading diffusion model ...')
105
- pipeline = DiffusionPipeline.from_pretrained(
106
- "sudo-ai/zero123plus-v1.2",
107
- custom_pipeline="zero123plus",
108
- torch_dtype=torch.float16,
109
- )
110
- pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
111
- pipeline.scheduler.config, timestep_spacing='trailing'
112
- )
113
-
114
- # load custom white-background UNet
115
- unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
116
- state_dict = torch.load(unet_ckpt_path, map_location='cpu')
117
- pipeline.unet.load_state_dict(state_dict, strict=True)
118
-
119
- pipeline = pipeline.to(device)
120
-
121
- # load reconstruction model
122
- print('Loading reconstruction model ...')
123
- model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
124
- model = instantiate_from_config(model_config)
125
- state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
126
- state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
127
- model.load_state_dict(state_dict, strict=True)
128
-
129
- model = model.to(device)
130
-
131
- print('Loading Finished!')
132
-
133
-
134
- def check_input_image(input_image):
135
- if input_image is None:
136
- raise gr.Error("No image uploaded!")
137
-
138
-
139
- def preprocess(input_image, do_remove_background):
140
-
141
- rembg_session = rembg.new_session() if do_remove_background else None
142
-
143
- if do_remove_background:
144
- input_image = remove_background(input_image, rembg_session)
145
- input_image = resize_foreground(input_image, 0.85)
146
-
147
- return input_image
148
-
149
-
150
- @spaces.GPU
151
- def generate_mvs(input_image, sample_steps, sample_seed):
152
-
153
- seed_everything(sample_seed)
154
-
155
- # sampling
156
- z123_image = pipeline(
157
- input_image,
158
- num_inference_steps=sample_steps
159
- ).images[0]
160
-
161
- show_image = np.asarray(z123_image, dtype=np.uint8)
162
- show_image = torch.from_numpy(show_image) # (960, 640, 3)
163
- show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
164
- show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
165
- show_image = Image.fromarray(show_image.numpy())
166
-
167
- return z123_image, show_image
168
-
169
-
170
- @spaces.GPU
171
- def make3d(images):
172
-
173
- global model
174
- if IS_FLEXICUBES:
175
- model.init_flexicubes_geometry(device, use_renderer=False)
176
- model = model.eval()
177
-
178
- images = np.asarray(images, dtype=np.float32) / 255.0
179
- images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
180
- images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
181
-
182
- input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
183
- render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
184
-
185
- images = images.unsqueeze(0).to(device)
186
- images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
187
-
188
- mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
189
- print(mesh_fpath)
190
- mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
191
- mesh_dirname = os.path.dirname(mesh_fpath)
192
- video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
193
-
194
- with torch.no_grad():
195
- # get triplane
196
- planes = model.forward_planes(images, input_cameras)
197
-
198
- # # get video
199
- # chunk_size = 20 if IS_FLEXICUBES else 1
200
- # render_size = 384
201
-
202
- # frames = []
203
- # for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
204
- # if IS_FLEXICUBES:
205
- # frame = model.forward_geometry(
206
- # planes,
207
- # render_cameras[:, i:i+chunk_size],
208
- # render_size=render_size,
209
- # )['img']
210
- # else:
211
- # frame = model.synthesizer(
212
- # planes,
213
- # cameras=render_cameras[:, i:i+chunk_size],
214
- # render_size=render_size,
215
- # )['images_rgb']
216
- # frames.append(frame)
217
- # frames = torch.cat(frames, dim=1)
218
-
219
- # images_to_video(
220
- # frames[0],
221
- # video_fpath,
222
- # fps=30,
223
- # )
224
-
225
- # print(f"Video saved to {video_fpath}")
226
-
227
- # get mesh
228
- mesh_out = model.extract_mesh(
229
- planes,
230
- use_texture_map=False,
231
- **infer_config,
232
- )
233
-
234
- vertices, faces, vertex_colors = mesh_out
235
- vertices = vertices[:, [1, 2, 0]]
236
- vertices[:, -1] *= -1
237
- faces = faces[:, [2, 1, 0]]
238
-
239
- save_obj(vertices, faces, vertex_colors, mesh_fpath)
240
-
241
- print(f"Mesh saved to {mesh_fpath}")
242
-
243
- return mesh_fpath
244
-
245
-
246
- _HEADER_ = '''
247
- <h2><b>Official 🤗 Gradio Demo</b></h2><h2><a href='https://github.com/TencentARC/InstantMesh' target='_blank'><b>InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models</b></a></h2>
248
- '''
249
-
250
- _LINKS_ = '''
251
- <h3>Code is available at <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>GitHub</a></h3>
252
- <h3>Report is available at <a href='https://arxiv.org/abs/2404.07191' target='_blank'>ArXiv</a></h3>
253
- '''
254
 
255
  _CITE_ = r"""
256
  ```bibtex
@@ -265,7 +16,6 @@ _CITE_ = r"""
265
 
266
 
267
  with gr.Blocks() as demo:
268
- gr.Markdown(_HEADER_)
269
  with gr.Row(variant="panel"):
270
  with gr.Column():
271
  with gr.Row():
@@ -327,14 +77,6 @@ with gr.Blocks() as demo:
327
  interactive=False
328
  )
329
 
330
- # with gr.Column():
331
- # output_video = gr.Video(
332
- # label="video", format="mp4",
333
- # width=379,
334
- # autoplay=True,
335
- # interactive=False
336
- # )
337
-
338
  with gr.Row():
339
  output_model_obj = gr.Model3D(
340
  label="Output Model (OBJ Format)",
@@ -344,7 +86,6 @@ with gr.Blocks() as demo:
344
  with gr.Row():
345
  gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
346
 
347
- gr.Markdown(_LINKS_)
348
  gr.Markdown(_CITE_)
349
 
350
  mv_images = gr.State()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
 
4
+ from util.instantmesh import generate_mvs, make3d, preprocess, check_input_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  _CITE_ = r"""
7
  ```bibtex
 
16
 
17
 
18
  with gr.Blocks() as demo:
 
19
  with gr.Row(variant="panel"):
20
  with gr.Column():
21
  with gr.Row():
 
77
  interactive=False
78
  )
79
 
 
 
 
 
 
 
 
 
80
  with gr.Row():
81
  output_model_obj = gr.Model3D(
82
  label="Output Model (OBJ Format)",
 
86
  with gr.Row():
87
  gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
88
 
 
89
  gr.Markdown(_CITE_)
90
 
91
  mv_images = gr.State()
util/instantmesh.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+
3
+ import os
4
+ import imageio
5
+ import numpy as np
6
+ import torch
7
+ import rembg
8
+ from PIL import Image
9
+ from torchvision.transforms import v2
10
+ from pytorch_lightning import seed_everything
11
+ from omegaconf import OmegaConf
12
+ from einops import rearrange, repeat
13
+ from tqdm import tqdm
14
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
15
+
16
+ from src.utils.train_util import instantiate_from_config
17
+ from src.utils.camera_util import (
18
+ FOV_to_intrinsics,
19
+ get_zero123plus_input_cameras,
20
+ get_circular_camera_poses,
21
+ )
22
+ from src.utils.mesh_util import save_obj
23
+ from src.utils.infer_util import remove_background, resize_foreground, images_to_video
24
+
25
+ import tempfile
26
+ from functools import partial
27
+
28
+ from huggingface_hub import hf_hub_download
29
+
30
+ import gradio as gr
31
+
32
+ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
33
+ """
34
+ Get the rendering camera parameters.
35
+ """
36
+ c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
37
+ if is_flexicubes:
38
+ cameras = torch.linalg.inv(c2ws)
39
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
40
+ else:
41
+ extrinsics = c2ws.flatten(-2)
42
+ intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
43
+ cameras = torch.cat([extrinsics, intrinsics], dim=-1)
44
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
45
+ return cameras
46
+
47
+
48
+ def images_to_video(images, output_path, fps=30):
49
+ # images: (N, C, H, W)
50
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
51
+ frames = []
52
+ for i in range(images.shape[0]):
53
+ frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
54
+ assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
55
+ f"Frame shape mismatch: {frame.shape} vs {images.shape}"
56
+ assert frame.min() >= 0 and frame.max() <= 255, \
57
+ f"Frame value out of range: {frame.min()} ~ {frame.max()}"
58
+ frames.append(frame)
59
+ imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
60
+
61
+ ###############################################################################
62
+ # Configuration.
63
+ ###############################################################################
64
+
65
+ import shutil
66
+
67
+ def find_cuda():
68
+ # Check if CUDA_HOME or CUDA_PATH environment variables are set
69
+ cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
70
+
71
+ if cuda_home and os.path.exists(cuda_home):
72
+ return cuda_home
73
+
74
+ # Search for the nvcc executable in the system's PATH
75
+ nvcc_path = shutil.which('nvcc')
76
+
77
+ if nvcc_path:
78
+ # Remove the 'bin/nvcc' part to get the CUDA installation path
79
+ cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
80
+ return cuda_path
81
+
82
+ return None
83
+
84
+ cuda_path = find_cuda()
85
+
86
+ if cuda_path:
87
+ print(f"CUDA installation found at: {cuda_path}")
88
+ else:
89
+ print("CUDA installation not found")
90
+
91
+ config_path = 'configs/instant-mesh-large.yaml'
92
+ config = OmegaConf.load(config_path)
93
+ config_name = os.path.basename(config_path).replace('.yaml', '')
94
+ model_config = config.model_config
95
+ infer_config = config.infer_config
96
+
97
+ IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
98
+
99
+ device = torch.device('cuda')
100
+
101
+ # load diffusion model
102
+ print('Loading diffusion model ...')
103
+ pipeline = DiffusionPipeline.from_pretrained(
104
+ "sudo-ai/zero123plus-v1.2",
105
+ custom_pipeline="zero123plus",
106
+ torch_dtype=torch.float16,
107
+ )
108
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
109
+ pipeline.scheduler.config, timestep_spacing='trailing'
110
+ )
111
+
112
+ # load custom white-background UNet
113
+ unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
114
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
115
+ pipeline.unet.load_state_dict(state_dict, strict=True)
116
+
117
+ pipeline = pipeline.to(device)
118
+
119
+ # load reconstruction model
120
+ print('Loading reconstruction model ...')
121
+ model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
122
+ model = instantiate_from_config(model_config)
123
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
124
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
125
+ model.load_state_dict(state_dict, strict=True)
126
+
127
+ model = model.to(device)
128
+
129
+ print('Loading Finished!')
130
+
131
+ def check_input_image(input_image):
132
+ if input_image is None:
133
+ raise gr.Error("No image uploaded!")
134
+
135
+
136
+ def preprocess(input_image, do_remove_background):
137
+
138
+ rembg_session = rembg.new_session() if do_remove_background else None
139
+
140
+ if do_remove_background:
141
+ input_image = remove_background(input_image, rembg_session)
142
+ input_image = resize_foreground(input_image, 0.85)
143
+
144
+ return input_image
145
+
146
+ @spaces.GPU
147
+ def generate_mvs(input_image, sample_steps, sample_seed):
148
+
149
+ seed_everything(sample_seed)
150
+
151
+ # sampling
152
+ z123_image = pipeline(
153
+ input_image,
154
+ num_inference_steps=sample_steps
155
+ ).images[0]
156
+
157
+ show_image = np.asarray(z123_image, dtype=np.uint8)
158
+ show_image = torch.from_numpy(show_image) # (960, 640, 3)
159
+ show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
160
+ show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
161
+ show_image = Image.fromarray(show_image.numpy())
162
+
163
+ return z123_image, show_image
164
+
165
+
166
+ @spaces.GPU
167
+ def make3d(images):
168
+
169
+ global model
170
+ if IS_FLEXICUBES:
171
+ model.init_flexicubes_geometry(device, use_renderer=False)
172
+ model = model.eval()
173
+
174
+ images = np.asarray(images, dtype=np.float32) / 255.0
175
+ images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
176
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
177
+
178
+ input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
179
+ render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
180
+
181
+ images = images.unsqueeze(0).to(device)
182
+ images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
183
+
184
+ mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
185
+ print(mesh_fpath)
186
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
187
+ mesh_dirname = os.path.dirname(mesh_fpath)
188
+ video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
189
+
190
+ with torch.no_grad():
191
+ # get triplane
192
+ planes = model.forward_planes(images, input_cameras)
193
+
194
+ # get mesh
195
+ mesh_out = model.extract_mesh(
196
+ planes,
197
+ use_texture_map=False,
198
+ **infer_config,
199
+ )
200
+
201
+ vertices, faces, vertex_colors = mesh_out
202
+ vertices = vertices[:, [1, 2, 0]]
203
+ vertices[:, -1] *= -1
204
+ faces = faces[:, [2, 1, 0]]
205
+
206
+ save_obj(vertices, faces, vertex_colors, mesh_fpath)
207
+
208
+ print(f"Mesh saved to {mesh_fpath}")
209
+
210
+ return mesh_fpath