Ryukijano commited on
Commit
bbd1db4
1 Parent(s): 8403619

Update app.py

Browse files

Exposed a few parameters :)

Files changed (1) hide show
  1. app.py +19 -38
app.py CHANGED
@@ -25,11 +25,9 @@ def main():
25
 
26
  # Download model configuration and weights from Hugging Face Hub
27
  print("[INFO] Downloading model configuration...")
28
- model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
29
- filename="config_re10k_v1.yaml")
30
  print("[INFO] Downloading model weights...")
31
- model_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
32
- filename="model_re10k_v1.pth")
33
 
34
  # Load model configuration using OmegaConf
35
  print("[INFO] Loading model configuration...")
@@ -61,10 +59,7 @@ def main():
61
  def preprocess(image):
62
  print("[DEBUG] Preprocessing image...")
63
  # Resize the image to the desired height and width specified in the configuration
64
- image = TTF.resize(
65
- image, (cfg.dataset.height, cfg.dataset.width),
66
- interpolation=TT.InterpolationMode.BICUBIC
67
- )
68
  # Apply padding to the image
69
  image = pad_border_fn(image)
70
  print("[INFO] Image preprocessing complete.")
@@ -72,16 +67,15 @@ def main():
72
 
73
  # Function to reconstruct the 3D model from the input image and export it as a PLY file
74
  @spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
75
- def reconstruct_and_export(image):
76
- """
77
- Passes image through model, outputs reconstruction in form of a dict of tensors.
78
- """
79
  print("[DEBUG] Starting reconstruction and export...")
80
  # Convert the preprocessed image to a tensor and move it to the specified device
81
  image = to_tensor(image).to(device).unsqueeze(0)
82
- inputs = {
83
- ("color_aug", 0, 0): image,
84
- }
 
 
85
 
86
  # Pass the image through the model to get the output
87
  print("[INFO] Passing image through the model...")
@@ -89,11 +83,11 @@ def main():
89
 
90
  # Export the reconstruction to a PLY file
91
  print(f"[INFO] Saving output to {ply_out_path}...")
92
- save_ply(outputs, ply_out_path, num_gauss=2)
93
  print("[INFO] Reconstruction and export complete.")
94
 
95
  return ply_out_path
96
-
97
  # Path to save the output PLY file
98
  ply_out_path = f'./mesh.ply'
99
 
@@ -107,26 +101,15 @@ def main():
107
 
108
  # Create the Gradio user interface
109
  with gr.Blocks(css=css) as demo:
110
- gr.Markdown(
111
- """
112
- # Flash3D
113
- """
114
- )
115
  with gr.Row(variant="panel"):
116
  with gr.Column(scale=1):
117
  with gr.Row():
118
  # Input image component for the user to upload an image
119
- input_image = gr.Image(
120
- label="Input Image",
121
- image_mode="RGBA",
122
- sources="upload",
123
- type="pil",
124
- elem_id="content_image",
125
- )
126
  with gr.Row():
127
  # Button to trigger the generation process
128
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
129
-
130
  with gr.Row(variant="panel"):
131
  # Examples panel to provide sample images for users
132
  gr.Examples(
@@ -143,20 +126,18 @@ def main():
143
  label="Examples",
144
  examples_per_page=20,
145
  )
146
-
147
  with gr.Row():
148
  # Display the preprocessed image (after resizing and padding)
149
  processed_image = gr.Image(label="Processed Image", interactive=False)
150
-
151
  with gr.Column(scale=2):
152
  with gr.Row():
153
  with gr.Tab("Reconstruction"):
154
  # 3D model viewer to display the reconstructed model
155
- output_model = gr.Model3D(
156
- height=512,
157
- label="Output Model",
158
- interactive=False
159
- )
160
 
161
  # Define the workflow for the Generate button
162
  submit.click(fn=check_input_image, inputs=[input_image]).success(
@@ -165,7 +146,7 @@ def main():
165
  outputs=[processed_image],
166
  ).success(
167
  fn=reconstruct_and_export,
168
- inputs=[processed_image],
169
  outputs=[output_model],
170
  )
