JiantaoLin
commited on
Commit
·
4157d39
1
Parent(s):
a2907bc
new
Browse files- image_to_mesh_new.py +436 -0
- pipeline/kiss3d_wrapper.py +429 -0
- pipeline/pipeline_config/default.yaml +25 -0
- pipeline/run_hpc.sh +10 -0
- pipeline/utils.py +198 -0
- run.sh +2 -0
- run_hpc.sh +11 -0
- text_to_mesh_new.py +244 -0
- upload_huggingface.py +57 -0
image_to_mesh_new.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from einops import rearrange
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import trimesh
|
7 |
+
import torchvision
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision import transforms
|
11 |
+
from torchvision.transforms import v2
|
12 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
13 |
+
import rembg
|
14 |
+
from diffusers import FluxPipeline, FluxControlNetImg2ImgPipeline
|
15 |
+
from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
16 |
+
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, HeunDiscreteScheduler
|
17 |
+
from pytorch_lightning import seed_everything
|
18 |
+
import os
|
19 |
+
|
20 |
+
from models.ISOMER.reconstruction_func import reconstruction
|
21 |
+
from models.ISOMER.projection_func import projection
|
22 |
+
from models.lrm.utils.infer_util import remove_background, resize_foreground, save_video
|
23 |
+
from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
|
24 |
+
from models.lrm.utils.render_utils import rotate_x, rotate_y
|
25 |
+
from models.lrm.utils.train_util import instantiate_from_config
|
26 |
+
from models.lrm.utils.camera_util import get_zero123plus_input_cameras, get_custom_zero123plus_input_cameras, get_flux_input_cameras
|
27 |
+
from utils.tool import NormalTransfer, get_render_cameras_frames, load_mipmap
|
28 |
+
from utils.tool import get_background, get_render_cameras_video, render_frames, mask_fix
|
29 |
+
|
30 |
+
device = "cuda"
|
31 |
+
resolution = 512
|
32 |
+
save_dir = "./outputs"
|
33 |
+
zero123plus_diffusion_steps = 75
|
34 |
+
normal_transfer = NormalTransfer()
|
35 |
+
rembg_session = rembg.new_session()
|
36 |
+
isomer_azimuths = torch.from_numpy(np.array([270, 0, 90, 180])).to(device)
|
37 |
+
isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).to(device)
|
38 |
+
isomer_radius = 4.1
|
39 |
+
isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device)
|
40 |
+
isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device)
|
41 |
+
# seed_everything(42)
|
42 |
+
|
43 |
+
# model initialization and loading
|
44 |
+
# flux
|
45 |
+
print('==> Loading Flux model ...')
|
46 |
+
flux_base_model_pth = "/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev"
|
47 |
+
flux_controlnet = FluxControlNetModel.from_pretrained("/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/flux_controlnets/FLUX.1-dev-ControlNet-Union-Pro")
|
48 |
+
flux_pipe = FluxControlNetImg2ImgPipeline.from_pretrained(flux_base_model_pth, controlnet=[flux_controlnet], torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16)
|
49 |
+
|
50 |
+
flux_pipe.load_lora_weights('./checkpoint/flux_lora/rgb_normal_large.safetensors')
|
51 |
+
|
52 |
+
|
53 |
+
flux_pipe.to(device=device, dtype=torch.bfloat16)
|
54 |
+
generator = torch.Generator(device=device).manual_seed(0)
|
55 |
+
|
56 |
+
# lrm
|
57 |
+
print('==> Loading LRM model ...')
|
58 |
+
config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
|
59 |
+
model_config = config.model_config
|
60 |
+
infer_config = config.infer_config
|
61 |
+
model = instantiate_from_config(model_config)
|
62 |
+
model_ckpt_path = "./checkpoint/lrm/final_ckpt.ckpt"
|
63 |
+
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
64 |
+
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
65 |
+
model.load_state_dict(state_dict, strict=True)
|
66 |
+
|
67 |
+
model = model.to(device)
|
68 |
+
model.init_flexicubes_geometry(device, fovy=50.0)
|
69 |
+
model = model.eval()
|
70 |
+
|
71 |
+
# zero123++
|
72 |
+
print('==> Loading diffusion model ...')
|
73 |
+
zero123plus_pipeline = DiffusionPipeline.from_pretrained(
|
74 |
+
"sudo-ai/zero123plus-v1.2",
|
75 |
+
custom_pipeline="./models/zero123plus",
|
76 |
+
torch_dtype=torch.float16,
|
77 |
+
)
|
78 |
+
zero123plus_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
79 |
+
zero123plus_pipeline.scheduler.config, timestep_spacing='trailing'
|
80 |
+
)
|
81 |
+
unet_ckpt_path = "./checkpoint/zero123++/flexgen_19w.ckpt"
|
82 |
+
state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict']
|
83 |
+
state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
|
84 |
+
zero123plus_pipeline.unet.load_state_dict(state_dict, strict=True)
|
85 |
+
zero123plus_pipeline = zero123plus_pipeline.to(device)
|
86 |
+
|
87 |
+
# unet_ckpt_path = "checkpoint/zero123++/diffusion_pytorch_model.bin"
|
88 |
+
# state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
89 |
+
# zero123plus_pipeline.unet.load_state_dict(state_dict, strict=True)
|
90 |
+
# zero123plus_pipeline = zero123plus_pipeline.to(device)
|
91 |
+
|
92 |
+
# florence
|
93 |
+
caption_model = AutoModelForCausalLM.from_pretrained(
|
94 |
+
"/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", torch_dtype=torch.bfloat16, trust_remote_code=True,
|
95 |
+
).to(device)
|
96 |
+
caption_processor = AutoProcessor.from_pretrained("/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", trust_remote_code=True)
|
97 |
+
|
98 |
+
# Flux multi-view generation
|
99 |
+
def multi_view_rgb_normal_generation_with_controlnet(prompt, image, strength=1.0,
|
100 |
+
control_image=[],
|
101 |
+
control_mode=[],
|
102 |
+
control_guidance_start=None,
|
103 |
+
control_guidance_end=None,
|
104 |
+
controlnet_conditioning_scale=None,
|
105 |
+
lora_scale=1.0
|
106 |
+
):
|
107 |
+
control_mode_dict = {
|
108 |
+
'canny': 0,
|
109 |
+
'tile': 1,
|
110 |
+
'depth': 2,
|
111 |
+
'blur': 3,
|
112 |
+
'pose': 4,
|
113 |
+
'gray': 5,
|
114 |
+
'lq': 6,
|
115 |
+
} # for https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union only
|
116 |
+
|
117 |
+
hparam_dict = {
|
118 |
+
'prompt': prompt,
|
119 |
+
'image': image,
|
120 |
+
'strength': strength,
|
121 |
+
'num_inference_steps': 30,
|
122 |
+
'guidance_scale': 3.5,
|
123 |
+
'num_images_per_prompt': 1,
|
124 |
+
'width': resolution*4,
|
125 |
+
'height': resolution*2,
|
126 |
+
'output_type': 'np',
|
127 |
+
'generator': generator,
|
128 |
+
'joint_attention_kwargs': {"scale": lora_scale}
|
129 |
+
}
|
130 |
+
|
131 |
+
# append controlnet hparams
|
132 |
+
if len(control_image) > 0:
|
133 |
+
assert len(control_mode) == len(control_image) # the count of image should be the same as control mode
|
134 |
+
|
135 |
+
ctrl_hparams = {
|
136 |
+
'control_mode': [control_mode_dict[mode_] for mode_ in control_mode],
|
137 |
+
'control_image': control_image,
|
138 |
+
'control_guidance_start': control_guidance_start or [0.0 for i in range(len(control_image))],
|
139 |
+
'control_guidance_end': control_guidance_end or [1.0 for i in range(len(control_image))],
|
140 |
+
'controlnet_conditioning_scale': controlnet_conditioning_scale or [1.0 for i in range(len(control_image))],
|
141 |
+
}
|
142 |
+
|
143 |
+
hparam_dict.update(ctrl_hparams)
|
144 |
+
|
145 |
+
# generate multi-view images
|
146 |
+
with torch.no_grad():
|
147 |
+
image = flux_pipe(
|
148 |
+
**hparam_dict
|
149 |
+
).images
|
150 |
+
return image
|
151 |
+
|
152 |
+
# captioning
|
153 |
+
def run_captioning(image):
|
154 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
155 |
+
torch_dtype = torch.bfloat16
|
156 |
+
|
157 |
+
if isinstance(image, str): # If image is a file path
|
158 |
+
image = Image.open(image).convert("RGB")
|
159 |
+
|
160 |
+
prompt = "<MORE_DETAILED_CAPTION>"
|
161 |
+
inputs = caption_processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
|
162 |
+
# print(f"inputs {inputs}")
|
163 |
+
|
164 |
+
generated_ids = caption_model.generate(
|
165 |
+
input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
|
166 |
+
)
|
167 |
+
|
168 |
+
generated_text = caption_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
169 |
+
parsed_answer = caption_processor.post_process_generation(
|
170 |
+
generated_text, task=prompt, image_size=(image.width, image.height)
|
171 |
+
)
|
172 |
+
# print(f"parsed_answer = {parsed_answer}")
|
173 |
+
caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"].replace("The image is ", "")
|
174 |
+
return caption_text
|
175 |
+
|
176 |
+
|
177 |
+
# zero123++ multi-view generation
|
178 |
+
def multi_view_rgb_generation(cond_img):
|
179 |
+
# generate multi-view images
|
180 |
+
with torch.no_grad():
|
181 |
+
output_image = zero123plus_pipeline(
|
182 |
+
cond_img,
|
183 |
+
num_inference_steps=zero123plus_diffusion_steps,
|
184 |
+
width=resolution*2,
|
185 |
+
height=resolution*2,
|
186 |
+
).images[0]
|
187 |
+
return output_image
|
188 |
+
|
189 |
+
# lrm reconstructions
|
190 |
+
def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False, render_azimuths=None, render_elevations=None, render_radius=None, render_fov=30):
|
191 |
+
images = image.unsqueeze(0).to(device)
|
192 |
+
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
|
193 |
+
# breakpoint()
|
194 |
+
with torch.no_grad():
|
195 |
+
# get triplane
|
196 |
+
planes = model.forward_planes(images, input_cameras)
|
197 |
+
|
198 |
+
mesh_path_idx = os.path.join(save_path, f'{name}.obj')
|
199 |
+
|
200 |
+
mesh_out = model.extract_mesh(
|
201 |
+
planes,
|
202 |
+
use_texture_map=export_texmap,
|
203 |
+
**infer_config,
|
204 |
+
)
|
205 |
+
if export_texmap:
|
206 |
+
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
|
207 |
+
save_obj_with_mtl(
|
208 |
+
vertices.data.cpu().numpy(),
|
209 |
+
uvs.data.cpu().numpy(),
|
210 |
+
faces.data.cpu().numpy(),
|
211 |
+
mesh_tex_idx.data.cpu().numpy(),
|
212 |
+
tex_map.permute(1, 2, 0).data.cpu().numpy(),
|
213 |
+
mesh_path_idx,
|
214 |
+
)
|
215 |
+
else:
|
216 |
+
vertices, faces, vertex_colors = mesh_out
|
217 |
+
save_obj(vertices, faces, vertex_colors, mesh_path_idx)
|
218 |
+
print(f"Mesh saved to {mesh_path_idx}")
|
219 |
+
|
220 |
+
render_size = 512
|
221 |
+
if if_save_video:
|
222 |
+
video_path_idx = os.path.join(save_path, f'{name}.mp4')
|
223 |
+
render_size = infer_config.render_resolution
|
224 |
+
ENV = load_mipmap("models/lrm/env_mipmap/6")
|
225 |
+
materials = (0.0,0.9)
|
226 |
+
|
227 |
+
all_mv, all_mvp, all_campos = get_render_cameras_video(
|
228 |
+
batch_size=1,
|
229 |
+
M=240,
|
230 |
+
radius=4.5,
|
231 |
+
elevation=(90, 60.0),
|
232 |
+
is_flexicubes=True,
|
233 |
+
fov=30
|
234 |
+
)
|
235 |
+
|
236 |
+
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
|
237 |
+
model,
|
238 |
+
planes,
|
239 |
+
render_cameras=all_mvp,
|
240 |
+
camera_pos=all_campos,
|
241 |
+
env=ENV,
|
242 |
+
materials=materials,
|
243 |
+
render_size=render_size,
|
244 |
+
chunk_size=20,
|
245 |
+
is_flexicubes=True,
|
246 |
+
)
|
247 |
+
normals = (torch.nn.functional.normalize(normals) + 1) / 2
|
248 |
+
normals = normals * alphas + (1-alphas)
|
249 |
+
all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
|
250 |
+
|
251 |
+
# breakpoint()
|
252 |
+
save_video(
|
253 |
+
all_frames,
|
254 |
+
video_path_idx,
|
255 |
+
fps=30,
|
256 |
+
)
|
257 |
+
print(f"Video saved to {video_path_idx}")
|
258 |
+
|
259 |
+
if render_azimuths is not None and render_elevations is not None and render_radius is not None:
|
260 |
+
render_size = infer_config.render_resolution
|
261 |
+
ENV = load_mipmap("models/lrm/env_mipmap/6")
|
262 |
+
materials = (0.0,0.9)
|
263 |
+
all_mv, all_mvp, all_campos, identity_mv = get_render_cameras_frames(
|
264 |
+
batch_size=1,
|
265 |
+
radius=render_radius,
|
266 |
+
azimuths=render_azimuths,
|
267 |
+
elevations=render_elevations,
|
268 |
+
fov=30
|
269 |
+
)
|
270 |
+
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
|
271 |
+
model,
|
272 |
+
planes,
|
273 |
+
render_cameras=all_mvp,
|
274 |
+
camera_pos=all_campos,
|
275 |
+
env=ENV,
|
276 |
+
materials=materials,
|
277 |
+
render_size=render_size,
|
278 |
+
render_mv = all_mv,
|
279 |
+
local_normal=True,
|
280 |
+
identity_mv=identity_mv,
|
281 |
+
)
|
282 |
+
else:
|
283 |
+
normals = None
|
284 |
+
frames = None
|
285 |
+
albedos = None
|
286 |
+
|
287 |
+
return vertices, faces, normals, frames, albedos
|
288 |
+
|
289 |
+
|
290 |
+
def transform_normal(input_normal, azimuths_deg, elevations_deg, radius=4.5, is_global_to_local=False):
|
291 |
+
"""
|
292 |
+
input_normal: in range [-1, 1], shape (b c h w)
|
293 |
+
"""
|
294 |
+
|
295 |
+
input_normal = input_normal.permute(0, 2, 3, 1).cpu()
|
296 |
+
|
297 |
+
azimuths_deg = np.array(azimuths_deg)
|
298 |
+
elevations_deg = np.array(elevations_deg)
|
299 |
+
|
300 |
+
if is_global_to_local:
|
301 |
+
local_normal = normal_transfer.trans_global_2_local(input_normal, azimuths_deg, elevations_deg)
|
302 |
+
return local_normal.permute(0, 3, 1, 2)
|
303 |
+
else:
|
304 |
+
global_normal = normal_transfer.trans_local_2_global(input_normal, azimuths_deg, elevations_deg, radius=radius, for_lotus=False)
|
305 |
+
global_normal[..., 0] *= -1
|
306 |
+
return global_normal.permute(0, 3, 1, 2)
|
307 |
+
|
308 |
+
def local_normal_global_transform(local_normal_images,azimuths_deg,elevations_deg):
|
309 |
+
if local_normal_images.min() >= 0:
|
310 |
+
local_normal = local_normal_images.float() * 2 - 1
|
311 |
+
else:
|
312 |
+
local_normal = local_normal_images.float()
|
313 |
+
global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
|
314 |
+
global_normal[...,0] *= -1
|
315 |
+
global_normal = (global_normal + 1) / 2
|
316 |
+
global_normal = global_normal.permute(0, 3, 1, 2)
|
317 |
+
return global_normal
|
318 |
+
|
319 |
+
def main():
|
320 |
+
image_pth = "examples/蓝色小怪物.webp"
|
321 |
+
save_dir_path = os.path.join(save_dir, image_pth.split("/")[-1].split(".")[0])
|
322 |
+
os.makedirs(save_dir_path, exist_ok=True)
|
323 |
+
input_image = Image.open(image_pth)
|
324 |
+
# if not args.no_rembg:
|
325 |
+
input_image = remove_background(input_image, rembg_session)
|
326 |
+
input_image = resize_foreground(input_image, 0.85)
|
327 |
+
|
328 |
+
# generate caption
|
329 |
+
image_caption = run_captioning(image_pth)
|
330 |
+
|
331 |
+
# generate multi-view images
|
332 |
+
output_image = multi_view_rgb_generation(input_image)
|
333 |
+
|
334 |
+
# lrm reconstructions
|
335 |
+
rgb_multi_view = np.asarray(output_image, dtype=np.float32) / 255.0
|
336 |
+
rgb_multi_view = torch.from_numpy(rgb_multi_view).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
|
337 |
+
rgb_multi_view = rearrange(rgb_multi_view, 'c (n h) (m w) -> (n m) c h w', n=2, m=2) # (8, 3, 512, 512)
|
338 |
+
|
339 |
+
input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
|
340 |
+
|
341 |
+
vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \
|
342 |
+
lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm',
|
343 |
+
export_texmap=False, if_save_video=False, render_azimuths=isomer_azimuths,
|
344 |
+
render_elevations=isomer_elevations, render_radius=isomer_radius, render_fov=30)
|
345 |
+
|
346 |
+
vertices = torch.from_numpy(vertices).to(device)
|
347 |
+
faces = torch.from_numpy(faces).to(device)
|
348 |
+
vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
|
349 |
+
vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
|
350 |
+
|
351 |
+
|
352 |
+
# lrm_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
|
353 |
+
lrm_3D_bundle_image = torchvision.utils.make_grid(torch.cat([rgb_multi_view[[3,0,1,2]].cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
|
354 |
+
# rgb_multi_view[[3,0,1,2]] : (B,3,H,W)
|
355 |
+
# lrm_multi_view_normals : (B,3,H,W)
|
356 |
+
# combined_images = 0.5 * rgb_multi_view[[3,0,1,2]].cpu() + 0.5 * (lrm_multi_view_normals.cpu() + 1) / 2
|
357 |
+
# torchvision.utils.save_image(combined_images, os.path.join("debug_output", 'combined.png'))
|
358 |
+
# breakpoint()
|
359 |
+
# Use the low-quality controlnet by default, feel free to try the others
|
360 |
+
control_image = [lrm_3D_bundle_image * 2 - 1]
|
361 |
+
control_mode = ['tile']
|
362 |
+
control_guidance_start = [0.0]
|
363 |
+
control_guidance_end = [0.3]
|
364 |
+
controlnet_conditioning_scale = [0.8]
|
365 |
+
|
366 |
+
flux_pipe.controlnet = FluxMultiControlNetModel([flux_controlnet for _ in control_mode])
|
367 |
+
# breakpoint()
|
368 |
+
rgb_normal_grid = multi_view_rgb_normal_generation_with_controlnet(
|
369 |
+
prompt= ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', image_caption]),
|
370 |
+
image=lrm_3D_bundle_image,
|
371 |
+
strength=0.6,
|
372 |
+
control_image=control_image,
|
373 |
+
control_mode=control_mode,
|
374 |
+
control_guidance_start=control_guidance_start,
|
375 |
+
control_guidance_end=control_guidance_end,
|
376 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
377 |
+
lora_scale=1.0
|
378 |
+
) # noted that rgb_normal_grid is a (b, h, w, c) numpy array
|
379 |
+
|
380 |
+
rgb_normal_grid = torch.from_numpy(rgb_normal_grid).contiguous().float()
|
381 |
+
rgb_normal_grid = rearrange(rgb_normal_grid.squeeze(0), '(n h) (m w) c-> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
|
382 |
+
rgb_multi_view = rgb_normal_grid[:4, :3, :, :].cuda()
|
383 |
+
normal_multi_view = rgb_normal_grid[4:, :3, :, :].cuda()
|
384 |
+
multi_view_mask = get_background(normal_multi_view).cuda()
|
385 |
+
rgb_multi_view = rgb_multi_view * multi_view_mask + (1-multi_view_mask)
|
386 |
+
|
387 |
+
# local normal to global normal
|
388 |
+
global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1).cpu(), isomer_azimuths, isomer_elevations).cuda()
|
389 |
+
|
390 |
+
global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
|
391 |
+
|
392 |
+
global_normal = global_normal.permute(0,2,3,1)
|
393 |
+
multi_view_mask = multi_view_mask.squeeze(1)
|
394 |
+
rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
|
395 |
+
# global_normal: B,H,W,3
|
396 |
+
# multi_view_mask: B,H,W
|
397 |
+
# rgb_multi_view: B,H,W,3
|
398 |
+
|
399 |
+
|
400 |
+
meshes = reconstruction(
|
401 |
+
normal_pils=global_normal,
|
402 |
+
masks=multi_view_mask,
|
403 |
+
weights=isomer_geo_weights,
|
404 |
+
fov=30,
|
405 |
+
radius=isomer_radius,
|
406 |
+
camera_angles_azi=isomer_azimuths,
|
407 |
+
camera_angles_ele=isomer_elevations,
|
408 |
+
expansion_weight_stage1=0.1,
|
409 |
+
init_type="file",
|
410 |
+
init_verts=vertices,
|
411 |
+
init_faces=faces,
|
412 |
+
stage1_steps=0,
|
413 |
+
stage2_steps=50,
|
414 |
+
start_edge_len_stage1=0.1,
|
415 |
+
end_edge_len_stage1=0.02,
|
416 |
+
start_edge_len_stage2=0.02,
|
417 |
+
end_edge_len_stage2=0.005,
|
418 |
+
)
|
419 |
+
|
420 |
+
save_glb_addr = projection(
|
421 |
+
meshes=meshes,
|
422 |
+
masks=multi_view_mask,
|
423 |
+
images=rgb_multi_view,
|
424 |
+
azimuths=isomer_azimuths,
|
425 |
+
elevations=isomer_elevations,
|
426 |
+
weights=isomer_color_weights,
|
427 |
+
fov=30,
|
428 |
+
radius=isomer_radius,
|
429 |
+
save_dir=f"{save_dir_path}/ISOMER/",
|
430 |
+
)
|
431 |
+
print(f'saved to {save_glb_addr}')
|
432 |
+
|
433 |
+
|
434 |
+
|
435 |
+
if __name__ == '__main__':
|
436 |
+
main()
|
pipeline/kiss3d_wrapper.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The kiss3d pipeline wrapper for inference
|
2 |
+
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import yaml
|
7 |
+
import uuid
|
8 |
+
from typing import Union, Any, Dict
|
9 |
+
from einops import rearrange
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from pipeline.utils import logger, TMP_DIR, OUT_DIR
|
13 |
+
from pipeline.utils import lrm_reconstruct, isomer_reconstruct
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torchvision
|
17 |
+
|
18 |
+
# for reconstruction model
|
19 |
+
from omegaconf import OmegaConf
|
20 |
+
from models.lrm.utils.train_util import instantiate_from_config
|
21 |
+
from models.lrm.utils.render_utils import rotate_x, rotate_y
|
22 |
+
from utils.tool import get_background
|
23 |
+
|
24 |
+
# for florence2
|
25 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
26 |
+
|
27 |
+
from diffusers import FluxPipeline, FluxControlNetImg2ImgPipeline, FluxImg2ImgPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
|
28 |
+
from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
29 |
+
|
30 |
+
|
31 |
+
def init_wrapper_from_config(config_path):
|
32 |
+
with open(config_path, 'r') as config_file:
|
33 |
+
config_ = yaml.load(config_file, yaml.FullLoader)
|
34 |
+
|
35 |
+
# init flux_pipeline
|
36 |
+
logger.info('==> Loading Flux model ...')
|
37 |
+
flux_device = config_['flux'].get('device', 'cpu')
|
38 |
+
flux_base_model_pth = config_['flux'].get('base_model', None)
|
39 |
+
flux_controlnet_pth = config_['flux'].get('controlnet', None)
|
40 |
+
flux_lora_pth = config_['flux'].get('lora', None)
|
41 |
+
|
42 |
+
# load flux model and controlnet
|
43 |
+
if flux_controlnet_pth is not None:
|
44 |
+
flux_controlnet = FluxControlNetModel.from_pretrained(flux_controlnet_pth)
|
45 |
+
flux_pipe = FluxControlNetImg2ImgPipeline.from_pretrained(flux_base_model_pth, controlnet=[flux_controlnet], \
|
46 |
+
torch_dtype=torch.bfloat16)
|
47 |
+
else:
|
48 |
+
flux_pipe = FluxImg2ImgPipeline(flux_base_model_pth, torch_dtype=torch.bfloat16)
|
49 |
+
|
50 |
+
# load lora weights
|
51 |
+
flux_pipe.load_lora_weights(flux_lora_pth)
|
52 |
+
flux_pipe.to(device=flux_device, dtype=torch.bfloat16)
|
53 |
+
|
54 |
+
# TODO: load redux model
|
55 |
+
# FluxPriorReduxPipeline.from_pretrained()
|
56 |
+
|
57 |
+
# TODO: load pulid model
|
58 |
+
|
59 |
+
# init multiview model
|
60 |
+
logger.info('==> Loading multiview diffusion model ...')
|
61 |
+
multiview_device = config_['multiview'].get('device', 'cpu')
|
62 |
+
multiview_pipeline = DiffusionPipeline.from_pretrained(
|
63 |
+
config_['multiview']['base_model'],
|
64 |
+
custom_pipeline=config_['multiview']['custom_pipeline'],
|
65 |
+
torch_dtype=torch.float16,
|
66 |
+
)
|
67 |
+
multiview_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
68 |
+
multiview_pipeline.scheduler.config, timestep_spacing='trailing'
|
69 |
+
)
|
70 |
+
|
71 |
+
unet_ckpt_path = config_['multiview'].get('unet', None)
|
72 |
+
if unet_ckpt_path is not None:
|
73 |
+
state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict']
|
74 |
+
state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
|
75 |
+
multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
|
76 |
+
|
77 |
+
multiview_pipeline.to(multiview_device)
|
78 |
+
|
79 |
+
# load caption model
|
80 |
+
logger.info('==> Loading caption model ...')
|
81 |
+
caption_device = config_['caption'].get('device', 'cpu')
|
82 |
+
caption_model = AutoModelForCausalLM.from_pretrained(config_['caption']['base_model'], \
|
83 |
+
torch_dtype=torch.bfloat16, trust_remote_code=True).to(caption_device)
|
84 |
+
caption_processor = AutoProcessor.from_pretrained(config_['caption']['base_model'], trust_remote_code=True)
|
85 |
+
|
86 |
+
# load reconstruction model
|
87 |
+
logger.info('==> Loading reconstruction model ...')
|
88 |
+
recon_device = config_['reconstruction'].get('device', 'cpu')
|
89 |
+
recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
|
90 |
+
recon_model = instantiate_from_config(recon_model_config.model_config)
|
91 |
+
# load recon model checkpoint
|
92 |
+
state_dict = torch.load(config_['reconstruction']['base_model'], map_location='cpu')['state_dict']
|
93 |
+
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
94 |
+
recon_model.load_state_dict(state_dict, strict=True)
|
95 |
+
recon_model.to(recon_device)
|
96 |
+
recon_model.init_flexicubes_geometry(recon_device, fovy=50.0)
|
97 |
+
recon_model.eval()
|
98 |
+
|
99 |
+
return kiss3d_wrapper(
|
100 |
+
config = config_,
|
101 |
+
flux_pipeline = flux_pipe,
|
102 |
+
multiview_pipeline = multiview_pipeline,
|
103 |
+
caption_processor = caption_processor,
|
104 |
+
caption_model = caption_model,
|
105 |
+
reconstruction_model_config = recon_model_config,
|
106 |
+
reconstruction_model = recon_model,
|
107 |
+
)
|
108 |
+
|
109 |
+
class kiss3d_wrapper(object):
|
110 |
+
def __init__(self,
|
111 |
+
config: Dict,
|
112 |
+
flux_pipeline: Union[FluxPipeline, FluxControlNetImg2ImgPipeline],
|
113 |
+
multiview_pipeline: DiffusionPipeline,
|
114 |
+
caption_processor: AutoProcessor,
|
115 |
+
caption_model: AutoModelForCausalLM,
|
116 |
+
reconstruction_model_config: Any,
|
117 |
+
reconstruction_model: Any,
|
118 |
+
):
|
119 |
+
self.config = config
|
120 |
+
self.flux_pipeline = flux_pipeline
|
121 |
+
self.multiview_pipeline = multiview_pipeline
|
122 |
+
self.caption_model = caption_model
|
123 |
+
self.caption_processor = caption_processor
|
124 |
+
self.recon_model_config = reconstruction_model_config
|
125 |
+
self.recon_model = reconstruction_model
|
126 |
+
|
127 |
+
self.renew_uuid()
|
128 |
+
|
129 |
+
def renew_uuid(self):
|
130 |
+
self.uuid = uuid.uuid4()
|
131 |
+
|
132 |
+
def context(self):
|
133 |
+
if self.config['use_zero_gpu']:
|
134 |
+
import spaces
|
135 |
+
return spaces.GPU()
|
136 |
+
else:
|
137 |
+
return torch.no_grad()
|
138 |
+
|
139 |
+
def get_image_caption(self, image):
|
140 |
+
"""
|
141 |
+
image: PIL image or path of PIL image
|
142 |
+
"""
|
143 |
+
torch_dtype = torch.bfloat16
|
144 |
+
caption_device = self.config['caption'].get('device', 'cpu')
|
145 |
+
|
146 |
+
if isinstance(image, str): # If image is a file path
|
147 |
+
image = Image.open(image).convert("RGB")
|
148 |
+
elif isinstance(image, Image):
|
149 |
+
image = image.convert("RGB")
|
150 |
+
else:
|
151 |
+
raise NotImplementedError('unexpected image type')
|
152 |
+
|
153 |
+
prompt = "<MORE_DETAILED_CAPTION>"
|
154 |
+
inputs = self.caption_processor(text=prompt, images=image, return_tensors="pt").to(caption_device, torch_dtype)
|
155 |
+
|
156 |
+
generated_ids = self.caption_model.generate(
|
157 |
+
input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
|
158 |
+
)
|
159 |
+
|
160 |
+
generated_text = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
161 |
+
parsed_answer = self.caption_processor.post_process_generation(
|
162 |
+
generated_text, task=prompt, image_size=(image.width, image.height)
|
163 |
+
)
|
164 |
+
caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"].replace("The image is ", "")
|
165 |
+
return caption_text
|
166 |
+
|
167 |
+
def generate_multiview(self, image):
|
168 |
+
with self.context():
|
169 |
+
mv_image = self.multiview_pipeline(image,
|
170 |
+
num_inference_steps=self.config['multiview']['num_inference_steps'],
|
171 |
+
width=512*2, height=512*2).images[0]
|
172 |
+
return mv_image
|
173 |
+
|
174 |
+
def reconstruct_from_multiview(self, mv_image):
|
175 |
+
"""
|
176 |
+
mv_image: PIL.Image
|
177 |
+
"""
|
178 |
+
recon_device = self.config['reconstruction'].get('device', 'cpu')
|
179 |
+
|
180 |
+
rgb_multi_view = np.asarray(mv_image, dtype=np.float32) / 255.0
|
181 |
+
rgb_multi_view = torch.from_numpy(rgb_multi_view).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
|
182 |
+
rgb_multi_view = rearrange(rgb_multi_view, 'c (n h) (m w) -> (n m) c h w', n=2, m=2).unsqueeze(0).to(recon_device)
|
183 |
+
|
184 |
+
with self.context():
|
185 |
+
vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \
|
186 |
+
lrm_reconstruct(self.recon_model, self.recon_model_config.infer_config,
|
187 |
+
rgb_multi_view, name=self.uuid)
|
188 |
+
|
189 |
+
return vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo
|
190 |
+
|
191 |
+
def generate_reference_3D_bundle_image_zero123(self, image, save_intermediate_results=True):
|
192 |
+
"""
|
193 |
+
input: image, PIL.Image
|
194 |
+
return: ref_3D_bundle_image, Tensor of shape (1, 3, 1024, 2048)
|
195 |
+
"""
|
196 |
+
mv_image = self.generate_multiview(image)
|
197 |
+
|
198 |
+
if save_intermediate_results:
|
199 |
+
mv_image.save(os.path.join(TMP_DIR, f'{self.uuid}_mv_image.png'))
|
200 |
+
|
201 |
+
vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = self.reconstruct_from_multiview(mv_image)
|
202 |
+
|
203 |
+
ref_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
|
204 |
+
|
205 |
+
if save_intermediate_results:
|
206 |
+
save_path = os.path.join(TMP_DIR, f'{self.uuid}_ref_3d_bundle_image.png')
|
207 |
+
torchvision.utils.save_image(ref_3D_bundle_image, save_path)
|
208 |
+
|
209 |
+
logger.info(f"Save reference 3D bundle image to {save_path}")
|
210 |
+
|
211 |
+
return ref_3D_bundle_image, save_path
|
212 |
+
|
213 |
+
return ref_3D_bundle_image
|
214 |
+
|
215 |
+
def generate_3d_bundle_image_controlnet(self,
|
216 |
+
prompt,
|
217 |
+
image=None,
|
218 |
+
strength=1.0,
|
219 |
+
control_image=[],
|
220 |
+
control_mode=[],
|
221 |
+
control_guidance_start=None,
|
222 |
+
control_guidance_end=None,
|
223 |
+
controlnet_conditioning_scale=None,
|
224 |
+
lora_scale=1.0,
|
225 |
+
save_intermediate_results=True,
|
226 |
+
**kwargs):
|
227 |
+
control_mode_dict = {
|
228 |
+
'canny': 0,
|
229 |
+
'tile': 1,
|
230 |
+
'depth': 2,
|
231 |
+
'blur': 3,
|
232 |
+
'pose': 4,
|
233 |
+
'gray': 5,
|
234 |
+
'lq': 6,
|
235 |
+
} # for https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union only
|
236 |
+
|
237 |
+
flux_device = self.config['flux'].get('device', 'cpu')
|
238 |
+
seed = self.config['flux'].get('seed', 0)
|
239 |
+
|
240 |
+
generator = torch.Generator(device=flux_device).manual_seed(seed)
|
241 |
+
|
242 |
+
hparam_dict = {
|
243 |
+
'prompt': ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', prompt]),
|
244 |
+
'image': image or torch.zeros((1, 3, 1024, 2048), dtype=torch.float32, device=flux_device),
|
245 |
+
'strength': strength,
|
246 |
+
'num_inference_steps': 30,
|
247 |
+
'guidance_scale': 3.5,
|
248 |
+
'num_images_per_prompt': 1,
|
249 |
+
'width': 2048,
|
250 |
+
'height': 1024,
|
251 |
+
'output_type': 'np',
|
252 |
+
'generator': generator,
|
253 |
+
'joint_attention_kwargs': {"scale": lora_scale}
|
254 |
+
}
|
255 |
+
hparam_dict.update(kwargs)
|
256 |
+
|
257 |
+
# append controlnet hparams
|
258 |
+
if len(control_image) > 0:
|
259 |
+
assert isinstance(self.flux_pipeline, FluxControlNetImg2ImgPipeline)
|
260 |
+
assert len(control_mode) == len(control_image) # the count of image should be the same as control mode
|
261 |
+
|
262 |
+
flux_ctrl_net = self.flux_pipeline.controlnet.nets[0]
|
263 |
+
self.flux_pipeline.controlnet = FluxMultiControlNetModel([flux_ctrl_net for i in range(len(control_image))])
|
264 |
+
|
265 |
+
ctrl_hparams = {
|
266 |
+
'control_mode': [control_mode_dict[mode_] for mode_ in control_mode],
|
267 |
+
'control_image': control_image,
|
268 |
+
'control_guidance_start': control_guidance_start or [0.0 for i in range(len(control_image))],
|
269 |
+
'control_guidance_end': control_guidance_end or [1.0 for i in range(len(control_image))],
|
270 |
+
'controlnet_conditioning_scale': controlnet_conditioning_scale or [1.0 for i in range(len(control_image))],
|
271 |
+
}
|
272 |
+
|
273 |
+
hparam_dict.update(ctrl_hparams)
|
274 |
+
|
275 |
+
with self.context():
|
276 |
+
gen_3d_bundle_image = self.flux_pipeline(**hparam_dict).images
|
277 |
+
|
278 |
+
gen_3d_bundle_image_ = torch.from_numpy(gen_3d_bundle_image).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
|
279 |
+
|
280 |
+
if save_intermediate_results:
|
281 |
+
save_path = os.path.join(TMP_DIR, f'{self.uuid}_gen_3d_bundle_image.png')
|
282 |
+
torchvision.utils.save_image(gen_3d_bundle_image_, save_path)
|
283 |
+
logger.info(f"Save generated 3D bundle image to {save_path}")
|
284 |
+
return gen_3d_bundle_image_, save_path
|
285 |
+
|
286 |
+
return gen_3d_bundle_image_
|
287 |
+
|
288 |
+
|
289 |
+
def generate_3d_bundle_image_text(self,
|
290 |
+
prompt,
|
291 |
+
image=None,
|
292 |
+
strength=1.0,
|
293 |
+
lora_scale=1.0,
|
294 |
+
num_inference_steps=30,
|
295 |
+
save_intermediate_results=True,
|
296 |
+
**kwargs):
|
297 |
+
|
298 |
+
"""
|
299 |
+
return: gen_3d_bundle_image, torch.Tensor of shape (3, 1024, 2048), range [0., 1.]
|
300 |
+
"""
|
301 |
+
|
302 |
+
if isinstance(self.flux_pipeline, FluxControlNetImg2ImgPipeline):
|
303 |
+
flux_pipeline = FluxImg2ImgPipeline(
|
304 |
+
scheduler = self.flux_pipeline.scheduler,
|
305 |
+
vae = self.flux_pipeline.vae,
|
306 |
+
text_encoder = self.flux_pipeline.text_encoder,
|
307 |
+
tokenizer = self.flux_pipeline.tokenizer,
|
308 |
+
text_encoder_2 = self.flux_pipeline.text_encoder_2,
|
309 |
+
tokenizer_2 = self.flux_pipeline.tokenizer_2,
|
310 |
+
transformer = self.flux_pipeline.transformer
|
311 |
+
)
|
312 |
+
else:
|
313 |
+
flux_pipeline = self.flux_pipeline
|
314 |
+
|
315 |
+
flux_device = self.config['flux'].get('device', 'cpu')
|
316 |
+
seed = self.config['flux'].get('seed', 0)
|
317 |
+
|
318 |
+
generator = torch.Generator(device=flux_device).manual_seed(seed)
|
319 |
+
|
320 |
+
hparam_dict = {
|
321 |
+
'prompt': ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', prompt]),
|
322 |
+
'image': image or torch.zeros((1, 3, 1024, 2048), dtype=torch.float32, device=flux_device),
|
323 |
+
'strength': strength,
|
324 |
+
'num_inference_steps': num_inference_steps,
|
325 |
+
'guidance_scale': 3.5,
|
326 |
+
'num_images_per_prompt': 1,
|
327 |
+
'width': 2048,
|
328 |
+
'height': 1024,
|
329 |
+
'output_type': 'np',
|
330 |
+
'generator': generator,
|
331 |
+
'joint_attention_kwargs': {"scale": lora_scale}
|
332 |
+
}
|
333 |
+
hparam_dict.update(kwargs)
|
334 |
+
|
335 |
+
with self.context():
|
336 |
+
gen_3d_bundle_image = flux_pipeline(**hparam_dict).images
|
337 |
+
|
338 |
+
gen_3d_bundle_image_ = torch.from_numpy(gen_3d_bundle_image).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
|
339 |
+
|
340 |
+
if save_intermediate_results:
|
341 |
+
save_path = os.path.join(TMP_DIR, f'{self.uuid}_gen_3d_bundle_image.png')
|
342 |
+
torchvision.utils.save_image(gen_3d_bundle_image_, save_path)
|
343 |
+
logger.info(f"Save generated 3D bundle image to {save_path}")
|
344 |
+
return gen_3d_bundle_image_, save_path
|
345 |
+
|
346 |
+
return gen_3d_bundle_image_
|
347 |
+
|
348 |
+
def reconstruct_3d_bundle_image(self, image, save_intermediate_results=True):
|
349 |
+
"""
|
350 |
+
image: torch.Tensor, range [0., 1.], (3, 1024, 2048)
|
351 |
+
"""
|
352 |
+
recon_device = self.config['reconstruction'].get('device', 'cpu')
|
353 |
+
|
354 |
+
# split rgb and normal
|
355 |
+
images = rearrange(image, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (3, 1024, 2048) -> (8, 3, 512, 512)
|
356 |
+
rgb_multi_view, normal_multi_view = images.chunk(2, dim=0)
|
357 |
+
multi_view_mask = get_background(normal_multi_view).to(recon_device)
|
358 |
+
rgb_multi_view = rgb_multi_view.to(recon_device) * multi_view_mask + (1 - multi_view_mask)
|
359 |
+
|
360 |
+
with self.context():
|
361 |
+
vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \
|
362 |
+
lrm_reconstruct(self.recon_model, self.recon_model_config.infer_config,
|
363 |
+
rgb_multi_view.unsqueeze(0).to(recon_device), name=self.uuid,
|
364 |
+
input_camera_type='kiss3d', render_3d_bundle_image=save_intermediate_results,
|
365 |
+
render_azimuths=[0, 90, 180, 270])
|
366 |
+
|
367 |
+
if save_intermediate_results:
|
368 |
+
recon_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
|
369 |
+
torchvision.utils.save_image(recon_3D_bundle_image, os.path.join(TMP_DIR, f'{k3d_wrapper.uuid})_lrm_recon_3d_bundle_image.png'))
|
370 |
+
|
371 |
+
recon_mesh_path = os.path.join(TMP_DIR, f"{self.uuid}_isomer_recon_mesh.obj")
|
372 |
+
|
373 |
+
return isomer_reconstruct(rgb_multi_view=rgb_multi_view,
|
374 |
+
normal_multi_view=normal_multi_view,
|
375 |
+
multi_view_mask=multi_view_mask,
|
376 |
+
vertices=vertices,
|
377 |
+
faces=faces,
|
378 |
+
save_path=recon_mesh_path)
|
379 |
+
|
380 |
+
|
381 |
+
def run_text_to_3d(k3d_wrapper,
|
382 |
+
prompt,
|
383 |
+
init_image_path=None):
|
384 |
+
# ======================================= Example of text to 3D generation ======================================
|
385 |
+
|
386 |
+
# Renew The uuid
|
387 |
+
k3d_wrapper.renew_uuid()
|
388 |
+
|
389 |
+
# FOR Text to 3D (also for image to image) with init image
|
390 |
+
init_image = None
|
391 |
+
if init_image_path is not None:
|
392 |
+
init_image = Image.open(init_image_path)
|
393 |
+
|
394 |
+
gen_3d_bundle_image, gen_save_path = k3d_wrapper.generate_3d_bundle_image_text(prompt,
|
395 |
+
image=init_image,
|
396 |
+
strength=1.0,
|
397 |
+
save_intermediate_results=True)
|
398 |
+
|
399 |
+
# recon from 3D Bundle image
|
400 |
+
recon_mesh_path = k3d_wrapper.reconstruct_3d_bundle_image(gen_3d_bundle_image, save_intermediate_results=False)
|
401 |
+
|
402 |
+
return gen_save_path, recon_mesh_path
|
403 |
+
|
404 |
+
def run_image_to_3d(k3d_wrapper, init_image_path):
|
405 |
+
# ======================================= Example of image to 3D generation ======================================
|
406 |
+
|
407 |
+
# Renew The uuid
|
408 |
+
k3d_wrapper.renew_uuid()
|
409 |
+
|
410 |
+
# FOR IMAGE TO 3D: generate reference 3D bundle image from a single input image
|
411 |
+
input_image = Image.open(init_image_path)
|
412 |
+
reference_3d_bundle_image, reference_save_path = k3d_wrapper.generate_reference_3D_bundle_image_zero123(input_image)
|
413 |
+
caption = k3d_wrapper.get_image_caption(input_image)
|
414 |
+
|
415 |
+
|
416 |
+
import pdb
|
417 |
+
pdb.set_trace()
|
418 |
+
|
419 |
+
|
420 |
+
if __name__ == "__main__":
|
421 |
+
k3d_wrapper = init_wrapper_from_config('/hpc2hdd/home/jlin695/code/Kiss3DGen/pipeline/pipeline_config/default.yaml')
|
422 |
+
|
423 |
+
# Example of loading existing 3D bundle Image
|
424 |
+
# demo_image = Image.open('/hpc2hdd/home/jlin695/code/github/Kiss3DGen/outputs/tmp/ea25bc9b-d775-46bb-9827-660a9a6540c8_gen_3d_bundle_image.png')
|
425 |
+
# gen_3d_bundle_image = torchvision.transforms.functional.to_tensor(demo_image)
|
426 |
+
|
427 |
+
run_image_to_3d(k3d_wrapper, '/hpc2hdd/home/jlin695/code/Kiss3DGen/examples/蓝色小怪物.webp')
|
428 |
+
# run_text_to_3d(k3d_wrapper, prompt='A doll of a girl in Harry Potter')
|
429 |
+
|
pipeline/pipeline_config/default.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
flux:
|
2 |
+
base_model: "/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev"
|
3 |
+
lora: "./checkpoint/flux_lora/rgb_normal_doll_object.safetensors"
|
4 |
+
controlnet: "/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/flux_controlnets/FLUX.1-dev-ControlNet-Union-Pro"
|
5 |
+
seed: 0
|
6 |
+
device: 'cuda:0'
|
7 |
+
|
8 |
+
multiview:
|
9 |
+
base_model: "sudo-ai/zero123plus-v1.2"
|
10 |
+
custom_pipeline: "./models/zero123plus"
|
11 |
+
unet: "./checkpoint/zero123++/flexgen_19w.ckpt"
|
12 |
+
num_inference_steps: 75
|
13 |
+
device: 'cuda:0'
|
14 |
+
|
15 |
+
reconstruction:
|
16 |
+
model_config: "./models/lrm/config/PRM_inference.yaml"
|
17 |
+
base_model: "./checkpoint/lrm/final_ckpt.ckpt"
|
18 |
+
device: 'cuda:0'
|
19 |
+
|
20 |
+
caption:
|
21 |
+
base_model: "/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2"
|
22 |
+
device: 'cuda:0'
|
23 |
+
|
24 |
+
use_zero_gpu: false # for huggingface demo only
|
25 |
+
3d_bundle_templates: '/hpc2hdd/home/jlin695/code/github/Kiss3DGen/init_3d_Bundle'
|
pipeline/run_hpc.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
source /hpc2ssd/softwares/anaconda3/bin/activate kiss3dgen
|
2 |
+
module load cuda/12.1 compilers/gcc-11.1.0 compilers/icc-2023.1.0 cmake/3.27.0
|
3 |
+
export CXX=$(which g++)
|
4 |
+
export CC=$(which gcc)
|
5 |
+
export CPLUS_INCLUDE_PATH=/hpc2ssd/softwares/cuda/cuda-12.1/targets/x86_64-linux/include:$CPLUS_INCLUDE_PATH
|
6 |
+
export CUDA_LAUNCH_BLOCKING=1
|
7 |
+
export NCCL_TIMEOUT=3600
|
8 |
+
export CUDA_VISIBLE_DEVICES="0"
|
9 |
+
|
10 |
+
python ./pipeline/kiss3d_wrapper.py
|
pipeline/utils.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import logging
|
4 |
+
|
5 |
+
__workdir__ = '/'.join(os.path.abspath(__file__).split('/')[:-2])
|
6 |
+
sys.path.insert(0, __workdir__)
|
7 |
+
|
8 |
+
print(__workdir__)
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from torchvision.transforms import v2
|
13 |
+
|
14 |
+
from models.lrm.online_render.render_single import load_mipmap
|
15 |
+
from models.lrm.utils.camera_util import get_zero123plus_input_cameras, get_custom_zero123plus_input_cameras, get_flux_input_cameras
|
16 |
+
from models.lrm.utils.render_utils import rotate_x, rotate_y
|
17 |
+
from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
|
18 |
+
|
19 |
+
from models.ISOMER.reconstruction_func import reconstruction
|
20 |
+
from models.ISOMER.projection_func import projection
|
21 |
+
|
22 |
+
from utils.tool import NormalTransfer, get_render_cameras_frames, get_background, get_render_cameras_video, render_frames, mask_fix
|
23 |
+
|
24 |
+
|
25 |
+
logging.basicConfig(
|
26 |
+
level = logging.INFO
|
27 |
+
)
|
28 |
+
logger = logging.getLogger('kiss3d_wrapper')
|
29 |
+
|
30 |
+
OUT_DIR = './outputs'
|
31 |
+
TMP_DIR = './outputs/tmp'
|
32 |
+
|
33 |
+
os.makedirs(TMP_DIR, exist_ok=True)
|
34 |
+
|
35 |
+
def lrm_reconstruct(model, infer_config, images,
|
36 |
+
name='', export_texmap=False,
|
37 |
+
input_camera_type='zero123',
|
38 |
+
render_3d_bundle_image=True,
|
39 |
+
render_azimuths=[270, 0, 90, 180],
|
40 |
+
render_elevations=[5, 5, 5, 5],
|
41 |
+
render_radius=4.5):
|
42 |
+
"""
|
43 |
+
image: Tensor, shape (1, c, h, w)
|
44 |
+
"""
|
45 |
+
|
46 |
+
mesh_path_idx = os.path.join(TMP_DIR, f'{name}_recon_from_{input_camera_type}.obj')
|
47 |
+
|
48 |
+
device = images.device
|
49 |
+
if input_camera_type == 'zero123':
|
50 |
+
input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
|
51 |
+
elif input_camera_type == 'kiss3d':
|
52 |
+
input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device)
|
53 |
+
else:
|
54 |
+
raise NotImplementedError(f'Unexpected input camera type: {input_camera_type}')
|
55 |
+
|
56 |
+
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
|
57 |
+
|
58 |
+
logger.info(f"==> Runing LRM reconstruction ...")
|
59 |
+
planes = model.forward_planes(images, input_cameras)
|
60 |
+
mesh_out = model.extract_mesh(
|
61 |
+
planes,
|
62 |
+
use_texture_map=export_texmap,
|
63 |
+
**infer_config,
|
64 |
+
)
|
65 |
+
if export_texmap:
|
66 |
+
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
|
67 |
+
save_obj_with_mtl(
|
68 |
+
vertices.data.cpu().numpy(),
|
69 |
+
uvs.data.cpu().numpy(),
|
70 |
+
faces.data.cpu().numpy(),
|
71 |
+
mesh_tex_idx.data.cpu().numpy(),
|
72 |
+
tex_map.permute(1, 2, 0).data.cpu().numpy(),
|
73 |
+
mesh_path_idx,
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
vertices, faces, vertex_colors = mesh_out
|
77 |
+
save_obj(vertices, faces, vertex_colors, mesh_path_idx)
|
78 |
+
logger.info(f"Mesh saved to {mesh_path_idx}")
|
79 |
+
|
80 |
+
if render_3d_bundle_image:
|
81 |
+
assert render_azimuths is not None and render_elevations is not None and render_radius is not None
|
82 |
+
render_azimuths = torch.Tensor(render_azimuths).to(device)
|
83 |
+
render_elevations = torch.Tensor(render_elevations).to(device)
|
84 |
+
|
85 |
+
render_size = infer_config.render_resolution
|
86 |
+
ENV = load_mipmap("models/lrm/env_mipmap/6")
|
87 |
+
materials = (0.0,0.9)
|
88 |
+
all_mv, all_mvp, all_campos, identity_mv = get_render_cameras_frames(
|
89 |
+
batch_size=1,
|
90 |
+
radius=render_radius,
|
91 |
+
azimuths=render_azimuths,
|
92 |
+
elevations=render_elevations,
|
93 |
+
fov=30
|
94 |
+
)
|
95 |
+
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
|
96 |
+
model,
|
97 |
+
planes,
|
98 |
+
render_cameras=all_mvp,
|
99 |
+
camera_pos=all_campos,
|
100 |
+
env=ENV,
|
101 |
+
materials=materials,
|
102 |
+
render_size=render_size,
|
103 |
+
render_mv = all_mv,
|
104 |
+
local_normal=True,
|
105 |
+
identity_mv=identity_mv,
|
106 |
+
)
|
107 |
+
else:
|
108 |
+
normals = None
|
109 |
+
frames = None
|
110 |
+
albedos = None
|
111 |
+
|
112 |
+
|
113 |
+
vertices = torch.from_numpy(vertices).to(device)
|
114 |
+
faces = torch.from_numpy(faces).to(device)
|
115 |
+
vertices = vertices @ rotate_x(np.pi / 2, device=device)[:3, :3]
|
116 |
+
vertices = vertices @ rotate_y(np.pi / 2, device=device)[:3, :3]
|
117 |
+
|
118 |
+
return vertices.cpu(), faces.cpu(), normals, frames, albedos
|
119 |
+
|
120 |
+
normal_transfer = NormalTransfer()
|
121 |
+
|
122 |
+
def local_normal_global_transform(local_normal_images,azimuths_deg,elevations_deg):
|
123 |
+
if local_normal_images.min() >= 0:
|
124 |
+
local_normal = local_normal_images.float() * 2 - 1
|
125 |
+
else:
|
126 |
+
local_normal = local_normal_images.float()
|
127 |
+
global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
|
128 |
+
global_normal[...,0] *= -1
|
129 |
+
global_normal = (global_normal + 1) / 2
|
130 |
+
global_normal = global_normal.permute(0, 3, 1, 2)
|
131 |
+
return global_normal
|
132 |
+
|
133 |
+
|
134 |
+
def isomer_reconstruct(
|
135 |
+
rgb_multi_view,
|
136 |
+
normal_multi_view,
|
137 |
+
multi_view_mask,
|
138 |
+
vertices,
|
139 |
+
faces,
|
140 |
+
save_path=None,
|
141 |
+
azimuths=[0, 90, 180, 270],
|
142 |
+
elevations=[5, 5, 5, 5],
|
143 |
+
geo_weights=[1, 0.9, 1, 0.9],
|
144 |
+
color_weights=[1, 0.5, 1, 0.5],
|
145 |
+
reconstruction_stage1_steps=50,
|
146 |
+
reconstruction_stage2_steps=50,
|
147 |
+
radius=4.1):
|
148 |
+
|
149 |
+
device = rgb_multi_view.device
|
150 |
+
to_tensor_ = lambda x: torch.Tensor(x).float().to(device)
|
151 |
+
|
152 |
+
# local normal to global normal
|
153 |
+
global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1).cpu(), to_tensor_(azimuths), to_tensor_(elevations)).to(device)
|
154 |
+
global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
|
155 |
+
|
156 |
+
global_normal = global_normal.permute(0,2,3,1)
|
157 |
+
multi_view_mask = multi_view_mask.squeeze(1)
|
158 |
+
rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
|
159 |
+
|
160 |
+
logger.info(f"==> Runing ISOMER reconstruction ...")
|
161 |
+
meshes = reconstruction(
|
162 |
+
normal_pils=global_normal,
|
163 |
+
masks=multi_view_mask,
|
164 |
+
weights=to_tensor_(geo_weights),
|
165 |
+
fov=30,
|
166 |
+
radius=radius,
|
167 |
+
camera_angles_azi=to_tensor_(azimuths),
|
168 |
+
camera_angles_ele=to_tensor_(elevations),
|
169 |
+
expansion_weight_stage1=0.1,
|
170 |
+
init_type="file",
|
171 |
+
init_verts=vertices,
|
172 |
+
init_faces=faces,
|
173 |
+
stage1_steps=reconstruction_stage1_steps,
|
174 |
+
stage2_steps=reconstruction_stage2_steps,
|
175 |
+
start_edge_len_stage1=0.1,
|
176 |
+
end_edge_len_stage1=0.02,
|
177 |
+
start_edge_len_stage2=0.02,
|
178 |
+
end_edge_len_stage2=0.005,
|
179 |
+
)
|
180 |
+
|
181 |
+
multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-10, blur=5)
|
182 |
+
|
183 |
+
logger.info(f"==> Runing ISOMER projection ...")
|
184 |
+
save_glb_addr = projection(
|
185 |
+
meshes,
|
186 |
+
masks=multi_view_mask_proj.to(device),
|
187 |
+
images=rgb_multi_view.to(device),
|
188 |
+
azimuths=to_tensor_(azimuths),
|
189 |
+
elevations=to_tensor_(elevations),
|
190 |
+
weights=to_tensor_(color_weights),
|
191 |
+
fov=30,
|
192 |
+
radius=radius,
|
193 |
+
save_dir=TMP_DIR,
|
194 |
+
save_glb_addr=save_path
|
195 |
+
)
|
196 |
+
|
197 |
+
logger.info(f"==> Save mesh to {save_glb_addr} ...")
|
198 |
+
return save_glb_addr
|
run.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
export CUDA_VISIBLE_DEVICES="0"
|
2 |
+
python text_to_mesh.py
|
run_hpc.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
source /hpc2ssd/softwares/anaconda3/bin/activate kiss3dgen
|
2 |
+
module load cuda/12.1 compilers/gcc-11.1.0 compilers/icc-2023.1.0 cmake/3.27.0
|
3 |
+
export CXX=$(which g++)
|
4 |
+
export CC=$(which gcc)
|
5 |
+
export CPLUS_INCLUDE_PATH=/hpc2ssd/softwares/cuda/cuda-12.1/targets/x86_64-linux/include:$CPLUS_INCLUDE_PATH
|
6 |
+
export CUDA_LAUNCH_BLOCKING=1
|
7 |
+
export NCCL_TIMEOUT=3600
|
8 |
+
export CUDA_VISIBLE_DEVICES="0"
|
9 |
+
# python app.py
|
10 |
+
python text_to_mesh.py
|
11 |
+
# python image_to_mesh.py
|
text_to_mesh_new.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from einops import rearrange
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import trimesh
|
7 |
+
import torchvision
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision import transforms
|
11 |
+
from torchvision.transforms import v2
|
12 |
+
from diffusers import HeunDiscreteScheduler
|
13 |
+
from diffusers import FluxPipeline
|
14 |
+
from pytorch_lightning import seed_everything
|
15 |
+
import os
|
16 |
+
|
17 |
+
import time
|
18 |
+
|
19 |
+
from models.lrm.utils.infer_util import save_video
|
20 |
+
from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
|
21 |
+
from models.lrm.utils.render_utils import rotate_x, rotate_y
|
22 |
+
from models.lrm.utils.train_util import instantiate_from_config
|
23 |
+
from models.lrm.utils.camera_util import get_flux_input_cameras
|
24 |
+
from models.ISOMER.reconstruction_func import reconstruction
|
25 |
+
from models.ISOMER.projection_func import projection
|
26 |
+
from utils.tool import NormalTransfer, load_mipmap
|
27 |
+
from utils.tool import get_background, get_render_cameras_video, render_frames, mask_fix
|
28 |
+
|
29 |
+
device = "cuda"
|
30 |
+
resolution = 512
|
31 |
+
save_dir = "./outputs/text2"
|
32 |
+
normal_transfer = NormalTransfer()
|
33 |
+
isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device)
|
34 |
+
isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device)
|
35 |
+
isomer_radius = 4.5
|
36 |
+
isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device)
|
37 |
+
isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device)
|
38 |
+
|
39 |
+
# model initialization and loading
|
40 |
+
# flux
|
41 |
+
flux_pipe = FluxPipeline.from_pretrained("/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev", torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16)
|
42 |
+
flux_pipe.load_lora_weights('./checkpoint/flux_lora/rgb_normal_large.safetensors')
|
43 |
+
|
44 |
+
flux_pipe.to(device=device, dtype=torch.bfloat16)
|
45 |
+
generator = torch.Generator(device=device).manual_seed(10)
|
46 |
+
|
47 |
+
# lrm
|
48 |
+
config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
|
49 |
+
model_config = config.model_config
|
50 |
+
infer_config = config.infer_config
|
51 |
+
model = instantiate_from_config(model_config)
|
52 |
+
model_ckpt_path = "./checkpoint/lrm/final_ckpt.ckpt"
|
53 |
+
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
54 |
+
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
55 |
+
model.load_state_dict(state_dict, strict=True)
|
56 |
+
|
57 |
+
model = model.to(device)
|
58 |
+
model.init_flexicubes_geometry(device, fovy=50.0)
|
59 |
+
model = model.eval()
|
60 |
+
|
61 |
+
# Flux multi-view generation
|
62 |
+
def multi_view_rgb_normal_generation(prompt, save_path=None):
|
63 |
+
# generate multi-view images
|
64 |
+
with torch.no_grad():
|
65 |
+
image = flux_pipe(
|
66 |
+
prompt=prompt,
|
67 |
+
num_inference_steps=30,
|
68 |
+
guidance_scale=3.5,
|
69 |
+
num_images_per_prompt=1,
|
70 |
+
width=resolution*4,
|
71 |
+
height=resolution*2,
|
72 |
+
output_type='np',
|
73 |
+
generator=generator
|
74 |
+
).images
|
75 |
+
return image
|
76 |
+
|
77 |
+
# lrm reconstructions
|
78 |
+
def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
|
79 |
+
images = image.unsqueeze(0).to(device)
|
80 |
+
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
|
81 |
+
# breakpoint()
|
82 |
+
with torch.no_grad():
|
83 |
+
# get triplane
|
84 |
+
planes = model.forward_planes(images, input_cameras)
|
85 |
+
|
86 |
+
mesh_path_idx = os.path.join(save_path, f'{name}.obj')
|
87 |
+
|
88 |
+
mesh_out = model.extract_mesh(
|
89 |
+
planes,
|
90 |
+
use_texture_map=export_texmap,
|
91 |
+
**infer_config,
|
92 |
+
)
|
93 |
+
if export_texmap:
|
94 |
+
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
|
95 |
+
save_obj_with_mtl(
|
96 |
+
vertices.data.cpu().numpy(),
|
97 |
+
uvs.data.cpu().numpy(),
|
98 |
+
faces.data.cpu().numpy(),
|
99 |
+
mesh_tex_idx.data.cpu().numpy(),
|
100 |
+
tex_map.permute(1, 2, 0).data.cpu().numpy(),
|
101 |
+
mesh_path_idx,
|
102 |
+
)
|
103 |
+
else:
|
104 |
+
vertices, faces, vertex_colors = mesh_out
|
105 |
+
save_obj(vertices, faces, vertex_colors, mesh_path_idx)
|
106 |
+
print(f"Mesh saved to {mesh_path_idx}")
|
107 |
+
|
108 |
+
render_size = 512
|
109 |
+
if if_save_video:
|
110 |
+
video_path_idx = os.path.join(save_path, f'{name}.mp4')
|
111 |
+
render_size = infer_config.render_resolution
|
112 |
+
ENV = load_mipmap("models/lrm/env_mipmap/6")
|
113 |
+
materials = (0.0,0.9)
|
114 |
+
|
115 |
+
all_mv, all_mvp, all_campos = get_render_cameras_video(
|
116 |
+
batch_size=1,
|
117 |
+
M=240,
|
118 |
+
radius=4.5,
|
119 |
+
elevation=(90, 60.0),
|
120 |
+
is_flexicubes=True,
|
121 |
+
fov=30
|
122 |
+
)
|
123 |
+
|
124 |
+
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
|
125 |
+
model,
|
126 |
+
planes,
|
127 |
+
render_cameras=all_mvp,
|
128 |
+
camera_pos=all_campos,
|
129 |
+
env=ENV,
|
130 |
+
materials=materials,
|
131 |
+
render_size=render_size,
|
132 |
+
chunk_size=20,
|
133 |
+
is_flexicubes=True,
|
134 |
+
)
|
135 |
+
normals = (torch.nn.functional.normalize(normals) + 1) / 2
|
136 |
+
normals = normals * alphas + (1-alphas)
|
137 |
+
all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
|
138 |
+
|
139 |
+
save_video(
|
140 |
+
all_frames,
|
141 |
+
video_path_idx,
|
142 |
+
fps=30,
|
143 |
+
)
|
144 |
+
print(f"Video saved to {video_path_idx}")
|
145 |
+
|
146 |
+
return vertices, faces
|
147 |
+
|
148 |
+
|
149 |
+
def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg):
|
150 |
+
if local_normal_images.min() >= 0:
|
151 |
+
local_normal = local_normal_images.float() * 2 - 1
|
152 |
+
else:
|
153 |
+
local_normal = local_normal_images.float()
|
154 |
+
global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
|
155 |
+
global_normal[...,0] *= -1
|
156 |
+
global_normal = (global_normal + 1) / 2
|
157 |
+
global_normal = global_normal.permute(0, 3, 1, 2)
|
158 |
+
return global_normal
|
159 |
+
|
160 |
+
def main(prompt = "a owl wearing a hat."):
|
161 |
+
fix_prompt = 'a grid of 2x4 multi-view image. elevation 5. white background.'
|
162 |
+
# user prompt
|
163 |
+
|
164 |
+
save_dir_path = os.path.join(save_dir, prompt.split(".")[0].replace(" ", "_"))
|
165 |
+
os.makedirs(save_dir_path, exist_ok=True)
|
166 |
+
prompt = fix_prompt+" "+prompt
|
167 |
+
# generate multi-view images
|
168 |
+
rgb_normal_grid = multi_view_rgb_normal_generation(prompt)
|
169 |
+
# lrm reconstructions
|
170 |
+
images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
|
171 |
+
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
|
172 |
+
rgb_multi_view = images[:4, :3, :, :]
|
173 |
+
normal_multi_view = images[4:, :3, :, :]
|
174 |
+
multi_view_mask = get_background(normal_multi_view)
|
175 |
+
rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
|
176 |
+
input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device)
|
177 |
+
vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=False)
|
178 |
+
# local normal to global normal
|
179 |
+
|
180 |
+
global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
|
181 |
+
global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
|
182 |
+
|
183 |
+
global_normal = global_normal.permute(0,2,3,1)
|
184 |
+
rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
|
185 |
+
multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
|
186 |
+
vertices = torch.from_numpy(vertices).to(device)
|
187 |
+
faces = torch.from_numpy(faces).to(device)
|
188 |
+
vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
|
189 |
+
vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
|
190 |
+
|
191 |
+
# global_normal: B,H,W,3
|
192 |
+
# multi_view_mask: B,H,W
|
193 |
+
# rgb_multi_view: B,H,W,3
|
194 |
+
|
195 |
+
multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-6, blur=5)
|
196 |
+
|
197 |
+
meshes = reconstruction(
|
198 |
+
normal_pils=global_normal,
|
199 |
+
masks=multi_view_mask,
|
200 |
+
weights=isomer_geo_weights,
|
201 |
+
fov=30,
|
202 |
+
radius=isomer_radius,
|
203 |
+
camera_angles_azi=isomer_azimuths,
|
204 |
+
camera_angles_ele=isomer_elevations,
|
205 |
+
expansion_weight_stage1=0.1,
|
206 |
+
init_type="file",
|
207 |
+
init_verts=vertices,
|
208 |
+
init_faces=faces,
|
209 |
+
stage1_steps=0,
|
210 |
+
stage2_steps=50,
|
211 |
+
start_edge_len_stage1=0.1,
|
212 |
+
end_edge_len_stage1=0.02,
|
213 |
+
start_edge_len_stage2=0.02,
|
214 |
+
end_edge_len_stage2=0.005,
|
215 |
+
)
|
216 |
+
|
217 |
+
|
218 |
+
multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-10, blur=5)
|
219 |
+
|
220 |
+
save_glb_addr = projection(
|
221 |
+
meshes,
|
222 |
+
masks=multi_view_mask_proj,
|
223 |
+
images=rgb_multi_view,
|
224 |
+
azimuths=isomer_azimuths,
|
225 |
+
elevations=isomer_elevations,
|
226 |
+
weights=isomer_color_weights,
|
227 |
+
fov=30,
|
228 |
+
radius=isomer_radius,
|
229 |
+
save_dir=f"{save_dir_path}/ISOMER/",
|
230 |
+
)
|
231 |
+
print(f'saved to {save_glb_addr}')
|
232 |
+
|
233 |
+
|
234 |
+
|
235 |
+
if __name__ == '__main__':
|
236 |
+
import time
|
237 |
+
start_time = time.time()
|
238 |
+
prompts = ["A red dragon soaring", "A running Chihuahua", "A dancing rabbit", "A girl with blue hair and white dress", "A teacher", "A tiger playing guitar", "A red rose", "A red peony", "A rose in a vase", "A golden retriever sitting", "A golden retriever running"]
|
239 |
+
for prompt in prompts:
|
240 |
+
main(prompt)
|
241 |
+
end_time = time.time()
|
242 |
+
print(f"Time taken: {end_time - start_time:.2f} seconds for {len(prompts)} prompts")
|
243 |
+
|
244 |
+
breakpoint()
|
upload_huggingface.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import HfApi, HfFolder, Repository, create_repo, upload_file
|
2 |
+
import os
|
3 |
+
|
4 |
+
# 登录到 Hugging Face
|
5 |
+
from huggingface_hub import login
|
6 |
+
login()
|
7 |
+
|
8 |
+
# 创建或指定现有的 Repository
|
9 |
+
repo_name = "xxx-ckpt"
|
10 |
+
username = "LTT"
|
11 |
+
repo_id = f"{username}/{repo_name}"
|
12 |
+
|
13 |
+
# 创建仓库(如果它不存在)
|
14 |
+
create_repo(repo_id, exist_ok=True)
|
15 |
+
|
16 |
+
# 文件夹
|
17 |
+
# 上传整个文件夹
|
18 |
+
def upload_folder(folder_path, repo_id):
|
19 |
+
"""
|
20 |
+
递归上传文件夹及其内容到 Hugging Face 仓库。
|
21 |
+
"""
|
22 |
+
for root, _, files in os.walk(folder_path):
|
23 |
+
for file in files:
|
24 |
+
# 文件完整路径
|
25 |
+
full_file_path = os.path.join(root, file)
|
26 |
+
# 相对于文件夹的相对路径(保留文件夹结构)
|
27 |
+
relative_path = os.path.relpath(full_file_path, folder_path)
|
28 |
+
|
29 |
+
# 上传文件到仓库
|
30 |
+
print(f"Uploading {relative_path}...")
|
31 |
+
upload_file(
|
32 |
+
path_or_fileobj=full_file_path,
|
33 |
+
path_in_repo=relative_path,
|
34 |
+
repo_id=repo_id
|
35 |
+
)
|
36 |
+
print(f"Uploaded {relative_path} successfully.")
|
37 |
+
|
38 |
+
|
39 |
+
# 上传模型文件
|
40 |
+
model_path = "checkpoint/zero123++/flexgen_19w.ckpt"
|
41 |
+
upload_file(path_or_fileobj=model_path, path_in_repo="flexgen_19w.ckpt", repo_id=repo_id)
|
42 |
+
|
43 |
+
# # 上传数据文件
|
44 |
+
# data_path = "/hpc2hdd/home/jlin695/data/env_map/data/env_mipmap_large.tar.gz"
|
45 |
+
# upload_file(path_or_fileobj=data_path, path_in_repo="env_mipmap_large.tar.gz", repo_id=repo_id)
|
46 |
+
|
47 |
+
# # 上传数据文件
|
48 |
+
# data_path = "/hpc2hdd/home/jlin695/data/env_map/data/env_map_light_large.tar.gz"
|
49 |
+
# upload_file(path_or_fileobj=data_path, path_in_repo="env_map_light_large.tar.gz", repo_id=repo_id)
|
50 |
+
|
51 |
+
# # 定义要上传的文件夹路径
|
52 |
+
# folder_path = "checkpoint/flux_lora"
|
53 |
+
|
54 |
+
# # 调用上传文件夹的函数
|
55 |
+
# upload_folder(folder_path, repo_id)
|
56 |
+
|
57 |
+
# print("模型和数据文件已上传到 Hugging Face。")
|