Ryukijano commited on
Commit
7b6840d
1 Parent(s): bbd1db4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -19
app.py CHANGED
@@ -25,9 +25,11 @@ def main():
25
 
26
  # Download model configuration and weights from Hugging Face Hub
27
  print("[INFO] Downloading model configuration...")
28
- model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="config_re10k_v1.yaml")
 
29
  print("[INFO] Downloading model weights...")
30
- model_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="model_re10k_v1.pth")
 
31
 
32
  # Load model configuration using OmegaConf
33
  print("[INFO] Loading model configuration...")
@@ -59,7 +61,10 @@ def main():
59
  def preprocess(image):
60
  print("[DEBUG] Preprocessing image...")
61
  # Resize the image to the desired height and width specified in the configuration
62
- image = TTF.resize(image, (cfg.dataset.height, cfg.dataset.width), interpolation=TT.InterpolationMode.BICUBIC)
 
 
 
63
  # Apply padding to the image
64
  image = pad_border_fn(image)
65
  print("[INFO] Image preprocessing complete.")
@@ -67,15 +72,16 @@ def main():
67
 
68
  # Function to reconstruct the 3D model from the input image and export it as a PLY file
69
  @spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
70
- def reconstruct_and_export(image, num_gauss, batch_size, num_iterations):
 
 
 
71
  print("[DEBUG] Starting reconstruction and export...")
72
  # Convert the preprocessed image to a tensor and move it to the specified device
73
  image = to_tensor(image).to(device).unsqueeze(0)
74
- inputs = {("color_aug", 0, 0): image}
75
-
76
- # Set the batch size and number of iterations in the model configuration
77
- model.cfg.dataset.batch_size = batch_size
78
- model.cfg.training.num_iterations = num_iterations
79
 
80
  # Pass the image through the model to get the output
81
  print("[INFO] Passing image through the model...")
@@ -83,11 +89,11 @@ def main():
83
 
84
  # Export the reconstruction to a PLY file
85
  print(f"[INFO] Saving output to {ply_out_path}...")
86
- save_ply(outputs, ply_out_path, num_gauss=num_gauss)
87
  print("[INFO] Reconstruction and export complete.")
88
 
89
  return ply_out_path
90
-
91
  # Path to save the output PLY file
92
  ply_out_path = f'./mesh.ply'
93
 
@@ -101,15 +107,26 @@ def main():
101
 
102
  # Create the Gradio user interface
103
  with gr.Blocks(css=css) as demo:
104
- gr.Markdown("# Flash3D")
 
 
 
 
105
  with gr.Row(variant="panel"):
106
  with gr.Column(scale=1):
107
  with gr.Row():
108
  # Input image component for the user to upload an image
109
- input_image = gr.Image(label="Input Image", image_mode="RGBA", sources="upload", type="pil", elem_id="content_image")
 
 
 
 
 
 
110
  with gr.Row():
111
  # Button to trigger the generation process
112
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
 
113
  with gr.Row(variant="panel"):
114
  # Examples panel to provide sample images for users
115
  gr.Examples(
@@ -126,18 +143,20 @@ def main():
126
  label="Examples",
127
  examples_per_page=20,
128
  )
 
129
  with gr.Row():
130
  # Display the preprocessed image (after resizing and padding)
131
  processed_image = gr.Image(label="Processed Image", interactive=False)
 
132
  with gr.Column(scale=2):
133
  with gr.Row():
134
  with gr.Tab("Reconstruction"):
135
  # 3D model viewer to display the reconstructed model
136
- output_model = gr.Model3D(height=512, label="Output Model", interactive=False)
137
- with gr.Row():
138
- num_gauss = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Gaussian Components", value=2)
139
- batch_size = gr.Slider(minimum=1, maximum=32, step=1, label="Batch Size", value=1)
140
- num_iterations = gr.Slider(minimum=1, maximum=1000, step=10, label="Number of Iterations", value=100)
141
 
142
  # Define the workflow for the Generate button
