# Gradio App Code (based on paste.txt) with Triton Integration and Fallback import psutil import gradio as gr import numpy as np import os import cv2 import matplotlib.pyplot as plt from huggingface_hub import snapshot_download import rasterio from rasterio.enums import Resampling from rasterio.plot import reshape_as_image import sys import time # For potential timeouts/delays # --- Triton Client Imports --- try: import tritonclient.http as httpclient import tritonclient.utils as triton_utils # For InferenceServerException TRITON_CLIENT_AVAILABLE = True except ImportError: print("WARNING: tritonclient is not installed. Triton inference will not be available.") print("Install using: pip install tritonclient[all]") TRITON_CLIENT_AVAILABLE = False httpclient = None # Define dummy to avoid NameErrors later triton_utils = None # --- Configuration --- # Download the entire repository for local fallback and utils repo_id = "truthdotphd/cloud-detection" repo_subdir = "." print(f"Downloading/Checking Hugging Face repo '{repo_id}'...") repo_dir = snapshot_download(repo_id=repo_id, local_dir=repo_subdir, local_dir_use_symlinks=False) # Use False for symlinks in Gradio/Docker usually print(f"Repo downloaded/cached at: {repo_dir}") # Add the repository directory to the Python path for local modules sys.path.append(repo_dir) # Import the necessary functions from the downloaded modules for LOCAL fallback try: # Adjust path if omnicloudmask is inside a subfolder omnicloudmask_path = os.path.join(repo_dir, "omnicloudmask") if os.path.isdir(omnicloudmask_path): sys.path.append(omnicloudmask_path) # Add subfolder if exists from omnicloudmask import predict_from_array LOCAL_MODEL_AVAILABLE = True print("Local omnicloudmask module loaded successfully.") except ImportError as e: print(f"ERROR: Could not import local 'predict_from_array' from omnicloudmask: {e}") print("Local fallback will not be available.") LOCAL_MODEL_AVAILABLE = False predict_from_array = None # Define dummy # --- Triton Server Configuration --- TRITON_IP = "206.123.129.87" # Use the public IP provided HTTP_TRITON_URL = f"{TRITON_IP}:8000" # GRPC_TRITON_URL = f"{TRITON_IP}:8001" # Keep for potential future use TRITON_MODEL_NAME = "cloud-detection" # Ensure this matches your deployed model name TRITON_INPUT_NAME = "input_jp2_bytes" # Ensure this matches your model's config.pbtxt TRITON_OUTPUT_NAME = "output_mask" # Ensure this matches your model's config.pbtxt TRITON_TIMEOUT_SECONDS = 300 # 5 minutes timeout for connection/network # --- Utility Functions (mostly from paste.txt) --- def visualize_rgb(red_file, green_file, blue_file): """ Create and display an RGB visualization immediately after images are uploaded. (Modified slightly: doesn't need nir_file) """ if not all([red_file, green_file, blue_file]): return None try: # Load bands (using load_band utility) # Get target shape from red band with rasterio.open(red_file) as src: target_height = src.height target_width = src.width blue_data = load_band(blue_file) green_data = load_band(green_file) red_data = load_band(red_file) # Compute max values for scaling (simple approach) red_max = np.percentile(red_data[red_data>0], 98) if np.any(red_data>0) else 1.0 green_max = np.percentile(green_data[green_data>0], 98) if np.any(green_data>0) else 1.0 blue_max = np.percentile(blue_data[blue_data>0], 98) if np.any(blue_data>0) else 1.0 # Create RGB image for visualization with dynamic normalization rgb_image = np.zeros((red_data.shape[0], red_data.shape[1], 3), dtype=np.float32) epsilon = 1e-10 rgb_image[:, :, 0] = np.clip(red_data / (red_max + epsilon), 0, 1) rgb_image[:, :, 1] = np.clip(green_data / (green_max + epsilon), 0, 1) rgb_image[:, :, 2] = np.clip(blue_data / (blue_max + epsilon), 0, 1) # Simple brightness/contrast adjustment (gamma correction) gamma = 1.8 rgb_image_enhanced = np.power(rgb_image, 1/gamma) # Convert to uint8 for display rgb_display = (rgb_image_enhanced * 255).astype(np.uint8) return rgb_display except Exception as e: print(f"Error generating RGB preview: {e}") import traceback traceback.print_exc() return None def visualize_jp2(file_path): """ Visualize a single JP2 file. (Unchanged from paste.txt) """ try: with rasterio.open(file_path) as src: data = src.read(1) # Check if data is all zero or invalid if np.all(data == 0) or np.ptp(data) == 0: print(f"Warning: Data in {file_path} is constant or zero. Cannot normalize.") # Return a black image or handle as appropriate return np.zeros((src.height, src.width, 3), dtype=np.uint8) # Normalize the data for visualization data_norm = (data - np.min(data)) / (np.max(data) - np.min(data)) # Apply a colormap for better visualization cmap = plt.get_cmap('viridis') colored_image = cmap(data_norm) # Convert to 8-bit for display return (colored_image[:, :, :3] * 255).astype(np.uint8) except Exception as e: print(f"Error visualizing JP2 file {file_path}: {e}") return None def load_band(file_path, resample=False, target_height=None, target_width=None): """ Load a single band from a raster file with optional resampling. (Unchanged from paste.txt) """ try: with rasterio.open(file_path) as src: if resample and target_height is not None and target_width is not None: # Ensure output shape matches target channels (1 for single band) out_shape = (1, target_height, target_width) band_data = src.read( out_shape=out_shape, resampling=Resampling.bilinear )[0].astype(np.float32) # Read only the first band after resampling else: band_data = src.read(1).astype(np.float32) # Read only the first band return band_data except Exception as e: print(f"Error loading band {file_path}: {e}") raise # Re-raise error to be caught by calling function def prepare_input_array(red_file, green_file, blue_file, nir_file): """ Prepare a stacked array (R, G, NIR) for the LOCAL model and an RGB image for visualization. (Slightly modified from paste.txt to handle potential loading errors) Returns: prediction_array (np.ndarray): Stacked array (R,G,NIR) for local model, or None on error. rgb_image_enhanced (np.ndarray): RGB image (0-1 float) for visualization, or None on error. """ try: # Get dimensions from red band to use for resampling with rasterio.open(red_file) as src: target_height = src.height target_width = src.width # Load bands (resample NIR band to match 10m resolution) blue_data = load_band(blue_file) # Needed for RGB viz green_data = load_band(green_file) red_data = load_band(red_file) nir_data = load_band( nir_file, resample=True, target_height=target_height, target_width=target_width ) # --- Prepare RGB Image for Visualization (similar to visualize_rgb but returns float array) --- red_max = np.percentile(red_data[red_data>0], 98) if np.any(red_data>0) else 1.0 green_max = np.percentile(green_data[green_data>0], 98) if np.any(green_data>0) else 1.0 blue_max = np.percentile(blue_data[blue_data>0], 98) if np.any(blue_data>0) else 1.0 epsilon = 1e-10 rgb_image = np.zeros((target_height, target_width, 3), dtype=np.float32) rgb_image[:, :, 0] = np.clip(red_data / (red_max + epsilon), 0, 1) rgb_image[:, :, 1] = np.clip(green_data / (green_max + epsilon), 0, 1) rgb_image[:, :, 2] = np.clip(blue_data / (blue_max + epsilon), 0, 1) # Apply gamma correction for enhancement gamma = 1.8 rgb_image_enhanced = np.power(rgb_image, 1/gamma) # --- End RGB Image Preparation --- # Stack bands in CHW format for LOCAL cloud mask prediction (red, green, nir) # Ensure all bands have the same shape before stacking if not (red_data.shape == green_data.shape == nir_data.shape): print("ERROR: Band shapes mismatch after loading/resampling!") print(f"Shapes - Red: {red_data.shape}, Green: {green_data.shape}, NIR: {nir_data.shape}") return None, None # Indicate error prediction_array = np.stack([red_data, green_data, nir_data], axis=0) # CHW format print(f"Local prediction array shape: {prediction_array.shape}") print(f"RGB visualization image shape: {rgb_image_enhanced.shape}") return prediction_array, rgb_image_enhanced except Exception as e: print(f"Error during input preparation: {e}") import traceback traceback.print_exc() return None, None # Indicate error def visualize_cloud_mask(rgb_image, pred_mask): """ Create a visualization of the cloud mask overlaid on the RGB image. (Unchanged from paste.txt, but added error checks) """ if rgb_image is None or pred_mask is None: print("Cannot visualize cloud mask: Missing RGB image or prediction mask.") return None try: # Ensure pred_mask has the right dimensions (H, W) if pred_mask.ndim == 3 and pred_mask.shape[0] == 1: # Squeeze channel dim if present pred_mask = np.squeeze(pred_mask, axis=0) elif pred_mask.ndim != 2: print(f"ERROR: Unexpected prediction mask dimension: {pred_mask.ndim}, shape: {pred_mask.shape}") # Attempt to squeeze if possible, otherwise fail try: pred_mask = np.squeeze(pred_mask) if pred_mask.ndim != 2: raise ValueError("Still not 2D after squeeze") except Exception as sq_err: print(f"Could not convert mask to 2D: {sq_err}") return None # Cannot visualize print(f"Visualization - RGB image shape: {rgb_image.shape}, Pred mask shape: {pred_mask.shape}") # Ensure mask has the same spatial dimensions as the image if pred_mask.shape != rgb_image.shape[:2]: print(f"Warning: Resizing prediction mask from {pred_mask.shape} to {rgb_image.shape[:2]} for visualization.") # Ensure mask is integer type for nearest neighbor interpolation if not np.issubdtype(pred_mask.dtype, np.integer): print("Warning: Prediction mask is not integer type, casting to uint8 for resize.") pred_mask = pred_mask.astype(np.uint8) pred_mask_resized = cv2.resize( pred_mask, (rgb_image.shape[1], rgb_image.shape[0]), # Target shape (width, height) for cv2.resize interpolation=cv2.INTER_NEAREST # Use nearest to preserve class labels ) pred_mask = pred_mask_resized print(f"Resized mask shape: {pred_mask.shape}") # Define colors for each class colors = { 0: [0, 255, 0], # Clear - Green 1: [255, 0, 0], # Thick Cloud - Red (Changed from White for better contrast) 2: [255, 255, 0], # Thin Cloud - Yellow (Changed from Gray) 3: [0, 0, 255] # Cloud Shadow - Blue (Changed from Gray) } # Create a color-coded mask visualization mask_vis = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8) for class_idx, color in colors.items(): # Handle potential out-of-bounds class indices in mask mask_vis[pred_mask == class_idx] = color # Create a blended visualization alpha = 0.4 # Transparency of the mask overlay # Ensure rgb_image is uint8 for blending rgb_uint8 = (np.clip(rgb_image, 0, 1) * 255).astype(np.uint8) blended = cv2.addWeighted(rgb_uint8, 1-alpha, mask_vis, alpha, 0) # --- Create Legend --- legend_height = 100 legend_width = blended.shape[1] # Match image width legend = np.ones((legend_height, legend_width, 3), dtype=np.uint8) * 255 # White background legend_text = ["Clear", "Thick Cloud", "Thin Cloud", "Cloud Shadow"] legend_colors = [colors.get(i, [0,0,0]) for i in range(4)] # Use .get for safety box_size = 15 text_offset_x = 40 start_y = 15 padding_y = 20 for i, (text, color) in enumerate(zip(legend_text, legend_colors)): # Draw color box cv2.rectangle(legend, (10, start_y + i*padding_y - box_size // 2), (10 + box_size, start_y + i*padding_y + box_size // 2), color, -1) # Draw text cv2.putText(legend, text, (text_offset_x, start_y + i*padding_y + box_size // 4), # Adjust vertical alignment cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) # --- End Legend --- # Combine image and legend final_output = np.vstack([blended, legend]) return final_output except Exception as e: print(f"Error during visualization: {e}") import traceback traceback.print_exc() return None # Return None if visualization fails # --- Triton Client Functions (Adapted from paste-2.txt) --- def is_triton_server_healthy(url=HTTP_TRITON_URL): """Checks if the Triton Inference Server is live.""" if not TRITON_CLIENT_AVAILABLE: return False try: triton_client = httpclient.InferenceServerClient(url=url, connection_timeout=10.0) # Short timeout for health check server_live = triton_client.is_server_live() if server_live: print(f"Triton server at {url} is live.") # Optionally check readiness: # server_ready = triton_client.is_server_ready() # print(f"Triton server at {url} is ready: {server_ready}") # return server_ready else: print(f"Triton server at {url} is not live.") return server_live except Exception as e: print(f"Could not connect to Triton server at {url}: {e}") return False def get_jp2_bytes_for_triton(red_file_path, green_file_path, nir_file_path): """ Reads the raw bytes of Red, Green, and NIR JP2 files for Triton. Order: Red, Green, NIR (must match Triton model input expectation) """ byte_list = [] files_to_read = [red_file_path, green_file_path, nir_file_path] band_names = ['Red', 'Green', 'NIR'] for file_path, band_name in zip(files_to_read, band_names): try: with open(file_path, "rb") as f: file_bytes = f.read() byte_list.append(file_bytes) print(f"Read {len(file_bytes)} bytes for {band_name} band from {os.path.basename(file_path)}") except FileNotFoundError: print(f"ERROR: File not found: {file_path}") raise # Propagate error except Exception as e: print(f"ERROR: Could not read file {file_path}: {e}") raise # Propagate error # Create NumPy array of object type to hold bytes input_byte_array = np.array(byte_list, dtype=object) # Expected shape is (3,) -> a 1D array containing 3 byte objects print(f"Prepared Triton input byte array with shape: {input_byte_array.shape} and dtype: {input_byte_array.dtype}") return input_byte_array def run_inference_triton_http(input_byte_array): """ Run inference using Triton HTTP client with raw JP2 bytes. """ if not TRITON_CLIENT_AVAILABLE: raise RuntimeError("Triton client library not available.") print("Attempting inference using Triton HTTP client...") try: client = httpclient.InferenceServerClient( url=HTTP_TRITON_URL, verbose=False, connection_timeout=TRITON_TIMEOUT_SECONDS, network_timeout=TRITON_TIMEOUT_SECONDS ) except Exception as e: print(f"ERROR: Couldn't create Triton HTTP client: {e}") raise # Propagate error # Prepare input tensor (BYTES type) # Shape [3] matches the 1D numpy array holding 3 byte strings inputs = [httpclient.InferInput(TRITON_INPUT_NAME, input_byte_array.shape, "BYTES")] inputs[0].set_data_from_numpy(input_byte_array, binary_data=True) # binary_data=True is important for BYTES # Prepare output tensor request outputs = [httpclient.InferRequestedOutput(TRITON_OUTPUT_NAME, binary_data=True)] # Send inference request try: print(f"Sending inference request to Triton model '{TRITON_MODEL_NAME}' at {HTTP_TRITON_URL}...") response = client.infer( model_name=TRITON_MODEL_NAME, inputs=inputs, outputs=outputs, request_id=str(os.getpid()), # Optional request ID timeout=TRITON_TIMEOUT_SECONDS ) print("Triton inference request successful.") mask = response.as_numpy(TRITON_OUTPUT_NAME) print(f"Received output mask from Triton with shape: {mask.shape}, dtype: {mask.dtype}") return mask except triton_utils.InferenceServerException as e: print(f"ERROR: Triton server failed inference: Status code {e.status()}, message: {e.message()}") print(f"Debug details: {e.debug_details()}") raise # Propagate error to trigger fallback except Exception as e: print(f"ERROR: An unexpected error occurred during Triton HTTP inference: {e}") import traceback traceback.print_exc() raise # Propagate error to trigger fallback # --- Main Processing Function with Fallback Logic --- def process_satellite_images(red_file, green_file, blue_file, nir_file, batch_size, patch_size, patch_overlap): """ Process satellite images: Try Triton first, fallback to local model. """ if not all([red_file, green_file, blue_file, nir_file]): return None, None, "ERROR: Please upload all four channel files (Red, Green, Blue, NIR)" # Store file paths from Gradio Image components red_file_path = red_file if isinstance(red_file, str) else red_file.name green_file_path = green_file if isinstance(green_file, str) else green_file.name blue_file_path = blue_file if isinstance(blue_file, str) else blue_file.name nir_file_path = nir_file if isinstance(nir_file, str) else nir_file.name print("\n--- Starting Cloud Detection Process ---") print(f"Input files: R={os.path.basename(red_file_path)}, G={os.path.basename(green_file_path)}, B={os.path.basename(blue_file_path)}, N={os.path.basename(nir_file_path)}") pred_mask = None status_message = "" rgb_display_image = None # For the raw RGB output panel rgb_float_image = None # For overlay visualization # 1. Prepare Visualization Image (always needed) & Local Input Array (needed for fallback) print("Preparing visualization image and local model input array...") local_input_array, rgb_float_image = prepare_input_array(red_file_path, green_file_path, blue_file_path, nir_file_path) if rgb_float_image is not None: # Convert float image (0-1) to uint8 (0-255) for the RGB output panel rgb_display_image = (np.clip(rgb_float_image, 0, 1) * 255).astype(np.uint8) else: print("ERROR: Failed to create RGB visualization image.") # Return early if visualization prep failed, as likely indicates file loading issues return None, None, "ERROR: Failed to load or process input band files." # 2. Check Triton Server Health use_triton = False if TRITON_CLIENT_AVAILABLE: print(f"Checking Triton server health at {HTTP_TRITON_URL}...") if is_triton_server_healthy(HTTP_TRITON_URL): use_triton = True else: print("Triton server is not healthy or unavailable.") status_message += "Triton server unavailable. " else: print("Triton client library not installed. Skipping Triton check.") status_message += "Triton client not installed. " # 3. Attempt Triton Inference if Healthy if use_triton: try: print("Preparing JP2 bytes for Triton...") # Use Red, Green, NIR file paths triton_byte_input = get_jp2_bytes_for_triton(red_file_path, green_file_path, nir_file_path) pred_mask = run_inference_triton_http(triton_byte_input) status_message += "Inference performed using Triton Server. " print("Triton inference successful.") except Exception as e: print(f"Triton inference failed: {e}. Falling back to local model.") status_message += f"Triton inference failed ({type(e).__name__}). " pred_mask = None # Ensure mask is None to trigger fallback use_triton = False # Explicitly mark Triton as not used # 4. Fallback to Local Model if Triton failed or wasn't available/healthy if pred_mask is None: # Check if mask wasn't obtained from Triton status_message += "Falling back to local inference. " if LOCAL_MODEL_AVAILABLE and local_input_array is not None: print("Running local inference using omnicloudmask...") try: # Predict cloud mask using local omnicloudmask pred_mask = predict_from_array( local_input_array, batch_size=batch_size, patch_size=patch_size, patch_overlap=patch_overlap ) print(f"Local prediction successful. Output mask shape: {pred_mask.shape}, dtype: {pred_mask.dtype}") status_message += "Local inference successful." except Exception as e: print(f"ERROR: Local inference failed: {e}") import traceback traceback.print_exc() status_message += f"Local inference FAILED: {e}" # Keep pred_mask as None elif not LOCAL_MODEL_AVAILABLE: status_message += "Local model not available. Cannot perform inference." print("ERROR: Local model could not be loaded.") elif local_input_array is None: status_message += "Local input data preparation failed. Cannot perform local inference." print("ERROR: Failed to prepare input array for local model.") else: status_message += "Unknown state, cannot perform inference." # Should not happen # 5. Process Results (Stats and Visualization) if mask was generated if pred_mask is not None: # Ensure mask is squeezed to 2D if necessary (local model might return extra dim) if pred_mask.ndim == 3 and pred_mask.shape[0] == 1: flat_mask = np.squeeze(pred_mask, axis=0) elif pred_mask.ndim == 2: flat_mask = pred_mask else: print(f"ERROR: Unexpected mask shape after inference: {pred_mask.shape}") status_message += " ERROR: Invalid mask shape received." flat_mask = None # Invalidate mask if flat_mask is not None: # Calculate class distribution clear_pixels = np.sum(flat_mask == 0) thick_cloud_pixels = np.sum(flat_mask == 1) thin_cloud_pixels = np.sum(flat_mask == 2) cloud_shadow_pixels = np.sum(flat_mask == 3) total_pixels = flat_mask.size stats = f""" Cloud Mask Statistics ({'Triton' if use_triton else 'Local'}): - Clear: {clear_pixels} pixels ({clear_pixels/total_pixels*100:.2f}%) - Thick Cloud: {thick_cloud_pixels} pixels ({thick_cloud_pixels/total_pixels*100:.2f}%) - Thin Cloud: {thin_cloud_pixels} pixels ({thin_cloud_pixels/total_pixels*100:.2f}%) - Cloud Shadow: {cloud_shadow_pixels} pixels ({cloud_shadow_pixels/total_pixels*100:.2f}%) - Total Cloud Cover (Thick+Thin): {(thick_cloud_pixels + thin_cloud_pixels)/total_pixels*100:.2f}% """ status_message += f"\nMask stats calculated. Total pixels: {total_pixels}." # Visualize the cloud mask on the original image print("Generating final visualization...") visualization = visualize_cloud_mask(rgb_float_image, flat_mask) # Use float image for viz function if visualization is None: status_message += " ERROR: Failed to generate visualization." print("--- Cloud Detection Process Finished ---") return rgb_display_image, visualization, status_message + "\n" + stats else: # Mask had wrong shape return rgb_display_image, None, status_message + "\nERROR: Could not process prediction mask." else: # Inference failed both ways or initial loading failed print("--- Cloud Detection Process Failed ---") return rgb_display_image, None, status_message + "\nERROR: Could not generate cloud mask." # --- Gradio Interface (from paste.txt) --- def check_cpu_usage(): """Check and return the current CPU usage.""" return f"CPU Usage: {psutil.cpu_percent()}%" # --- Build Gradio App --- print("Building Gradio interface...") with gr.Blocks(title="Satellite Cloud Detection (Triton/Local)") as demo: gr.Markdown(""" # Satellite Cloud Detection (with Triton Fallback) Upload separate JP2 files for Red (e.g., B04), Green (e.g., B03), Blue (e.g., B02), and NIR (e.g., B8A) channels. The application will **first attempt** to use a remote Triton Inference Server. If the server is unavailable or inference fails, it will **fall back** to using the local OmniCloudMask model. **Pixel Classification:** - Clear (Green) - Thick Cloud (Red) - Thin Cloud (Yellow) - Cloud Shadow (Blue) The model works best with imagery at 10-50m resolution. """) # Main cloud detection interface with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Input Bands (JP2)") # Use filepaths which are needed for both local reading and byte reading red_input = gr.File(label="Red Channel (e.g., B04)", type="filepath") green_input = gr.File(label="Green Channel (e.g., B03)", type="filepath") blue_input = gr.File(label="Blue Channel (e.g., B02)", type="filepath") nir_input = gr.File(label="NIR Channel (e.g., B8A)", type="filepath") gr.Markdown("### Local Model Parameters (Used for Fallback)") batch_size = gr.Slider(minimum=1, maximum=32, value=4, step=1, label="Batch Size", info="Memory usage/speed for local model") patch_size = gr.Slider(minimum=256, maximum=2048, value=1024, step=128, label="Patch Size", info="Patch size for local model processing") patch_overlap = gr.Slider(minimum=64, maximum=512, value=256, step=64, label="Patch Overlap", info="Overlap for local model processing") process_btn = gr.Button("Process Cloud Detection", variant="primary") with gr.Column(scale=2): gr.Markdown("### Results") # Output components rgb_output = gr.Image(label="Original RGB Image (Approx. True Color)", type="numpy") cloud_output = gr.Image(label="Cloud Detection Visualization (Mask Overlay)", type="numpy") stats_output = gr.Textbox(label="Processing Status & Statistics", lines=10) # CPU usage monitoring section (Optional) with gr.Accordion("System Monitoring", open=False): cpu_button = gr.Button("Check CPU Usage") cpu_output = gr.Textbox(label="Current CPU Usage") cpu_button.click(fn=check_cpu_usage, inputs=None, outputs=cpu_output) # Examples section # Ensure example paths are relative to where the script is run, # or absolute if needed. Assumes 'jp2s' folder is present. example_base = os.path.join(repo_dir, "jp2s") # Use downloaded repo path example_files = [ os.path.join(example_base, "B04.jp2"), # Red os.path.join(example_base, "B03.jp2"), # Green os.path.join(example_base, "B02.jp2"), # Blue os.path.join(example_base, "B8A.jp2") # NIR ] # Check if example files actually exist before adding example if all(os.path.exists(f) for f in example_files): print("Adding examples...") gr.Examples( examples=[example_files + [4, 1024, 256]], # Corresponds to inputs below inputs=[red_input, green_input, blue_input, nir_input, batch_size, patch_size, patch_overlap], outputs=[rgb_output, cloud_output, stats_output], # Define outputs for examples too fn=process_satellite_images, # Function to run for examples cache_examples=False # Maybe disable caching if files change or for debugging ) else: print(f"WARN: Example JP2 files not found in '{example_base}'. Skipping examples.") # Setup main button click handler process_btn.click( fn=process_satellite_images, inputs=[red_input, green_input, blue_input, nir_input, batch_size, patch_size, patch_overlap], outputs=[rgb_output, cloud_output, stats_output] ) # --- Launch the App --- print("Launching Gradio app...") # Allow queueing and potentially increase workers if needed demo.queue(default_concurrency_limit=4).launch(debug=True, share=False) # share=True for public link if needed