|
import sys |
|
import spaces |
|
sys.path.append("flash3d") |
|
|
|
from omegaconf import OmegaConf |
|
import gradio as gr |
|
import torch |
|
import torchvision.transforms as TT |
|
import torchvision.transforms.functional as TTF |
|
from huggingface_hub import hf_hub_download |
|
import numpy as np |
|
from einops import rearrange |
|
|
|
from networks.gaussian_predictor import GaussianPredictor |
|
from util.vis3d import save_ply |
|
|
|
def main(): |
|
print("[INFO] Starting main function...") |
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda:0" |
|
print("[INFO] CUDA is available. Using GPU device.") |
|
else: |
|
device = "cpu" |
|
print("[INFO] CUDA is not available. Using CPU device.") |
|
|
|
|
|
print("[INFO] Downloading model configuration...") |
|
model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d", |
|
filename="config_re10k_v1.yaml") |
|
print("[INFO] Downloading model weights...") |
|
model_path = hf_hub_download(repo_id="einsafutdinov/flash3d", |
|
filename="model_re10k_v1.pth") |
|
|
|
|
|
print("[INFO] Loading model configuration...") |
|
cfg = OmegaConf.load(model_cfg_path) |
|
|
|
|
|
print("[INFO] Initializing GaussianPredictor model...") |
|
model = GaussianPredictor(cfg) |
|
try: |
|
device = torch.device(device) |
|
model.to(device) |
|
except Exception as e: |
|
print(f"[ERROR] Failed to set device: {e}") |
|
raise |
|
|
|
|
|
print("[INFO] Loading model weights...") |
|
model.load_model(model_path) |
|
|
|
|
|
pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug)) |
|
to_tensor = TT.ToTensor() |
|
|
|
|
|
def check_input_image(input_images): |
|
print("[DEBUG] Checking input images...") |
|
if not input_images or len(input_images) == 0: |
|
print("[ERROR] No images uploaded!") |
|
raise gr.Error("No images uploaded!") |
|
print("[INFO] Input images are valid.") |
|
|
|
|
|
def preprocess(images, padding_value): |
|
processed_images = [] |
|
for image in images: |
|
|
|
print("[DEBUG] Preprocessing image...") |
|
image = TTF.resize(image, (cfg.dataset.height, cfg.dataset.width), interpolation=TT.InterpolationMode.BICUBIC) |
|
pad_border_fn = TT.Pad((padding_value, padding_value)) |
|
image = pad_border_fn(image) |
|
print("[INFO] Image preprocessing complete.") |
|
processed_images.append(image) |
|
return processed_images |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def reconstruct_and_export(images, num_gauss): |
|
""" |
|
Passes images through model, outputs reconstruction in form of a dict of tensors. |
|
""" |
|
outputs_list = [] |
|
for image in images: |
|
print("[DEBUG] Starting reconstruction and export...") |
|
|
|
image = to_tensor(image).to(device).unsqueeze(0) |
|
inputs = { |
|
("color_aug", 0, 0): image, |
|
} |
|
|
|
|
|
print("[INFO] Passing image through the model...") |
|
outputs = model(inputs) |
|
outputs_list.append(outputs) |
|
|
|
|
|
|
|
gauss_means = outputs_list[0][('gauss_means', 0, 0)] |
|
if gauss_means.size(0) < num_gauss or gauss_means.size(0) % num_gauss != 0: |
|
adjusted_num_gauss = max(1, gauss_means.size(0) // (gauss_means.size(0) // num_gauss)) |
|
print(f"[WARNING] Adjusting num_gauss from {num_gauss} to {adjusted_num_gauss} to avoid shape mismatch.") |
|
num_gauss = adjusted_num_gauss |
|
|
|
|
|
print(f"[DEBUG] gauss_means tensor shape: {gauss_means.shape}") |
|
|
|
|
|
print(f"[INFO] Saving output to {ply_out_path}...") |
|
save_ply(outputs_list[0], ply_out_path, num_gauss=num_gauss) |
|
print("[INFO] Reconstruction and export complete.") |
|
|
|
return ply_out_path |
|
|
|
|
|
ply_out_path = f'./mesh.ply' |
|
|
|
|
|
css = """ |
|
h1 { |
|
text-align: center; |
|
display:block; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown( |
|
""" |
|
# Flash3D |
|
""" |
|
) |
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=1): |
|
with gr.Row(): |
|
|
|
input_images = gr.Images( |
|
label="Input Images", |
|
image_mode="RGBA", |
|
sources="upload", |
|
type="pil", |
|
elem_id="content_images", |
|
tool="editor", |
|
multiple=True |
|
) |
|
with gr.Row(): |
|
|
|
num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=1) |
|
padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32) |
|
with gr.Row(): |
|
|
|
submit = gr.Button("Generate", elem_id="generate", variant="primary") |
|
|
|
with gr.Row(variant="panel"): |
|
|
|
gr.Examples( |
|
examples=[ |
|
'./demo_examples/bedroom_01.png', |
|
'./demo_examples/kitti_02.png', |
|
'./demo_examples/kitti_03.png', |
|
'./demo_examples/re10k_04.jpg', |
|
'./demo_examples/re10k_05.jpg', |
|
'./demo_examples/re10k_06.jpg', |
|
], |
|
inputs=[input_images], |
|
cache_examples=False, |
|
label="Examples", |
|
examples_per_page=20, |
|
) |
|
|
|
with gr.Row(): |
|
|
|
processed_images = gr.Gallery(label="Processed Images", interactive=False) |
|
|
|
with gr.Column(scale=2): |
|
with gr.Row(): |
|
with gr.Tab("Reconstruction"): |
|
|
|
output_model = gr.Model3D( |
|
height=512, |
|
label="Output Model", |
|
interactive=False |
|
) |
|
|
|
|
|
submit.click(fn=check_input_image, inputs=[input_images]).success( |
|
fn=preprocess, |
|
inputs=[input_images, padding_value], |
|
outputs=[processed_images], |
|
).success( |
|
fn=reconstruct_and_export, |
|
inputs=[processed_images, num_gauss], |
|
outputs=[output_model], |
|
) |
|
|
|
|
|
demo.queue(max_size=1) |
|
print("[INFO] Launching Gradio demo...") |
|
demo.launch(share=True) |
|
|
|
if __name__ == "__main__": |
|
print("[INFO] Running application...") |
|
main() |