import os import socket import time import gradio as gr import numpy as np from PIL import Image from loguru import logger import cv2 import torch # Import MMRotate from mmrotate.models import build_detector from mmdet.apis import init_detector, inference_detector from mmrotate.apis import inference_detector_by_patches from mmrotate.datasets import DOTADataset # Default size for model IMG_SIZE = 1024 OVERLAP = 192 MARGIN = OVERLAP / 2 # depends on the GPU memory BATCH_SIZE = 16 # CLASSES CLASSES = ['ship',] # Choose to use a config and initialize the detector config_file = 'redet_re50_refpn_1x_dota_ms_rr_le90.py' # Setup a checkpoint file to load weights_file = 'weights/best_mAP_epoch_20.pth' # check if GPU if available device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") logger.info(f"Using device: {device}") # Check Gradio version logger.info(f"Gradio version: {gr.__version__}") # build the model from a config file and a checkpoint file model = init_detector(config_file, weights_file, device) # Define the inference function def predict_image(img, threshold): if isinstance(img, Image.Image): img = np.array(img) if not isinstance(img, np.ndarray) or len(img.shape) != 3 or img.shape[2] != 3: raise BaseException("predit_image(): input 'img' shoud be single RGB image in PIL or Numpy array format.") start_time = time.time() if img.shape[0] > IMG_SIZE or img.shape[1] > IMG_SIZE: #print("Running inference_detector_by_patches") result = inference_detector_by_patches(model, img, sizes=[IMG_SIZE], steps=[IMG_SIZE - 2 * MARGIN], ratios=[1.0], merge_iou_thr=0.3, bs=BATCH_SIZE) else: #print("Running inference_detector") result = inference_detector(model, img) end_time = time.time() #print(result) # total number of predictions infos = np.sum(result[0][:, -1] > threshold) img_preds = model.show_result(img, result, score_thr=threshold, show=False) return img_preds, img.shape, infos, end_time - start_time # Define example images and their true labels for users to choose from example_data = [ ["./demo/82f13510a.jpg", 0.75], ["./demo/836f35381.jpg", 0.75], ["./demo/848d2afef.jpg", 0.75], ["./demo/Satellite_Image_Marina_New_Zealand.jpg", 0.4], ["./demo/Pleiades_HD15_Miami_Marina.jpg", 0.4], # Add more example images and labels as needed ] # Define CSS for some elements css = """ .image-preview { height: 820px !important; width: 800px !important; } """ TITLE = "Ship Detection on Optical Satellite image" # Define the Gradio Interface demo = gr.Blocks(title=TITLE, css=css).queue() with demo: gr.Markdown(f"

{TITLE}

") with gr.Row(): with gr.Column(scale=0): input_image = gr.Image(type="pil", interactive=True) run_button = gr.Button(value="Run") with gr.Accordion("Advanced options", open=True): threshold = gr.Slider(label="Confidence threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.01) dimensions = gr.Textbox(label="Image size", interactive=False) detections = gr.Textbox(label="Predicted ships", interactive=False) stopwatch = gr.Number(label="Execution time (sec.)", interactive=False, precision=3) with gr.Column(scale=2): output_image = gr.Image(type="pil", elem_classes='image-preview', interactive=False) run_button.click(fn=predict_image, inputs=[input_image, threshold], outputs=[output_image, dimensions, detections, stopwatch]) gr.Examples( examples=example_data, inputs = [input_image, threshold], outputs = [output_image, dimensions, detections, stopwatch], fn=predict_image, #cache_examples=True, label='Try these images!' ) gr.Markdown("

This demo is provided by Jeff Faudi and DL4EO. This model is based on the MMRotate framework which provides oriented bounding boxes. We believe that oriented bouding boxes are better suited for detection in satellite images. This model has been trained on Airbus Ship Detection available on Kaggle. The associated license is CC-BY-SA-NC. This demonstration CANNOT be used for commercial puposes. Please contact me for more information on how you could get access to a commercial grade model or API.

") if os.path.exists('/.dockerenv'): print('Running inside a Docker container') # Launch the interface on MacOS hostname = socket.gethostname() demo.launch( server_name=hostname, inline=False, server_port=7860, debug=True ) else: print('Not running inside a Docker container') demo.launch( inline=False, server_port=7860, debug=False )