import gradio as gr import torch import numpy as np import requests from PIL import Image import torch from inference import predict, random_sample, overlay_images, post_process def main(): pp_options = [ "None", "Thresholding", "Closing", "Opening", "Canny Edge", "Gaussian Smoothing", "Hysteresis" ] def update_slider(post_process): visibility = [0 for option in pp_options] if post_process ==pp_options[0]: # None pass else: # Retrieve index of post_process assert post_process in pp_options index = pp_options.index(post_process) visibility[index] = 1 ret_updates = [] for vis in visibility: if vis == 1: ret_updates.append(gr.update(visible=True)) else: ret_updates.append(gr.update(visible=False)) return ret_updates with gr.Blocks() as demo: # Button to select task seismic_data = gr.State() prediction_data = gr.State() processed_prediction_data = gr.State() gr.Markdown("## SFM Inference Demo") gr.Markdown("### Select a task and run inference on seismic data") with gr.Row(): task = gr.Radio(choices=['Fault', 'Facies'], label="Select Task", value='Fault') gr.Markdown("### Upload your seismic data or sample from dataset") with gr.Row(): seismic_image = gr.Image(label="Seismic Data") prediction_image = gr.Image(label="Prediction Result") with gr.Row(): random_sample_button = gr.Button("Upload Random Sample", elem_id="random-sample-button") random_sample_button.click(fn=random_sample, inputs=[task], outputs=[seismic_image, seismic_data]) with gr.Row(): predict_button = gr.Button("Run Inference", elem_id="predict-button") predict_button.click(fn=predict, inputs=[seismic_data, task], outputs=[prediction_image, prediction_data]) processed_prediction_data = prediction_data with gr.Row(): overlay_image = gr.Image(label="Overlay Result") with gr.Column(): gr.Markdown("### Overlay Seismic Data with Prediction Result") overlay_button = gr.Button("Overlay Result", elem_id="overlay-button") overlay_button.click(fn=overlay_images, inputs=[seismic_image, prediction_image], outputs=[overlay_image]) gr.Markdown("### Post Processing") with gr.Row(): post_process = gr.Radio(choices=pp_options, value='None', elem_id="post-processing", label="Post Processing Method") slider_none = gr.Slider(minimum=0, maximum=255, value=128, label="None Value", visible=False) slider_thresh = gr.Slider(minimum=0, maximum=255, value=128, label="Threshold Value", visible=False) slider_close = gr.Slider(minimum=0, maximum=64, value=32, label="Closing Value", visible=False) slider_open = gr.Slider(minimum=0, maximum=64, value=32, label="Opening Value", visible=False) slider_canny = gr.Slider(minimum=0, maximum=255, value=128, label="Canny Edge Value", visible=False) slider_gauss = gr.Slider(minimum=0, maximum=255, value=128, label="Sigma", visible=False) slider_hyst = gr.Slider(minimum=0, maximum=255, value=128, label="Hysteresis Min Value", visible=False) post_process.change( fn=update_slider, inputs=[post_process], outputs=[slider_none, slider_thresh, slider_close, slider_open, slider_canny, slider_gauss, slider_hyst] ) gr.Button("Download Processed Image", elem_id="download-processed-button") demo.launch() if __name__ == "__main__": main()