Spaces:
Running
on
L40S
Running
on
L40S
File size: 7,667 Bytes
68cd723 bf1f021 68cd723 3a05849 68cd723 bf1f021 68cd723 bf1f021 68cd723 3a05849 bf1f021 fee0483 bf1f021 3a05849 68cd723 3a05849 bf1f021 3a05849 68cd723 bf1f021 68cd723 bf1f021 68cd723 3a05849 bf1f021 68cd723 bf1f021 68cd723 bf1f021 68cd723 d0e49e5 68cd723 d0e49e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import os, sys, time, traceback
print("sys path insert", os.path.join(os.path.dirname(__file__), "dust3r"))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "dust3r"))
import cv2
import numpy as np
from PIL import Image, ImageSequence
from einops import rearrange
import torch
from infer.utils import seed_everything, timing_decorator
from infer.utils import get_parameter_number, set_parameter_grad_false
from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo
from third_party.gen_baking import back_projection
from third_party.dust3r_utils import infer_warp_mesh_img
from svrm.ldm.vis_util import render_func
class MeshBaker:
def __init__(
self,
align_model = "third_party/weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt",
device = "cuda:0",
align_times = 1,
iou_thresh = 0.8,
save_memory = False
):
self.device = device
self.save_memory = save_memory
self.align_model = AsymmetricCroCo3DStereo.from_pretrained(align_model)
self.align_model = self.align_model if save_memory else self.align_model.to(device)
self.align_times = align_times
self.align_model.eval()
self.iou_thresh = iou_thresh
set_parameter_grad_false(self.align_model)
print('baking align model', get_parameter_number(self.align_model))
def align_and_check(self, src, dst, align_times=3):
try:
st = time.time()
best_baking_flag = False
best_aligned_image = aligned_image = src
best_info = {'match_num': 1000, "mask_iou": self.iou_thresh-0.1}
for i in range(align_times):
aligned_image, info = infer_warp_mesh_img(aligned_image, dst, self.align_model, vis=False)
aligned_image = Image.fromarray(aligned_image)
print(f"{i}-th time align process, mask-iou is {info['mask_iou']}")
if info['mask_iou'] > best_info['mask_iou']:
best_aligned_image, best_info = aligned_image, info
if info['mask_iou'] < self.iou_thresh:
break
# print(f"Best Baking Info:{best_info['mask_iou']}")
best_baking_flag = best_info['mask_iou'] > self.iou_thresh
return best_aligned_image, best_info, best_baking_flag
except Exception as e:
print(f"Error processing image: {e}")
traceback.print_exc()
return None, None, None
@timing_decorator("baking mesh")
def __call__(self, *args, **kwargs):
if self.save_memory:
self.align_model = self.align_model.to(self.device)
torch.cuda.empty_cache()
res = self.call(*args, **kwargs)
self.align_model = self.align_model.to("cpu")
else:
res = self.call(*args, **kwargs)
torch.cuda.empty_cache()
return res
def call(self, save_folder, force=False, front='auto', others=['180°'], align_times=3, seed=0):
obj_path = os.path.join(save_folder, "mesh.obj")
raw_texture_path = os.path.join(save_folder, "texture.png")
views_pil = os.path.join(save_folder, "views.jpg")
views_gif = os.path.join(save_folder, "views.gif")
cond_pil = os.path.join(save_folder, "img_nobg.png")
if os.path.exists(views_pil):
views_pil = Image.open(views_pil)
views = rearrange(np.asarray(views_pil, dtype=np.uint8), '(n h) (m w) c -> (n m) h w c', n=3, m=2)
views = [Image.fromarray(views[idx]).convert('RGB') for idx in [0,2,4,5,3,1]]
cond_pil = Image.open(cond_pil).resize((512,512))
elif os.path.exists(views_gif):
views_gif_pil = Image.open(views_gif)
views = [img.convert('RGB') for img in ImageSequence.Iterator(views_gif_pil)]
cond_pil, views = views[0], views[1:]
else:
raise FileNotFoundError("views file not found")
others = [int(x.replace("°", "")) for x in others]
if len(others)==0:
rendered_views = render_func(obj_path, elev=0, n_views=1)
elif len(others)==1 and others[0]==180:
rendered_views = render_func(obj_path, elev=0, n_views=2)
else:
rendered_views = render_func(obj_path, elev=0, n_views=6)
print(f"Need baking views are {others}")
others = [0] + others
seed_everything(seed)
for ele_idx, ele in enumerate([0, 60, 120, 180, 240, 300]):
if ele not in others: continue
print(f"\n Baking view ele_{ele} ...")
if ele == 0:
if front == 'input image' or front == 'auto':
aligned_cond, cond_info, _ = self.align_and_check(cond_pil, rendered_views[0], align_times=self.align_times)
if cond_info is None: continue
aligned_cond.convert("RGB").save(save_folder + f'/aligned_cond.jpg')
if front == 'input image':
aligned_img, info = aligned_cond, cond_info
print("Using input image to bake front view")
if front == 'multi-view front view' or front == 'auto':
aligned_img, info, _ = self.align_and_check(views[0], rendered_views[0], align_times=self.align_times)
if info is None: continue
aligned_img.save(save_folder + f'/aligned_{ele}.jpg')
print("Using multi-view front view image to bake front view")
if front == 'auto' and info['mask_iou'] < cond_info['mask_iou']:
print("Auto using Cond Image to bake front view")
aligned_img, info = aligned_cond, cond_info
need_baking = info['mask_iou'] > self.iou_thresh
else:
aligned_img, info, need_baking = self.align_and_check(
views[ele//60],
rendered_views[min(ele//60, len(others)-1)],
align_times=self.align_times
)
if info is None: continue
aligned_img.save(save_folder + f'/aligned_{ele}.jpg')
try:
if need_baking or force:
st = time.time()
view1_res = back_projection(
obj_file = obj_path,
init_texture_file = raw_texture_path,
front_view_file = aligned_img,
dst_dir = os.path.join(save_folder, f"view_{ele_idx}"),
render_resolution = aligned_img.size[0],
uv_resolution = 1024,
views = [[0, ele]],
device = self.device
)
print(f"view_{ele_idx} elevation_{ele} baking finished at {time.time() - st}")
obj_path = os.path.join(save_folder, f"view_{ele_idx}/bake/mesh.obj")
raw_texture_path = os.path.join(save_folder, f"view_{ele_idx}/bake/texture.png")
else:
print(f"Skip view_{ele_idx} elevation_{ele} baking")
except Exception as err:
print(err)
continue
print("\nBaking Finished\n")
return obj_path
if __name__ == "__main__":
baker = MeshBaker()
obj_path = baker("./outputs/test")
print(obj_path)
|