File size: 2,754 Bytes
9c98fd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3aa18cd
 
 
 
9c98fd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3aa18cd
 
 
 
 
 
 
 
 
 
9c98fd2
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import logging
import os
import tempfile
import time

import numpy as np
import rembg
import torch
from PIL import Image
from functools import partial

from tsr.system import TSR
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation

import argparse


if torch.cuda.is_available():
    device = "cuda:0"
else:
    device = "cpu"

model = TSR.from_pretrained(
    "stabilityai/TripoSR",
    config_name="config.yaml",
    weight_name="model.ckpt",
)

# adjust the chunk size to balance between speed and memory usage
model.renderer.set_chunk_size(8192)
model.to(device)

rembg_session = rembg.new_session()




def preprocess(input_image, do_remove_background, foreground_ratio):
    def fill_background(image):
        image = np.array(image).astype(np.float32) / 255.0
        image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
        image = Image.fromarray((image * 255.0).astype(np.uint8))
        return image

    if do_remove_background:
        image = input_image.convert("RGB")
        image = remove_background(image, rembg_session)
        image = resize_foreground(image, foreground_ratio)
        image = fill_background(image)
    else:
        image = input_image
        if image.mode == "RGBA":
            image = fill_background(image)
    return image


def generate(image, mc_resolution, formats=["obj", "glb"], path="output.obj"):
    scene_codes = model(image, device=device)
    mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
    mesh = to_gradio_3d_orientation(mesh)
    rv = []
    for format in formats:
        mesh_path = path.replace(".obj", f".{format}")
        mesh.export(mesh_path)
        rv.append(mesh_path)
    return rv


def run_example(image_pil):
    preprocessed = preprocess(image_pil, False, 0.9)
    mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
    return preprocessed, mesh_name_obj, mesh_name_glb

def generate_obj_from_image(image_pil, path="output.obj"):
    try:
        # Preprocess the image without removing the background and with a foreground ratio of 0.9
        print("Preprocessing image")
        preprocessed = preprocess(image_pil, True, 0.9)
        print("Generating mesh")
        # Generate the mesh and get the paths to the .obj and .glb files
        mesh_paths = generate(preprocessed, 256, ["obj"], path)
    except Exception as e:
        print(f"Error generating mesh: {e}")
        return None
    
    # Return the path to the .obj file
    return mesh_paths[0]

if __name__ == "__main__":
    # run a test
    image_path = "output.png"
    image = Image.open(image_path)
    generate_obj_from_image(image, "output.obj")
    # move the .obj file to the output directory