dmitriitochilkin commited on
Commit
7e0376e
1 Parent(s): 4395771
Files changed (1) hide show
  1. gradio_app.py +115 -0
gradio_app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import tempfile
4
+ import time
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import rembg
9
+ import torch
10
+ from PIL import Image
11
+
12
+ from tsr.system import TSR
13
+ from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
14
+
15
+ if torch.cuda.is_available():
16
+ device = "cuda:0"
17
+ else:
18
+ device = "cpu"
19
+
20
+ model = TSR.from_pretrained(
21
+ "stabilityai/TripoSR",
22
+ config_name="config.yaml",
23
+ weight_name="model.ckpt",
24
+ )
25
+ model.to(device)
26
+
27
+ rembg_session = rembg.new_session()
28
+
29
+
30
+ def check_input_image(input_image):
31
+ if input_image is None:
32
+ raise gr.Error("No image uploaded!")
33
+
34
+
35
+ def preprocess(image_path, do_remove_background, foreground_ratio):
36
+ def fill_background(image):
37
+ image = np.array(image).astype(np.float32) / 255.0
38
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
39
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
40
+ return image
41
+
42
+ if do_remove_background:
43
+ image = remove_background(Image.open(image_path), rembg_session)
44
+ image = resize_foreground(image, foreground_ratio)
45
+ image = fill_background(image)
46
+ else:
47
+ image = Image.open(image_path)
48
+ if image.mode == "RGBA":
49
+ image = fill_background(image)
50
+ return image
51
+
52
+
53
+ def generate(image):
54
+ scene_codes = model(image, device=device)
55
+ mesh = model.extract_mesh(scene_codes)[0]
56
+ mesh.vertices = to_gradio_3d_orientation(mesh.vertices)
57
+ mesh_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
58
+ mesh.export(mesh_path.name)
59
+ return mesh_path.name
60
+
61
+
62
+ with gr.Blocks() as demo:
63
+ gr.Markdown(
64
+ """
65
+ ## TripoSR Demo
66
+ [TripoSR](https://github.com/VAST-AI-Research/TripoSR) is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
67
+ """
68
+ )
69
+ with gr.Row(variant="panel"):
70
+ with gr.Column():
71
+ with gr.Row():
72
+ input_image = gr.Image(
73
+ label="Input Image",
74
+ sources="upload",
75
+ type="filepath",
76
+ elem_id="content_image",
77
+ )
78
+ processed_image = gr.Image(label="Processed Image", interactive=False)
79
+ with gr.Row():
80
+ with gr.Group():
81
+ do_remove_background = gr.Checkbox(
82
+ label="Remove Background", value=True
83
+ )
84
+ foreground_ratio = gr.Slider(
85
+ label="Foreground Ratio",
86
+ minimum=0.5,
87
+ maximum=1.0,
88
+ value=0.85,
89
+ step=0.05,
90
+ )
91
+ with gr.Row():
92
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
93
+ with gr.Column():
94
+ with gr.Tab("Model"):
95
+ output_model = gr.Model3D(
96
+ label="Output Model",
97
+ interactive=False,
98
+ )
99
+ gr.Markdown(
100
+ """
101
+ Note: The model shown here will be flipped due to some visualization issues. Please download to get the correct result.
102
+ """
103
+ )
104
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
105
+ fn=preprocess,
106
+ inputs=[input_image, do_remove_background, foreground_ratio],
107
+ outputs=[processed_image],
108
+ ).success(
109
+ fn=generate,
110
+ inputs=[processed_image],
111
+ outputs=[output_model],
112
+ )
113
+
114
+ demo.queue(max_size=1)
115
+ demo.launch()