File size: 4,303 Bytes
8e0b903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import logging
import time

import numpy as np
import rembg
import torch
from PIL import Image
from rotate import rotate
import streamlit as st

import sys
import os


from tsr.system import TSR
from x3D_utils import remove_background, resize_foreground

import logging
import time
import streamlit as st
import torch
from datetime import datetime


# Hàm tùy chỉnh để hiển thị thông báo kèm thời gian tương tự logging
def st_info_with_logging_format(message):
    current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S:%f')[:-3]  # Định dạng thời gian giống logging
    formatted_message = f"{current_time} - INFO - {message}"
    st.info(formatted_message)

def st_success_with_logging_format(message):
    current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S:%f')[:-3]  # Định dạng thời gian giống logging
    formatted_message = f"{current_time} - INFO - {message}"
    st.success(formatted_message)

# Thay thế các thông báo st.info bằng hàm mới
def convert_to_3d(image_path, output_filename='', isHuman=False, isCloth=False, cloth_cat=''):
    with st.expander("ImageTo3D Extract Infomation"):
        class Timer:
            def __init__(self):
                self.items = {}
                self.time_scale = 1000.0  # ms
                self.time_unit = "seconds"

            def start(self, name: str) -> None:
                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                self.items[name] = time.time()
                st_info_with_logging_format(f"{name} ...")

            def end(self, name: str) -> float:
                if name not in self.items:
                    return
                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                start_time = self.items.pop(name)
                delta = time.time() - start_time
                t = delta * self.time_scale
                st_success_with_logging_format(f"{name} finished in {(t / 1000):.2f} {self.time_unit}.")

        timer = Timer()

        logging.basicConfig(
            format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
        )

        if not torch.cuda.is_available():
            device = "cpu"
        else:
            device = "cuda:0"

        timer.start("Initializing model")
        model = TSR.from_pretrained(
            config_path=r"config.yaml",
            weight_path=r"model.ckpt"
        )

        model.renderer.set_chunk_size(10_000)  # 0 for no chunking; default is 8192
        model.to(device)
        timer.end("Initializing model")

        timer.start("Removing background")
        if isHuman:
            rembg_session = rembg.new_session(model_name="u2net_human_seg")
        elif isCloth:
            rembg_session = rembg.new_session(model_name="u2net_cloth_seg")
        else:
            rembg_session = rembg.new_session()

        if isCloth and cloth_cat != '':
            image = remove_background(Image.open(image_path), rembg_session, cloth_category=cloth_cat)
        else:
            image = remove_background(Image.open(image_path), rembg_session)
        image = resize_foreground(image, 0.85)
        image = np.array(image).astype(np.float32) / 255.0
        image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
        image = Image.fromarray((image * 255.0).astype(np.uint8))

        timer.end("Removing background")

        timer.start("Running model on image")
        with torch.no_grad():
            scene_codes = model([image], device=device)
        timer.end("Running model on image")

        timer.start("Extracting mesh")
        mesh = model.extract_mesh(scene_codes, resolution=256)[0]
        timer.end("Extracting mesh")

        timer.start("Rotating object")
        mesh = rotate(mesh)
        timer.end("Rotating object")

        timer.start("Saving generated object")
        if output_filename == '':
            output_filename = f"{image_path.split('.')[-2]}_out"
        output_filepath = f"{output_filename}.glb"
        mesh.export(output_filepath)
        timer.end("Saving generated object")

        return output_filepath