mayrajeo commited on
Commit
fbad261
·
verified ·
1 Parent(s): 00d722c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -33
app.py CHANGED
@@ -1,41 +1,36 @@
1
- import sys, os
2
 
3
  import gradio as gr
4
  import plotly.express as px
5
  import numpy as np
6
- import random
7
- from ultralytics import YOLO
8
- #from sahi.models.yolov8 import *
9
- from src.sahi_onnx import *
10
  from sahi.predict import get_sliced_prediction
11
  from sahi.utils.cv import visualize_object_predictions
12
  import PIL
13
 
14
- #model_base = "https://huggingface.co/mayrajeo/marine-vessel-detection/resolve/main/"
15
- model_base = 'onnx_models'
16
 
17
  def inference(
18
  im:gr.Image=None,
19
- model_path:gr.Dropdown='YOLOv8n',
20
  conf_thr:gr.Slider=0.25
21
  ):
22
- #model = Yolov8DetectionModel(model_path=f'{model_base}/{model_path}/{model_path}.pt',
23
- model = Yolov8onnxDetectionModel(model_path=f'{model_base}/{model_path}/{model_path.lower()}.onnx',
24
- config_path=f'{model_base}/{model_path}/args.yaml',
25
- device='cpu',
26
- confidence_threshold=conf_thr,
27
- category_mapping={'0': 'Boat'},
28
- image_size=640)
29
 
30
  res = get_sliced_prediction(im, model, slice_width=320,
31
  slice_height=320, overlap_height_ratio=0.2,
32
  overlap_width_ratio=0.2, verbose=0)
33
  img = PIL.Image.open(im)
34
- visual_result = visualize_object_predictions(image=np.array(img),
 
 
35
  object_prediction_list=res.object_prediction_list,
36
- text_size=0.4,
37
  rect_th=1)
38
- fig = px.imshow(visual_result['image'])
39
  fig.update_layout(showlegend=False, hovermode=False)
40
  fig.update_xaxes(visible=False)
41
  fig.update_yaxes(visible=False)
@@ -43,15 +38,7 @@ def inference(
43
 
44
  inputs = [
45
  gr.Image(type='filepath', label='Input'),
46
- gr.Dropdown([
47
- 'YOLOv8n',
48
- 'YOLOv8s',
49
- 'YOLOv8m',
50
- 'YOLOv8l',
51
- 'YOLOv8x'
52
- ],
53
- value='YOLOv8n', label='Model'),
54
- gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label='Confidence Threshold'),
55
  ]
56
 
57
  outputs = [
@@ -61,7 +48,6 @@ outputs = [
61
  example_images = [[f'examples/{f}'] for f in os.listdir('examples')]
62
 
63
 
64
-
65
  gr.Interface(
66
  fn=inference,
67
  inputs=inputs,
@@ -71,9 +57,9 @@ gr.Interface(
71
  cache_examples=False,
72
  examples_per_page=10,
73
  title='Marine vessel detection from Sentinel 2 images',
74
- description="""Models detect potential marine vessels from Sentinel 2 imagery.
75
- Each example image covers 7.68x7.68 km (768x768 pixels).
 
76
  As we don't clean the prediction with stationary targets that look like vessels in this resolution,
77
- there will most likely be false positives from lighthouses, above-water rocks and on land.\n
78
- Be patient with responses, as free tier only has 2vCPUs so app might be slow sometimes."""
79
- ).launch()
 
1
+ import os
2
 
3
  import gradio as gr
4
  import plotly.express as px
5
  import numpy as np
6
+ from sahi.models.ultralytics import *
 
 
 
7
  from sahi.predict import get_sliced_prediction
8
  from sahi.utils.cv import visualize_object_predictions
9
  import PIL
10
 
11
+ MODEL_PATH = 'yolo11s_tci.pt'
 
12
 
13
  def inference(
14
  im:gr.Image=None,
 
15
  conf_thr:gr.Slider=0.25
16
  ):
17
+ model = UltralyticsDetectionModel(model_path=MODEL_PATH,
18
+ device='cuda',
19
+ confidence_threshold=conf_thr,
20
+ category_mapping={'0': 'Boat'},
21
+ image_size=640)
 
 
22
 
23
  res = get_sliced_prediction(im, model, slice_width=320,
24
  slice_height=320, overlap_height_ratio=0.2,
25
  overlap_width_ratio=0.2, verbose=0)
26
  img = PIL.Image.open(im)
27
+ img = PIL.ImageEnhance.Brightness(img)
28
+ img = img.enhance(2.0)
29
+ visual_result = visualize_object_predictions(image=np.array(img), color=(100, 193, 203),
30
  object_prediction_list=res.object_prediction_list,
31
+ text_size=0.4, hide_labels=True, hide_conf=True,
32
  rect_th=1)
33
+ fig = px.imshow(visual_result['image'], width=900, height=900)
34
  fig.update_layout(showlegend=False, hovermode=False)
35
  fig.update_xaxes(visible=False)
36
  fig.update_yaxes(visible=False)
 
38
 
39
  inputs = [
40
  gr.Image(type='filepath', label='Input'),
41
+ gr.components.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label='Confidence Threshold'),
 
 
 
 
 
 
 
 
42
  ]
43
 
44
  outputs = [
 
48
  example_images = [[f'examples/{f}'] for f in os.listdir('examples')]
49
 
50
 
 
51
  gr.Interface(
52
  fn=inference,
53
  inputs=inputs,
 
57
  cache_examples=False,
58
  examples_per_page=10,
59
  title='Marine vessel detection from Sentinel 2 images',
60
+ description="""Model detects potential marine vessels from Sentinel 2 imagery.
61
+ Each example image covers 7.68x7.68 km (768x768 pixels). The result image has its brightness increased,
62
+ but the predictions are made based on Sentinel-2 L1C-TCI data.
63
  As we don't clean the prediction with stationary targets that look like vessels in this resolution,
64
+ there will most likely be false positives from lighthouses, above-water rocks and on land."""
65
+ ).launch(share=True, server_name='0.0.0.0')