143
  submit.click(fn=check_input_image, inputs=[input_image]).success(
@@ -146,7 +165,7 @@ def main():
146
  outputs=[processed_image],
147
  ).success(
148
  fn=reconstruct_and_export,
149
- inputs=[processed_image, num_gauss, batch_size, num_iterations],
150
  outputs=[output_model],
151
  )
152
 
 
25
 
26
  # Download model configuration and weights from Hugging Face Hub
27
  print("[INFO] Downloading model configuration...")
28
+ model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
29
+ filename="config_re10k_v1.yaml")
30
  print("[INFO] Downloading model weights...")
31
+ model_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
32
+ filename="model_re10k_v1.pth")
33
 
34
  # Load model configuration using OmegaConf
35
  print("[INFO] Loading model configuration...")
 
61
  def preprocess(image):
62
  print("[DEBUG] Preprocessing image...")
63
  # Resize the image to the desired height and width specified in the configuration
64
+ image = TTF.resize(
65
+ image, (cfg.dataset.height, cfg.dataset.width),
66
+ interpolation=TT.InterpolationMode.BICUBIC
67
+ )
68
  # Apply padding to the image
69
  image = pad_border_fn(image)
70
  print("[INFO] Image preprocessing complete.")
 
72
 
73
  # Function to reconstruct the 3D model from the input image and export it as a PLY file
74
  @spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
75
+ def reconstruct_and_export(image):
76
+ """
77
+ Passes image through model, outputs reconstruction in form of a dict of tensors.
78
+ """
79
  print("[DEBUG] Starting reconstruction and export...")
80
  # Convert the preprocessed image to a tensor and move it to the specified device
81
  image = to_tensor(image).to(device).unsqueeze(0)
82
+ inputs = {
83
+ ("color_aug", 0, 0): image,
84
+ }
 
 
85
 
86
  # Pass the image through the model to get the output
87
  print("[INFO] Passing image through the model...")
 
89
 
90
  # Export the reconstruction to a PLY file
91
  print(f"[INFO] Saving output to {ply_out_path}...")
92
+ save_ply(outputs, ply_out_path, num_gauss=2)
93
  print("[INFO] Reconstruction and export complete.")
94
 
95
  return ply_out_path
96
+
97
  # Path to save the output PLY file
98
  ply_out_path = f'./mesh.ply'
99
 
 
107
 
108
  # Create the Gradio user interface
109
  with gr.Blocks(css=css) as demo:
110
+ gr.Markdown(
111
+ """
112
+ # Flash3D
113
+ """
114
+ )
115
  with gr.Row(variant="panel"):
116
  with gr.Column(scale=1):
117
  with gr.Row():
118
  # Input image component for the user to upload an image
119
+ input_image = gr.Image(
120
+ label="Input Image",
121
+ image_mode="RGBA",
122
+ sources="upload",
123
+ type="pil",
124
+ elem_id="content_image",
125
+ )
126
  with gr.Row():
127
  # Button to trigger the generation process
128
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
129
+
130
  with gr.Row(variant="panel"):
131
  # Examples panel to provide sample images for users
132
  gr.Examples(
 
143
  label="Examples",
144
  examples_per_page=20,
145
  )
146
+
147
  with gr.Row():
148
  # Display the preprocessed image (after resizing and padding)
149
  processed_image = gr.Image(label="Processed Image", interactive=False)
150
+
151
  with gr.Column(scale=2):
152
  with gr.Row():
153
  with gr.Tab("Reconstruction"):
154
  # 3D model viewer to display the reconstructed model
155
+ output_model = gr.Model3D(
156
+ height=512,
157
+ label="Output Model",
158
+ interactive=False
159
+ )
160
 
161
  # Define the workflow for the Generate button
162
  submit.click(fn=check_input_image, inputs=[input_image]).success(
 
165
  outputs=[processed_image],
166
  ).success(
167
  fn=reconstruct_and_export,
168
+ inputs=[processed_image],
169
  outputs=[output_model],
170
  )
171