|
import os |
|
import socket |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image, ImageDraw |
|
from pathlib import Path |
|
from loguru import logger |
|
import cv2 |
|
import torch |
|
import ultralytics |
|
from ultralytics import YOLO |
|
import time |
|
import base64 |
|
import requests |
|
import json |
|
|
|
|
|
DL4EO_API_URL = "https://dl4eo--oil-storage-predict.modal.run" |
|
|
|
|
|
DL4EO_API_KEY = os.environ['DL4EO_API_KEY'] |
|
|
|
|
|
LINE_WIDTH = 2 |
|
|
|
|
|
WEIGHTS_FILE = './weights/best.pt' |
|
model = None |
|
if os.path.exists(WEIGHTS_FILE): |
|
model = YOLO(WEIGHTS_FILE) |
|
logger.info(f"Setup for local inference") |
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") |
|
logger.info(f"Using device: {device}") |
|
|
|
|
|
logger.info(f"Ultralytics version: {ultralytics.__version__}") |
|
|
|
logger.info(f"Gradio version: {gr.__version__}") |
|
|
|
|
|
def predict_image(image, threshold): |
|
|
|
|
|
|
|
|
|
if isinstance(image, Image.Image): |
|
img = np.array(image) |
|
|
|
if not isinstance(img, np.ndarray) or len(img.shape) != 3 or img.shape[2] != 3: |
|
raise BaseException("predit_image(): input 'img' shoud be single RGB image in PIL or Numpy array format.") |
|
|
|
width, height = img.shape[0], img.shape[1] |
|
|
|
if model is None: |
|
|
|
image_base64 = base64.b64encode(np.ascontiguousarray(img)).decode() |
|
|
|
|
|
payload = { |
|
'image': image_base64, |
|
'shape': img.shape, |
|
'threshold': threshold, |
|
} |
|
|
|
headers = { |
|
'Authorization': 'Bearer ' + DL4EO_API_KEY, |
|
'Content-Type': 'application/json' |
|
} |
|
|
|
|
|
response = requests.post(DL4EO_API_URL, json=payload, headers=headers) |
|
|
|
|
|
if response.status_code != 200: |
|
raise Exception( |
|
f"Received status code={response.status_code} in inference API" |
|
) |
|
|
|
json_data = json.loads(response.content) |
|
duration = json_data['duration'] |
|
boxes = json_data['boxes'] |
|
else: |
|
start_time = time.time() |
|
results = model.predict([img], imgsz=(width, height), conf=threshold) |
|
end_time = time.time() |
|
boxes = [box.xyxy.cpu().squeeze().int().tolist() for box in boxes] |
|
duration = end_time - start_time |
|
boxes = results[0].boxes |
|
|
|
|
|
draw = ImageDraw.Draw(image) |
|
|
|
for box in boxes: |
|
left, top, right, bottom = box |
|
|
|
if left <= 0: left = -LINE_WIDTH |
|
if top <= 0: top = top - LINE_WIDTH |
|
if right >= img.shape[0] - 1: right = img.shape[0] - 1 + LINE_WIDTH |
|
if bottom >= img.shape[1] - 1: bottom = img.shape[1] - 1 + LINE_WIDTH |
|
|
|
draw.rectangle([left, top, right, bottom], outline="red", width=LINE_WIDTH) |
|
|
|
return image, str(image.size), len(boxes), duration |
|
|
|
|
|
|
|
example_data = [ |
|
["./demo/588fc1fb-b86a-4fb4-8161-d9bd3a1556ca.jpg", 0.50], |
|
["./demo/605ffac0-69d5-4748-92c2-48d43f51afc1.jpg", 0.50], |
|
["./demo/67f7c7ad-11a1-4c7f-9f2a-da7ef50bfdd8.jpg", 0.50], |
|
["./demo/b8c0e212-3669-4ff8-81a5-32191d456f86.jpg", 0.50], |
|
["./demo/df5ec618-c1f3-4cfe-88b1-86799d23c22d.jpg", 0.50]] |
|
|
|
|
|
css = """ |
|
.image-preview { |
|
height: 820px !important; |
|
width: 800px !important; |
|
} |
|
""" |
|
TITLE = "Oil storage detection on SPOT images (1.5 m) with YOLOv8" |
|
|
|
|
|
demo = gr.Blocks(title=TITLE, css=css).queue() |
|
with demo: |
|
gr.Markdown(f"<h1><center>{TITLE}<center><h1>") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=0): |
|
input_image = gr.Image(type="pil", interactive=True, scale=1) |
|
run_button = gr.Button(value="Run", scale=0) |
|
with gr.Accordion("Advanced options", open=True): |
|
threshold = gr.Slider(label="Confidence threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.01) |
|
dimensions = gr.Textbox(label="Image size", interactive=False) |
|
detections = gr.Number(label="Predicted objects", interactive=False) |
|
stopwatch = gr.Number(label="Execution time (sec.)", interactive=False, precision=3) |
|
|
|
with gr.Column(scale=2): |
|
output_image = gr.Image(type="pil", elem_classes='image-preview', interactive=False, width=800, height=800) |
|
|
|
run_button.click(fn=predict_image, inputs=[input_image, threshold], outputs=[output_image, dimensions, detections, stopwatch]) |
|
gr.Examples( |
|
examples=example_data, |
|
inputs = [input_image, threshold], |
|
outputs = [output_image, dimensions, detections, stopwatch], |
|
fn=predict_image, |
|
cache_examples=True, |
|
label='Try these images! They are not included in the training dataset.' |
|
) |
|
|
|
gr.Markdown("""<p>This demo is provided by <a href='https://www.linkedin.com/in/faudi/'>Jeff Faudi</a> and <a href='https://www.dl4eo.com/'>DL4EO</a>. |
|
The model has been trained with the <a href='https://www.ultralytics.com/yolo'>Ultralytics YOLOv8</a> framework on the |
|
<a href='https://www.kaggle.com/datasets/airbusgeo/airbus-oil-storage-detection-dataset'>Airbus Oil Storage Dataset</a>. |
|
The associated license is <a href='https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en'>CC-BY-SA-NC</a>. |
|
This demonstration CANNOT be used for commercial puposes. Please contact <a href='mailto:jeff@dl4eo.com'>me</a> |
|
for more information on how you could get access to a commercial grade model or API. </p>""") |
|
|
|
|
|
demo.launch( |
|
inline=False, |
|
show_api=False, |
|
debug=False |
|
) |
|
|
|
|