michaelj commited on
Commit
23f5383
1 Parent(s): 946d0a6

, resolution=mc_resolution

Browse files
Files changed (1) hide show
  1. app.py +54 -52
app.py CHANGED
@@ -2,37 +2,60 @@ import logging
2
  import os
3
  import tempfile
4
  import time
 
5
  import gradio as gr
6
  import numpy as np
7
  import rembg
8
  import torch
9
  from PIL import Image
10
  from functools import partial
 
11
  from tsr.system import TSR
12
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
13
- import argparse
 
 
 
 
 
 
 
 
 
 
14
  if torch.cuda.is_available():
15
  device = "cuda:0"
16
  else:
17
  device = "cpu"
 
 
 
 
 
18
  model = TSR.from_pretrained(
19
  "stabilityai/TripoSR",
20
  config_name="config.yaml",
21
  weight_name="model.ckpt",
 
22
  )
23
- # adjust the chunk size to balance between speed and memory usage
24
- model.renderer.set_chunk_size(8192)
25
  model.to(device)
 
26
  rembg_session = rembg.new_session()
 
 
27
  def check_input_image(input_image):
28
  if input_image is None:
29
  raise gr.Error("No image uploaded!")
 
 
30
  def preprocess(input_image, do_remove_background, foreground_ratio):
31
  def fill_background(image):
32
  image = np.array(image).astype(np.float32) / 255.0
33
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
34
  image = Image.fromarray((image * 255.0).astype(np.uint8))
35
  return image
 
36
  if do_remove_background:
37
  image = input_image.convert("RGB")
38
  image = remove_background(image, rembg_session)
@@ -43,26 +66,25 @@ def preprocess(input_image, do_remove_background, foreground_ratio):
43
  if image.mode == "RGBA":
44
  image = fill_background(image)
45
  return image
46
- def generate(image, mc_resolution, formats=["obj", "glb"]):
 
 
47
  scene_codes = model(image, device=device)
48
  mesh = model.extract_mesh(scene_codes, resolution=1024)[0]
49
  mesh = to_gradio_3d_orientation(mesh)
50
- rv = []
51
- for format in formats:
52
- mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
53
- mesh.export(mesh_path.name)
54
- rv.append(mesh_path.name)
55
- return rv
56
  def run_example(image_pil):
57
  preprocessed = preprocess(image_pil, False, 0.9)
58
- mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
59
- return preprocessed, mesh_name_obj, mesh_name_glb
60
- with gr.Blocks(title="TripoSR") as demo:
61
- gr.Markdown(
62
- """
63
- 图像生成3d模型
64
- """
65
- )
66
  with gr.Row(variant="panel"):
67
  with gr.Column():
68
  with gr.Row():
@@ -86,51 +108,30 @@ with gr.Blocks(title="TripoSR") as demo:
86
  value=0.85,
87
  step=0.05,
88
  )
89
- mc_resolution = gr.Slider(
90
- label="Marching Cubes Resolution",
91
- minimum=32,
92
- maximum=1024,
93
- value=256,
94
- step=32
95
- )
96
  with gr.Row():
97
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
98
  with gr.Column():
99
- with gr.Tab("OBJ"):
100
- output_model_obj = gr.Model3D(
101
- label="Output Model (OBJ Format)",
102
  interactive=False,
103
  )
104
- gr.Markdown("Note: The model shown here is flipped. Download to get correct results.")
105
- with gr.Tab("GLB"):
106
- output_model_glb = gr.Model3D(
107
- label="Output Model (GLB Format)",
108
  interactive=False,
109
  )
110
- gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
111
  with gr.Row(variant="panel"):
112
  gr.Examples(
113
  examples=[
114
- "examples/hamburger.png",
115
- "examples/poly_fox.png",
116
- "examples/robot.png",
117
- "examples/teapot.png",
118
- "examples/tiger_girl.png",
119
- "examples/horse.png",
120
- "examples/flamingo.png",
121
- "examples/unicorn.png",
122
- "examples/chair.png",
123
- "examples/iso_house.png",
124
- "examples/marble.png",
125
- "examples/police_woman.png",
126
- "examples/captured.jpeg",
127
  ],
128
  inputs=[input_image],
129
- outputs=[processed_image, output_model_obj, output_model_glb],
130
- cache_examples=False,
131
  fn=partial(run_example),
132
  label="Examples",
133
- examples_per_page=20,
134
  )
