jeffaudi commited on
Commit
3eb34a9
1 Parent(s): 94b8172

Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ **/.DS_Store
2
+ **/__pycache__/
3
+ weights/
README.md CHANGED
@@ -1,9 +1,11 @@
1
  ---
2
- title: Aircraft Detection Optical Satellite
3
- emoji: 🏆
4
  colorFrom: indigo
5
- colorTo: red
6
- sdk: docker
 
 
7
  pinned: false
8
  license: cc-by-nc-sa-4.0
9
  ---
 
1
  ---
2
+ title: Aircraft Detection in Optical Satellite images
3
+ emoji: ✈️
4
  colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.19.2
8
+ app_file: app.py
9
  pinned: false
10
  license: cc-by-nc-sa-4.0
11
  ---
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import socket
3
+ import gradio as gr
4
+ import numpy as np
5
+ from PIL import Image, ImageDraw
6
+ from pathlib import Path
7
+ from loguru import logger
8
+ import cv2
9
+ import torch
10
+ import ultralytics
11
+ from ultralytics import YOLO
12
+ import time
13
+ import base64
14
+ import requests
15
+ import json
16
+
17
+ # API for inferences
18
+ DL4EO_API_URL = "https://dl4eo--predict.modal.run"
19
+
20
+ # Auth Token to access API
21
+ DL4EO_API_KEY = os.environ['DL4EO_API_KEY']
22
+
23
+ # width of the boxes on image
24
+ LINE_WIDTH = 2
25
+
26
+ # Load a model if weights are present
27
+ WEIGHTS_FILE = './weights/best.pt'
28
+ model = None
29
+ if os.path.exists(WEIGHTS_FILE):
30
+ model = YOLO(WEIGHTS_FILE) # previously trained YOLOv8n model
31
+ logger.info(f"Setup for local inference")
32
+
33
+ # check if GPU if available
34
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
35
+ logger.info(f"Using device: {device}")
36
+
37
+ # Check Ultralytics modules version
38
+ logger.info(f"Ultralytics version: {ultralytics.__version__}")
39
+
40
+ # Check Gradio modules version
41
+ logger.info(f"Gradio version: {gr.__version__}")
42
+
43
+ # Define the inference function
44
+ def predict_image(image, threshold):
45
+
46
+ # Resize the image to the new size
47
+ #image = image.resize((image.size[0] * 2, image.size[1] * 2))
48
+
49
+ if isinstance(image, Image.Image):
50
+ img = np.array(image)
51
+
52
+ if not isinstance(img, np.ndarray) or len(img.shape) != 3 or img.shape[2] != 3:
53
+ raise BaseException("predit_image(): input 'img' shoud be single RGB image in PIL or Numpy array format.")
54
+
55
+ width, height = img.shape[0], img.shape[1]
56
+
57
+ if model is None:
58
+ # Encode the image data as base64
59
+ image_base64 = base64.b64encode(np.ascontiguousarray(img)).decode()
60
+
61
+ # Create a dictionary representing the JSON payload
62
+ payload = {
63
+ 'image': image_base64,
64
+ 'shape': img.shape,
65
+ 'threshold': threshold,
66
+ }
67
+
68
+ headers = {
69
+ 'Authorization': 'Bearer ' + DL4EO_API_KEY,
70
+ 'Content-Type': 'application/json' # Adjust the content type as needed
71
+ }
72
+
73
+ # Send the POST request to the API endpoint with the image file as binary payload
74
+ response = requests.post(DL4EO_API_URL, json=payload, headers=headers)
75
+
76
+ # Check the response status
77
+ if response.status_code != 200:
78
+ raise Exception(
79
+ f"Received status code={response.status_code} in inference API"
80
+ )
81
+
82
+ json_data = json.loads(response.content)
83
+ duration = json_data['duration']
84
+ boxes = json_data['boxes']
85
+ else:
86
+ start_time = time.time()
87
+ results = model.predict([img], imgsz=(width, height), conf=threshold)
88
+ end_time = time.time()
89
+ boxes = [box.xyxy.cpu().squeeze().int().tolist() for box in boxes]
90
+ duration = end_time - start_time
91
+ boxes = results[0].boxes
92
+
93
+ # drow boxes on image
94
+ draw = ImageDraw.Draw(image)
95
+
96
+ for box in boxes:
97
+ left, top, right, bottom = box
98
+
99
+ if left <= 0: left = -LINE_WIDTH
100
+ if top <= 0: top = top - LINE_WIDTH
101
+ if right >= img.shape[0] - 1: right = img.shape[0] - 1 + LINE_WIDTH
102
+ if bottom >= img.shape[1] - 1: bottom = img.shape[1] - 1 + LINE_WIDTH
103
+
104
+ draw.rectangle([left, top, right, bottom], outline="red", width=LINE_WIDTH)
105
+
106
+ return image, str(image.size), len(boxes), duration
107
+
108
+
109
+ # Define example images and their true labels for users to choose from
110
+ example_data = [
111
+ ["./demo/airport01.jpg", 0.50],
112
+ ["./demo/airport02.jpg", 0.50],
113
+ ["./demo/airport03.jpg", 0.50],
114
+ ["./demo/airport04.jpg", 0.50],
115
+ ["./demo/Pleiades_Neo_Tucson_USA.jpg", 0.50],
116
+ ]
117
+
118
+ # Define CSS for some elements
119
+ css = """
120
+ .image-preview {
121
+ height: 820px !important;
122
+ width: 800px !important;
123
+ }
124
+ """
125
+ TITLE = "Aircraft detection with YOLOv8"
126
+
127
+ # Define the Gradio Interface
128
+ demo = gr.Blocks(title=TITLE, css=css).queue()
129
+ with demo:
130
+ gr.Markdown(f"<h1><center>{TITLE}<center><h1>")
131
+ #gr.Markdown("<p>This demo is provided by <a href='https://www.linkedin.com/in/faudi/'>Jeff Faudi</a> and <a href='https://www.dl4eo.com/'>DL4EO</a></p>")
132
+
133
+ with gr.Row():
134
+ with gr.Column(scale=0):
135
+ input_image = gr.Image(type="pil", interactive=True, scale=1)
136
+ run_button = gr.Button(value="Run", scale=0)
137
+ with gr.Accordion("Advanced options", open=True):
138
+ threshold = gr.Slider(label="Confidence threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.01)
139
+ dimensions = gr.Textbox(label="Image size", interactive=False)
140
+ detections = gr.Number(label="Predicted aircrafts", interactive=False)
141
+ stopwatch = gr.Number(label="Execution time (sec.)", interactive=False, precision=3)
142
+
143
+ with gr.Column(scale=2):
144
+ output_image = gr.Image(type="pil", elem_classes='image-preview', interactive=False, width=800, height=800)
145
+
146
+ run_button.click(fn=predict_image, inputs=[input_image, threshold], outputs=[output_image, dimensions, detections, stopwatch])
147
+ gr.Examples(
148
+ examples=example_data,
149
+ inputs = [input_image, threshold],
150
+ outputs = [output_image, dimensions, detections, stopwatch],
151
+ fn=predict_image,
152
+ #cache_examples=True,
153
+ label='Try these images!'
154
+ )
155
+
156
+ gr.Markdown("<p>This demo is provided by <a href='https://www.linkedin.com/in/faudi/'>Jeff Faudi</a> and <a href='https://www.dl4eo.com/'>DL4EO</a>. The model has been trained with the <a href='https://www.ultralytics.com/yolo'>Ultralytics YOLOv8</a> framework on the <a href='https://www.cosmiqworks.org/rareplanes/'>RarePlanes</a>, <a href='https://www.kaggle.com/datasets/airbusgeo/airbus-aircrafts-sample-dataset'>Airbus</a> and <a href='https://github.com/dilsadunsal/HRPlanesv2-Data-Set'>HRPlanesv2</a> datasets. The associated licenses are <a href='https://about.google/brand-resource-center/products-and-services/geo-guidelines/#google-earth-web-and-apps'>GoogleEarth fair use</a> and <a href='https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en'>CC-BY-SA-NC</a>. This demonstration CANNOT be used for commercial puposes. Please contact <a href='mailto:jeff@dl4eo.com'>me</a> for more information on how you could get access to a commercial grade model or API. </p>")
157
+
158
+ if os.path.exists('/.dockerenv'):
159
+ print('Running inside a Docker container')
160
+
161
+ # Launch the interface on MacOS
162
+ hostname = socket.gethostname()
163
+
164
+ demo.launch(
165
+ server_name=hostname,
166
+ inline=False,
167
+ server_port=7860,
168
+ debug=True
169
+ )
170
+ else:
171
+ print('Not running inside a Docker container')
172
+ demo.launch(
173
+ inline=False,
174
+ #server_port=7860,
175
+ debug=False
176
+ )
demo/Pleiades_Neo_Tucson_USA.jpg ADDED
demo/airport01.jpg ADDED
demo/airport02.jpg ADDED
demo/airport03.jpg ADDED
demo/airport04.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ loguru
2
+ ultralytics==8.1.18
3
+ gradio==3.35.2