File size: 7,321 Bytes
6db5fd9
 
 
 
 
 
 
 
 
 
 
 
 
d38e5ca
 
6db5fd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665b2f0
6db5fd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665b2f0
6db5fd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
from gs_train import train
import os

DATASET_DIR = "colmap_data"

def get_dataset_folders(datasets_path):
    try:
        return [f for f in os.listdir(datasets_path) if os.path.isdir(os.path.join(datasets_path, f))]
    except FileNotFoundError:
        return []

def gs_demo_tab(cache_path):
    # datasets_path = "/app/data/scenes/"
    dataset_path = os.path.join(cache_path, DATASET_DIR)
    def start_training(selected_folder, *args):
        selected_data_path = os.path.join(datasets_path, selected_folder)
        return train(selected_data_path, *args)
    
    def get_context():
        return gr.Blocks(delete_cache=(True, True))
    
    with get_context() as gs_demo:
        gr.Markdown("""
        <style>
        .fixed-size-video video {
            max-height: 400px !important;
            height: 400px !important;
            object-fit: contain;
        }
        </style>
        """)
        gr.Markdown("# Gaussian Splatting Training Demo")

        refresh_button = gr.Button("Refresh Datasets", elem_classes="refresh-button")
        dataset_dropdown = gr.Dropdown(label="Select Dataset", choices=[], value="")

        def update_dataset_dropdown():
            print("update_dataset_dropdown, cache_path", cache_path)
            # Update the dataset folders list
            dataset_folders = get_dataset_folders(dataset_path)
            # dataset_folders = "/app/data/scenes/"
            print("dataset_folders", dataset_folders)
            # Only set a default value if there are folders available
            default_value = dataset_folders[0] if dataset_folders else None
            return gr.Dropdown(label="Select Dataset", choices=dataset_folders, value=default_value)
        
        # Set the update function to be called when the refresh button is clicked
        refresh_button.click(fn=update_dataset_dropdown, inputs=None, outputs=dataset_dropdown)

        with gr.Accordion("Model Parameters", open=False):
            with gr.Row():
                with gr.Column():
                    sh_degree = gr.Number(label="SH Degree", value=3)
                    model_path = gr.Textbox(label="Model Path", value="")
                    images = gr.Textbox(label="Images", value="images")
                    resolution = gr.Number(label="Resolution", value=-1)
                    white_background = gr.Checkbox(label="White Background", value=True)
                    data_device = gr.Dropdown(label="Data Device", choices=["cuda", "cpu"], value="cuda")
                    eval = gr.Checkbox(label="Eval", value=False)
        
        with gr.Accordion("Pipeline Parameters", open=False):
            with gr.Row():
                with gr.Column():
                    convert_SHs_python = gr.Checkbox(label="Convert SHs Python", value=False)
                    compute_cov3D_python = gr.Checkbox(label="Compute Cov3D Python", value=False)
                    debug = gr.Checkbox(label="Debug", value=False)
        
        with gr.Accordion("Optimization Parameters", open=False):
            with gr.Row():
                with gr.Column():
                    iterations = gr.Number(label="Iterations", value=1000)
                    position_lr_init = gr.Number(label="Position LR Init", value=0.00016)
                    position_lr_final = gr.Number(label="Position LR Final", value=0.0000016)
                    position_lr_delay_mult = gr.Number(label="Position LR Delay Mult", value=0.01)
                    position_lr_max_steps = gr.Number(label="Position LR Max Steps", value=30000)
                with gr.Column():
                    feature_lr = gr.Number(label="Feature LR", value=0.0025)
                    opacity_lr = gr.Number(label="Opacity LR", value=0.05)
                    scaling_lr = gr.Number(label="Scaling LR", value=0.005)
                    rotation_lr = gr.Number(label="Rotation LR", value=0.001)
                    percent_dense = gr.Number(label="Percent Dense", value=0.01)
                with gr.Column():
                    lambda_dssim = gr.Number(label="Lambda DSSIM", value=0.2)
                    densification_interval = gr.Number(label="Densification Interval", value=100)
                    opacity_reset_interval = gr.Number(label="Opacity Reset Interval", value=3000)
                    densify_from_iter = gr.Number(label="Densify From Iter", value=500)
                    densify_until_iter = gr.Number(label="Densify Until Iter", value=15000)
                    densify_grad_threshold = gr.Number(label="Densify Grad Threshold", value=0.0002)
                    random_background = gr.Checkbox(label="Random Background", value=False)
        
        start_button = gr.Button("Start Training")
        
        # Add state variable to store model path
        model_path_state = gr.State()
        
        # Add video output and load model button with fixed scale
        video_output = gr.Video(
            label="Training Progress", 
            height=400,  # Fixed height
            width="100%",  # Full width of container
            autoplay=False,  # Prevent autoplay
            show_label=True,
            container=True,
            elem_classes="fixed-size-video"  # Add custom class for potential CSS
        )
        load_model_button = gr.Button("Load 3D Model", interactive=False)
        output = gr.Model3D(label="3D Model Output", visible=False)
        
        def handle_training_complete(selected_folder, *args):
            # Construct the full path to the selected dataset
            selected_data_path = os.path.join(dataset_path, selected_folder)
            # Call the training function with the full path
            video_path, model_path = train(selected_data_path, *args)
            # Then return all required outputs
            return [
                video_path,           # video output
                gr.Button(value="Load 3D Model", interactive=True),  # Return new button with updated properties
                gr.Model3D(visible=False),  # keep 3D model hidden
                model_path            # store model path in state
            ]
        
        def load_model(model_path):
            if not model_path:
                return gr.Model3D(visible=False)
            return gr.Model3D(value=model_path, visible=True)
        
        # Connect the start training button
        start_button.click(
            fn=handle_training_complete,
            inputs=[
                dataset_dropdown, sh_degree, model_path, images, resolution, white_background, data_device, eval,
                convert_SHs_python, compute_cov3D_python, debug,
                iterations, position_lr_init, position_lr_final, position_lr_delay_mult,
                position_lr_max_steps, feature_lr, opacity_lr, scaling_lr, rotation_lr,
                percent_dense, lambda_dssim, densification_interval, opacity_reset_interval,
                densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
            ],
            outputs=[video_output, load_model_button, output, model_path_state]
        )
        
        # Connect the load model button
        load_model_button.click(
            fn=load_model,
            inputs=[model_path_state],
            outputs=output
        )
    return gs_demo