171
 
 
25
 
26
  # Download model configuration and weights from Hugging Face Hub
27
  print("[INFO] Downloading model configuration...")
28
+ model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="config_re10k_v1.yaml")
 
29
  print("[INFO] Downloading model weights...")
30
+ model_path = hf_hub_download(repo_id="einsafutdinov/flash3d", filename="model_re10k_v1.pth")
 
31
 
32
  # Load model configuration using OmegaConf
33
  print("[INFO] Loading model configuration...")
 
59
  def preprocess(image):
60
  print("[DEBUG] Preprocessing image...")
61
  # Resize the image to the desired height and width specified in the configuration
62
+ image = TTF.resize(image, (cfg.dataset.height, cfg.dataset.width), interpolation=TT.InterpolationMode.BICUBIC)
 
 
 
63
  # Apply padding to the image
64
  image = pad_border_fn(image)
65
  print("[INFO] Image preprocessing complete.")
 
67
 
68
  # Function to reconstruct the 3D model from the input image and export it as a PLY file
69
  @spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
70
+ def reconstruct_and_export(image, num_gauss, batch_size, num_iterations):
 
 
 
71
  print("[DEBUG] Starting reconstruction and export...")
72
  # Convert the preprocessed image to a tensor and move it to the specified device
73
  image = to_tensor(image).to(device).unsqueeze(0)
74
+ inputs = {("color_aug", 0, 0): image}
75
+
76
+ # Set the batch size and number of iterations in the model configuration
77
+ model.cfg.dataset.batch_size = batch_size
78
+ model.cfg.training.num_iterations = num_iterations
79
 
80
  # Pass the image through the model to get the output
81
  print("[INFO] Passing image through the model...")
 
83
 
84
  # Export the reconstruction to a PLY file
85
  print(f"[INFO] Saving output to {ply_out_path}...")
86
+ save_ply(outputs, ply_out_path, num_gauss=num_gauss)
87
  print("[INFO] Reconstruction and export complete.")
88
 
89
  return ply_out_path
90
+
91
  # Path to save the output PLY file
92
  ply_out_path = f'./mesh.ply'
93
 
 
101
 
102
  # Create the Gradio user interface
103
  with gr.Blocks(css=css) as demo:
104
+ gr.Markdown("# Flash3D")
 
 
 
 
105
  with gr.Row(variant="panel"):
106
  with gr.Column(scale=1):
107
  with gr.Row():
108
  # Input image component for the user to upload an image
109
+ input_image = gr.Image(label="Input Image", image_mode="RGBA", sources="upload", type="pil", elem_id="content_image")
 
 
 
 
 
 
110
  with gr.Row():
111
  # Button to trigger the generation process
112
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
 
113
  with gr.Row(variant="panel"):
114
  # Examples panel to provide sample images for users
115
  gr.Examples(
 
126
  label="Examples",
127
  examples_per_page=20,
128
  )
 
129
  with gr.Row():
130
  # Display the preprocessed image (after resizing and padding)
131
  processed_image = gr.Image(label="Processed Image", interactive=False)
 
132
  with gr.Column(scale=2):
133
  with gr.Row():
134
  with gr.Tab("Reconstruction"):
135
  # 3D model viewer to display the reconstructed model
136
+ output_model = gr.Model3D(height=512, label="Output Model", interactive=False)
137
+ with gr.Row():
138
+ num_gauss = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Gaussian Components", value=2)
139
+ batch_size = gr.Slider(minimum=1, maximum=32, step=1, label="Batch Size", value=1)
140
+ num_iterations = gr.Slider(minimum=1, maximum=1000, step=10, label="Number of Iterations", value=100)
141
 
142
  # Define the workflow for the Generate button
143
  submit.click(fn=check_input_image, inputs=[input_image]).success(
 
146
  outputs=[processed_image],
147
  ).success(
148
  fn=reconstruct_and_export,
149
+ inputs=[processed_image, num_gauss, batch_size, num_iterations],
150
  outputs=[output_model],
151
  )
152