135
  submit.click(fn=check_input_image, inputs=[input_image]).success(
136
  fn=preprocess,
@@ -138,8 +139,9 @@ with gr.Blocks(title="TripoSR") as demo:
138
  outputs=[processed_image],
139
  ).success(
140
  fn=generate,
141
- inputs=[processed_image, mc_resolution],
142
- outputs=[output_model_obj, output_model_glb],
143
  )
 
144
  demo.queue(max_size=10)
145
- demo.launch()
 
2
  import os
3
  import tempfile
4
  import time
5
+
6
  import gradio as gr
7
  import numpy as np
8
  import rembg
9
  import torch
10
  from PIL import Image
11
  from functools import partial
12
+
13
  from tsr.system import TSR
14
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
15
+
16
+ #HF_TOKEN = os.getenv("HF_TOKEN")
17
+
18
+ HEADER = """
19
+ **TripoSR** is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, developed in collaboration between [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
20
+ **Tips:**
21
+ 1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
22
+ 2. Please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
23
+ """
24
+
25
+
26
  if torch.cuda.is_available():
27
  device = "cuda:0"
28
  else:
29
  device = "cpu"
30
+
31
+ d = os.environ.get("DEVICE", None)
32
+ if d != None:
33
+ device = d
34
+
35
  model = TSR.from_pretrained(
36
  "stabilityai/TripoSR",
37
  config_name="config.yaml",
38
  weight_name="model.ckpt",
39
+ # token=HF_TOKEN
40
  )
41
+ model.renderer.set_chunk_size(131072)
 
42
  model.to(device)
43
+
44
  rembg_session = rembg.new_session()
45
+
46
+
47
  def check_input_image(input_image):
48
  if input_image is None:
49
  raise gr.Error("No image uploaded!")
50
+
51
+
52
  def preprocess(input_image, do_remove_background, foreground_ratio):
53
  def fill_background(image):
54
  image = np.array(image).astype(np.float32) / 255.0
55
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
56
  image = Image.fromarray((image * 255.0).astype(np.uint8))
57
  return image
58
+
59
  if do_remove_background:
60
  image = input_image.convert("RGB")
61
  image = remove_background(image, rembg_session)
 
66
  if image.mode == "RGBA":
67
  image = fill_background(image)
68
  return image
69
+
70
+
71
+ def generate(image):
72
  scene_codes = model(image, device=device)
73
  mesh = model.extract_mesh(scene_codes, resolution=1024)[0]
74
  mesh = to_gradio_3d_orientation(mesh)
75
+ mesh_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
76
+ mesh_path2 = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
77
+ mesh.export(mesh_path.name)
78
+ mesh.export(mesh_path2.name)
79
+ return mesh_path.name, mesh_path2.name
80
+
81
  def run_example(image_pil):
82
  preprocessed = preprocess(image_pil, False, 0.9)
83
+ mesh_name, mesn_name2 = generate(preprocessed)
84
+ return preprocessed, mesh_name, mesh_name2
85
+
86
+ with gr.Blocks() as demo:
87
+ gr.Markdown(HEADER)
 
 
 
88
  with gr.Row(variant="panel"):
89
  with gr.Column():
90
  with gr.Row():
 
108
  value=0.85,
109
  step=0.05,
110
  )
 
 
 
 
 
 
 
111
  with gr.Row():
112
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
113
  with gr.Column():
114
+ with gr.Tab("obj"):
115
+ output_model = gr.Model3D(
116
+ label="Output Model",
117
  interactive=False,
118
  )
119
+ with gr.Tab("glb"):
120
+ output_model2 = gr.Model3D(
121
+ label="Output Model",
 
122
  interactive=False,
123
  )
 
124
  with gr.Row(variant="panel"):
125
  gr.Examples(
126
  examples=[
127
+ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
 
 
 
 
 
 
 
 
 
 
 
 
128
  ],
129
  inputs=[input_image],
130
+ outputs=[processed_image, output_model, output_model2],
131
+ #cache_examples=True,
132
  fn=partial(run_example),
133
  label="Examples",
134
+ examples_per_page=20
135
  )
136
  submit.click(fn=check_input_image, inputs=[input_image]).success(
137
  fn=preprocess,
 
139
  outputs=[processed_image],
140
  ).success(
141
  fn=generate,
142
+ inputs=[processed_image],
143
+ outputs=[output_model, output_model2],
144
  )
145
+
146
  demo.queue(max_size=10)
147
+ demo.launch()