Spaces:
Build error
Build error
Nikhil0987
commited on
Commit
•
ebf16ac
1
Parent(s):
aedd5aa
Update app.py
Browse files
app.py
CHANGED
@@ -1,35 +1,186 @@
|
|
1 |
-
import
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
#
|
5 |
-
|
|
|
6 |
|
7 |
-
|
8 |
-
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
# Handle events
|
17 |
-
for event in pygame.event.get():
|
18 |
-
if event.type == QUIT:
|
19 |
-
running = False
|
20 |
|
21 |
-
|
|
|
22 |
|
23 |
-
|
24 |
-
|
25 |
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
|
|
|
30 |
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torchvision.transforms.functional as TF
|
6 |
+
from safetensors.torch import load_file
|
7 |
+
import rembg
|
8 |
+
import gradio as gr
|
9 |
|
10 |
+
# download checkpoints
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
ckpt_path = hf_hub_download(repo_id="ashawkey/LGM", filename="model_fp16.safetensors")
|
13 |
|
14 |
+
try:
|
15 |
+
import diff_gaussian_rasterization
|
16 |
+
except ImportError:
|
17 |
+
os.system("pip install ./diff-gaussian-rasterization")
|
18 |
|
19 |
+
import kiui
|
20 |
+
from kiui.op import recenter
|
21 |
|
22 |
+
from core.options import Options
|
23 |
+
from core.models import LGM
|
24 |
+
from mvdream.pipeline_mvdream import MVDreamPipeline
|
|
|
|
|
|
|
|
|
25 |
|
26 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
27 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
28 |
|
29 |
+
TMP_DIR = '/tmp'
|
30 |
+
os.makedirs(TMP_DIR, exist_ok=True)
|
31 |
|
32 |
+
# opt = tyro.cli(AllConfigs)
|
33 |
+
opt = Options(
|
34 |
+
input_size=256,
|
35 |
+
up_channels=(1024, 1024, 512, 256, 128), # one more decoder
|
36 |
+
up_attention=(True, True, True, False, False),
|
37 |
+
splat_size=128,
|
38 |
+
output_size=512, # render & supervise Gaussians at a higher resolution.
|
39 |
+
batch_size=8,
|
40 |
+
num_views=8,
|
41 |
+
gradient_accumulation_steps=1,
|
42 |
+
mixed_precision='bf16',
|
43 |
+
resume=ckpt_path,
|
44 |
+
)
|
45 |
|
46 |
+
# model
|
47 |
+
model = LGM(opt)
|
48 |
|
49 |
+
# resume pretrained checkpoint
|
50 |
+
if opt.resume is not None:
|
51 |
+
if opt.resume.endswith('safetensors'):
|
52 |
+
ckpt = load_file(opt.resume, device='cpu')
|
53 |
+
else:
|
54 |
+
ckpt = torch.load(opt.resume, map_location='cpu')
|
55 |
+
model.load_state_dict(ckpt, strict=False)
|
56 |
+
print(f'[INFO] Loaded checkpoint from {opt.resume}')
|
57 |
+
else:
|
58 |
+
print(f'[WARN] model randomly initialized, are you sure?')
|
59 |
|
60 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
61 |
+
model = model.half().to(device)
|
62 |
+
model.eval()
|
63 |
+
|
64 |
+
tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
|
65 |
+
proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
|
66 |
+
proj_matrix[0, 0] = -1 / tan_half_fov
|
67 |
+
proj_matrix[1, 1] = -1 / tan_half_fov
|
68 |
+
proj_matrix[2, 2] = - (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
|
69 |
+
proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
|
70 |
+
proj_matrix[2, 3] = 1
|
71 |
+
|
72 |
+
# load dreams
|
73 |
+
pipe_text = MVDreamPipeline.from_pretrained(
|
74 |
+
'ashawkey/mvdream-sd2.1-diffusers', # remote weights
|
75 |
+
torch_dtype=torch.float16,
|
76 |
+
trust_remote_code=True,
|
77 |
+
# local_files_only=True,
|
78 |
+
)
|
79 |
+
pipe_text = pipe_text.to(device)
|
80 |
+
|
81 |
+
pipe_image = MVDreamPipeline.from_pretrained(
|
82 |
+
"ashawkey/imagedream-ipmv-diffusers", # remote weights
|
83 |
+
torch_dtype=torch.float16,
|
84 |
+
trust_remote_code=True,
|
85 |
+
# local_files_only=True,
|
86 |
+
)
|
87 |
+
pipe_image = pipe_image.to(device)
|
88 |
+
|
89 |
+
# load rembg
|
90 |
+
bg_remover = rembg.new_session()
|
91 |
+
|
92 |
+
# process function
|
93 |
+
def run(input_image):
|
94 |
+
prompt_neg = "ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate"
|
95 |
+
|
96 |
+
# seed
|
97 |
+
kiui.seed_everything(42)
|
98 |
+
|
99 |
+
output_ply_path = os.path.join(TMP_DIR, 'output.ply')
|
100 |
+
|
101 |
+
input_image = np.array(input_image) # uint8
|
102 |
+
# bg removal
|
103 |
+
carved_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4]
|
104 |
+
mask = carved_image[..., -1] > 0
|
105 |
+
image = recenter(carved_image, mask, border_ratio=0.2)
|
106 |
+
image = image.astype(np.float32) / 255.0
|
107 |
+
image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
|
108 |
+
mv_image = pipe_image("", image, negative_prompt=prompt_neg, num_inference_steps=30, guidance_scale=5.0, elevation=0)
|
109 |
+
|
110 |
+
# generate gaussians
|
111 |
+
input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32
|
112 |
+
input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
|
113 |
+
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
|
114 |
+
input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
115 |
+
|
116 |
+
rays_embeddings = model.prepare_default_rays(device, elevation=0)
|
117 |
+
input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
|
118 |
+
|
119 |
+
with torch.no_grad():
|
120 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
121 |
+
# generate gaussians
|
122 |
+
gaussians = model.forward_gaussians(input_image)
|
123 |
+
|
124 |
+
# save gaussians
|
125 |
+
model.gs.save_ply(gaussians, output_ply_path)
|
126 |
+
|
127 |
+
return output_ply_path
|
128 |
+
|
129 |
+
# gradio UI
|
130 |
+
|
131 |
+
_TITLE = '''LGM Mini'''
|
132 |
+
|
133 |
+
_DESCRIPTION = '''
|
134 |
+
<div>
|
135 |
+
A lightweight version of <a href="https://huggingface.co/spaces/ashawkey/LGM">LGM: Large Multi-View Gaussian Model for High-Resolution 3D Content Creation</a>.
|
136 |
+
</div>
|
137 |
+
'''
|
138 |
+
|
139 |
+
css = '''
|
140 |
+
#duplicate-button {
|
141 |
+
margin: auto;
|
142 |
+
color: white;
|
143 |
+
background: #1565c0;
|
144 |
+
border-radius: 100vh;
|
145 |
+
}
|
146 |
+
'''
|
147 |
+
|
148 |
+
block = gr.Blocks(title=_TITLE, css=css)
|
149 |
+
with block:
|
150 |
+
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
|
151 |
+
|
152 |
+
with gr.Row():
|
153 |
+
with gr.Column(scale=1):
|
154 |
+
gr.Markdown('# ' + _TITLE)
|
155 |
+
gr.Markdown(_DESCRIPTION)
|
156 |
+
|
157 |
+
with gr.Row(variant='panel'):
|
158 |
+
with gr.Column(scale=1):
|
159 |
+
# input image
|
160 |
+
input_image = gr.Image(label="image", type='pil', height=320)
|
161 |
+
# gen button
|
162 |
+
button_gen = gr.Button("Generate")
|
163 |
+
|
164 |
+
|
165 |
+
with gr.Column(scale=1):
|
166 |
+
output_splat = gr.Model3D(label="3D Gaussians")
|
167 |
+
|
168 |
+
button_gen.click(fn=run, inputs=[input_image], outputs=[output_splat])
|
169 |
+
|
170 |
+
gr.Examples(
|
171 |
+
examples=[
|
172 |
+
"data_test/frog_sweater.jpg",
|
173 |
+
"data_test/bird.jpg",
|
174 |
+
"data_test/boy.jpg",
|
175 |
+
"data_test/cat_statue.jpg",
|
176 |
+
"data_test/dragontoy.jpg",
|
177 |
+
"data_test/gso_rabbit.jpg",
|
178 |
+
],
|
179 |
+
inputs=[input_image],
|
180 |
+
outputs=[output_splat],
|
181 |
+
fn=lambda x: run(input_image=x),
|
182 |
+
cache_examples=True,
|
183 |
+
label='Image-to-3D Examples'
|
184 |
+
)
|
185 |
+
|
186 |
+
block.queue().launch(debug=True, share=True)
|