File size: 7,667 Bytes
68cd723
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf1f021
68cd723
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a05849
68cd723
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf1f021
68cd723
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf1f021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68cd723
 
3a05849
bf1f021
 
 
fee0483
bf1f021
3a05849
68cd723
3a05849
bf1f021
 
 
3a05849
68cd723
bf1f021
 
 
 
68cd723
bf1f021
68cd723
3a05849
 
 
 
 
bf1f021
68cd723
 
bf1f021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68cd723
bf1f021
68cd723
d0e49e5
68cd723
 
 
 
d0e49e5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os, sys, time, traceback
print("sys path insert", os.path.join(os.path.dirname(__file__), "dust3r"))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "dust3r"))

import cv2
import numpy as np
from PIL import Image, ImageSequence
from einops import rearrange
import torch

from infer.utils import seed_everything, timing_decorator
from infer.utils import get_parameter_number, set_parameter_grad_false

from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo

from third_party.gen_baking import back_projection
from third_party.dust3r_utils import infer_warp_mesh_img
from svrm.ldm.vis_util import render_func


class MeshBaker:
    def __init__(
        self, 
        align_model = "third_party/weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt",
        device = "cuda:0", 
        align_times = 1,
        iou_thresh = 0.8,
        save_memory = False
    ):
        self.device = device
        self.save_memory = save_memory
        self.align_model = AsymmetricCroCo3DStereo.from_pretrained(align_model)
        self.align_model = self.align_model if save_memory else self.align_model.to(device)
        self.align_times = align_times
        self.align_model.eval()
        self.iou_thresh = iou_thresh
        set_parameter_grad_false(self.align_model)
        print('baking align model', get_parameter_number(self.align_model))
    
    def align_and_check(self, src, dst, align_times=3):
        try:
            st = time.time()
            best_baking_flag = False
            best_aligned_image = aligned_image = src
            best_info = {'match_num': 1000, "mask_iou": self.iou_thresh-0.1}
            for i in range(align_times):
                aligned_image, info = infer_warp_mesh_img(aligned_image, dst, self.align_model, vis=False)
                aligned_image = Image.fromarray(aligned_image)
                print(f"{i}-th time align process, mask-iou is {info['mask_iou']}")
                if info['mask_iou'] > best_info['mask_iou']:
                    best_aligned_image, best_info = aligned_image, info
                if info['mask_iou'] < self.iou_thresh:
                    break
            # print(f"Best Baking Info:{best_info['mask_iou']}")
            best_baking_flag = best_info['mask_iou'] > self.iou_thresh
            return best_aligned_image, best_info, best_baking_flag
        except Exception as e:
            print(f"Error processing image: {e}")
            traceback.print_exc()
            return None, None, None
        
    @timing_decorator("baking mesh")
    def __call__(self, *args, **kwargs):
        if self.save_memory:
            self.align_model = self.align_model.to(self.device)
            torch.cuda.empty_cache()
            res = self.call(*args, **kwargs)
            self.align_model = self.align_model.to("cpu")
        else:
            res = self.call(*args, **kwargs)
        torch.cuda.empty_cache()
        return res
    
    def call(self, save_folder, force=False, front='auto', others=['180°'], align_times=3, seed=0):
        obj_path         = os.path.join(save_folder, "mesh.obj")
        raw_texture_path = os.path.join(save_folder, "texture.png")
        views_pil        = os.path.join(save_folder, "views.jpg")
        views_gif        = os.path.join(save_folder, "views.gif")
        cond_pil         = os.path.join(save_folder, "img_nobg.png")

        if os.path.exists(views_pil):
            views_pil = Image.open(views_pil)
            views = rearrange(np.asarray(views_pil, dtype=np.uint8), '(n h) (m w) c -> (n m) h w c', n=3, m=2)
            views = [Image.fromarray(views[idx]).convert('RGB') for idx in [0,2,4,5,3,1]] 
            cond_pil = Image.open(cond_pil).resize((512,512))
        elif os.path.exists(views_gif):
            views_gif_pil = Image.open(views_gif)
            views = [img.convert('RGB') for img in ImageSequence.Iterator(views_gif_pil)]
            cond_pil, views = views[0], views[1:]
        else:
            raise FileNotFoundError("views file not found")
        
        others = [int(x.replace("°", "")) for x in others]
        
        if len(others)==0:
            rendered_views = render_func(obj_path, elev=0, n_views=1)
        elif len(others)==1 and others[0]==180:
            rendered_views = render_func(obj_path, elev=0, n_views=2)
        else:
            rendered_views = render_func(obj_path, elev=0, n_views=6)
            
        print(f"Need baking views are {others}")
        others = [0] + others
        
        seed_everything(seed)
        
        for ele_idx, ele in enumerate([0, 60, 120, 180, 240, 300]):
            
            if ele not in others: continue
            
            print(f"\n Baking view ele_{ele} ...")
            
            if ele == 0:
                if front == 'input image' or front == 'auto':
                    aligned_cond, cond_info, _ = self.align_and_check(cond_pil, rendered_views[0], align_times=self.align_times)
                    if cond_info is None: continue
                    aligned_cond.convert("RGB").save(save_folder + f'/aligned_cond.jpg')
                    if front == 'input image':
                        aligned_img, info = aligned_cond, cond_info
                        print("Using input image to bake front view")
        
                if front == 'multi-view front view' or front == 'auto':
                    aligned_img, info, _ = self.align_and_check(views[0], rendered_views[0], align_times=self.align_times)
                    if info is None: continue
                    aligned_img.save(save_folder + f'/aligned_{ele}.jpg')
                    print("Using multi-view front view image to bake front view")
                
                if front == 'auto' and info['mask_iou'] < cond_info['mask_iou']:
                    print("Auto using Cond Image to bake front view")
                    aligned_img, info = aligned_cond, cond_info
                    
                need_baking = info['mask_iou'] > self.iou_thresh
                
            else:
                aligned_img, info, need_baking = self.align_and_check(
                    views[ele//60], 
                    rendered_views[min(ele//60, len(others)-1)], 
                    align_times=self.align_times
                )
                if info is None: continue
                aligned_img.save(save_folder + f'/aligned_{ele}.jpg')

            try:
                if need_baking or force:
                    st = time.time()
                    view1_res = back_projection(
                        obj_file = obj_path,
                        init_texture_file = raw_texture_path,
                        front_view_file = aligned_img,
                        dst_dir = os.path.join(save_folder, f"view_{ele_idx}"),
                        render_resolution = aligned_img.size[0], 
                        uv_resolution = 1024,
                        views = [[0, ele]],
                        device = self.device
                    )
                    print(f"view_{ele_idx} elevation_{ele} baking finished at {time.time() - st}")
                    obj_path = os.path.join(save_folder, f"view_{ele_idx}/bake/mesh.obj")
                    raw_texture_path = os.path.join(save_folder, f"view_{ele_idx}/bake/texture.png")
                else:
                    print(f"Skip view_{ele_idx} elevation_{ele} baking")
            except Exception as err:
                print(err)
                continue

        print("\nBaking Finished\n")
        return obj_path


if __name__ == "__main__":
    baker = MeshBaker()
    obj_path = baker("./outputs/test")
    print(obj_path)