mrdas commited on
Commit
4026aa0
·
verified ·
1 Parent(s): 5b8bec7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import tempfile
4
+ import time
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import rembg
9
+ import torch
10
+ from PIL import Image
11
+ from functools import partial
12
+
13
+ from tsr.system import TSR
14
+ from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
15
+
16
+ import argparse
17
+
18
+
19
+ if torch.cuda.is_available():
20
+ device = "cuda:0"
21
+ else:
22
+ device = "cpu"
23
+
24
+ model = TSR.from_pretrained(
25
+ "stabilityai/TripoSR",
26
+ config_name="config.yaml",
27
+ weight_name="model.ckpt",
28
+ )
29
+
30
+ # adjust the chunk size to balance between speed and memory usage
31
+ model.renderer.set_chunk_size(8192)
32
+ model.to(device)
33
+
34
+ rembg_session = rembg.new_session()
35
+
36
+
37
+ def check_input_image(input_image):
38
+ if input_image is None:
39
+ raise gr.Error("No image uploaded!")
40
+
41
+
42
+ def preprocess(input_image, do_remove_background, foreground_ratio):
43
+ def fill_background(image):
44
+ image = np.array(image).astype(np.float32) / 255.0
45
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
46
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
47
+ return image
48
+
49
+ if do_remove_background:
50
+ image = input_image.convert("RGB")
51
+ image = remove_background(image, rembg_session)
52
+ image = resize_foreground(image, foreground_ratio)
53
+ image = fill_background(image)
54
+ else:
55
+ image = input_image
56
+ if image.mode == "RGBA":
57
+ image = fill_background(image)
58
+ return image
59
+
60
+
61
+ def generate(image, mc_resolution, formats=["obj", "glb"]):
62
+ scene_codes = model(image, device=device)
63
+ mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
64
+ mesh = to_gradio_3d_orientation(mesh)
65
+ rv = []
66
+ for format in formats:
67
+ mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
68
+ mesh.export(mesh_path.name)
69
+ rv.append(mesh_path.name)
70
+ return rv
71
+
72
+
73
+ def run_example(image_pil):
74
+ preprocessed = preprocess(image_pil, False, 0.9)
75
+ mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
76
+ return preprocessed, mesh_name_obj, mesh_name_glb
77
+
78
+
79
+ with gr.Blocks(title="TripoSR") as interface:
80
+ gr.Markdown(
81
+ """
82
+ # TripoSR Demo
83
+ [TripoSR](https://github.com/VAST-AI-Research/TripoSR) is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
84
+
85
+ **Tips:**
86
+ 1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
87
+ 2. You can disable "Remove Background" for the provided examples since they have been already preprocessed.
88
+ 3. Otherwise, please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
89
+ """
90
+ )
91
+ with gr.Row(variant="panel"):
92
+ with gr.Column():
93
+ with gr.Row():
94
+ input_image = gr.Image(
95
+ label="Input Image",
96
+ image_mode="RGBA",
97
+ sources="upload",
98
+ type="pil",
99
+ elem_id="content_image",
100
+ )
101
+ processed_image = gr.Image(label="Processed Image", interactive=False)
102
+ with gr.Row():
103
+ with gr.Group():
104
+ do_remove_background = gr.Checkbox(
105
+ label="Remove Background", value=True
106
+ )
107
+ foreground_ratio = gr.Slider(
108
+ label="Foreground Ratio",
109
+ minimum=0.5,
110
+ maximum=1.0,
111
+ value=0.85,
112
+ step=0.05,
113
+ )
114
+ mc_resolution = gr.Slider(
115
+ label="Marching Cubes Resolution",
116
+ minimum=32,
117
+ maximum=320,
118
+ value=256,
119
+ step=32
120
+ )
121
+ with gr.Row():
122
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
123
+ with gr.Column():
124
+ with gr.Tab("OBJ"):
125
+ output_model_obj = gr.Model3D(
126
+ label="Output Model (OBJ Format)",
127
+ interactive=False,
128
+ )
129
+ gr.Markdown("Note: The model shown here is flipped. Download to get correct results.")
130
+ with gr.Tab("GLB"):
131
+ output_model_glb = gr.Model3D(
132
+ label="Output Model (GLB Format)",
133
+ interactive=False,
134
+ )
135
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
136
+ with gr.Row(variant="panel"):
137
+ gr.Examples(
138
+ examples=[
139
+ "examples/hamburger.png",
140
+ "examples/poly_fox.png",
141
+ "examples/robot.png",
142
+ "examples/teapot.png",
143
+ "examples/tiger_girl.png",
144
+ "examples/horse.png",
145
+ "examples/flamingo.png",
146
+ "examples/unicorn.png",
147
+ "examples/chair.png",
148
+ "examples/iso_house.png",
149
+ "examples/marble.png",
150
+ "examples/police_woman.png",
151
+ "examples/captured_p.png",
152
+ ],
153
+ inputs=[input_image],
154
+ outputs=[processed_image, output_model_obj, output_model_glb],
155
+ cache_examples=False,
156
+ fn=partial(run_example),
157
+ label="Examples",
158
+ examples_per_page=20,
159
+ )
160
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
161
+ fn=preprocess,
162
+ inputs=[input_image, do_remove_background, foreground_ratio],
163
+ outputs=[processed_image],
164
+ ).success(
165
+ fn=generate,
166
+ inputs=[processed_image, mc_resolution],
167
+ outputs=[output_model_obj, output_model_glb],
168
+ )
169
+
170
+
171
+
172
+ if __name__ == '__main__':
173
+ parser = argparse.ArgumentParser()
174
+ parser.add_argument('--username', type=str, default=None, help='Username for authentication')
175
+ parser.add_argument('--password', type=str, default=None, help='Password for authentication')
176
+ parser.add_argument('--port', type=int, default=8005, help='Port to run the server listener on')
177
+ parser.add_argument("--listen", action='store_true', default="0.0.0.0", help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
178
+ parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
179
+ parser.add_argument("--queuesize", type=int, default=1, help="launch gradio queue max_size")
180
+ args = parser.parse_args()
181
+ interface.queue(max_size=args.queuesize)
182
+ interface.launch(
183
+ auth=(args.username, args.password) if (args.username and args.password) else None,
184
+ share=args.share,
185
+ server_name="0.0.0.0" if args.listen else None,
186
+ server_port=args.port
187
+ )