Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
from ultralytics import YOLO
|
7 |
+
import supervision as sv
|
8 |
+
from PIL import Image
|
9 |
+
from huggingface_hub import snapshot_download
|
10 |
+
|
11 |
+
# Set up environment and device
|
12 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
13 |
+
|
14 |
+
# Adjustable parameters for detection
|
15 |
+
CONFIDENCE_THRESHOLD = 0.1 # Confidence threshold for detections
|
16 |
+
NMS_THRESHOLD = 0 # IoU threshold for non-maximum suppression
|
17 |
+
SLICE_WIDTH = 1024 # Width of each slice
|
18 |
+
SLICE_HEIGHT = 1024 # Height of each slice
|
19 |
+
OVERLAP_WIDTH = 200 # Overlap width between slices
|
20 |
+
OVERLAP_HEIGHT = 200 # Overlap height between slices
|
21 |
+
ANNOTATION_COLOR = sv.Color.RED # Red in BGR format for OpenCV
|
22 |
+
ANNOTATION_THICKNESS = 4 # Thickness of bounding box lines
|
23 |
+
|
24 |
+
# Download YOLO model weights from Hugging Face Hub
|
25 |
+
repo_id = 'edeler/ICC' # Replace with your Hugging Face repository ID
|
26 |
+
model_dir = snapshot_download(repo_id, local_dir='./models/ICC')
|
27 |
+
model_path = os.path.join(model_dir, "best.pt") # Adjust if filename differs
|
28 |
+
model = YOLO(model_path).to(device)
|
29 |
+
|
30 |
+
# Define the detection function for Gradio
|
31 |
+
def detect_objects(image: np.ndarray) -> Image.Image:
|
32 |
+
# Ensure the image is in BGR format if provided by PIL (Gradio gives us an RGB image)
|
33 |
+
if image.shape[-1] == 3:
|
34 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
35 |
+
|
36 |
+
# Define callback function for slice-based inference
|
37 |
+
def callback(image_slice: np.ndarray) -> sv.Detections:
|
38 |
+
# Run inference on each slice
|
39 |
+
result = model(image_slice)[0]
|
40 |
+
# Convert detections to `sv.Detections` format for further processing
|
41 |
+
detections = sv.Detections.from_ultralytics(result)
|
42 |
+
# Filter detections based on confidence threshold
|
43 |
+
return detections[detections.confidence >= CONFIDENCE_THRESHOLD]
|
44 |
+
|
45 |
+
# Initialize InferenceSlicer with adjustable slice dimensions and overlap settings
|
46 |
+
slicer = sv.InferenceSlicer(
|
47 |
+
callback=callback,
|
48 |
+
slice_wh=(SLICE_WIDTH, SLICE_HEIGHT),
|
49 |
+
overlap_wh=(OVERLAP_WIDTH, OVERLAP_HEIGHT),
|
50 |
+
overlap_ratio_wh=None
|
51 |
+
)
|
52 |
+
|
53 |
+
# Perform slicing-based inference on the entire image
|
54 |
+
detections = slicer(image)
|
55 |
+
|
56 |
+
# Apply Non-Maximum Suppression (NMS) to the detections to avoid duplicate boxes
|
57 |
+
detections = detections.with_nms(threshold=NMS_THRESHOLD, class_agnostic=False)
|
58 |
+
|
59 |
+
# Initialize an annotator for bounding boxes with specified color and thickness
|
60 |
+
box_annotator = sv.OrientedBoxAnnotator(color=ANNOTATION_COLOR, thickness=ANNOTATION_THICKNESS)
|
61 |
+
|
62 |
+
# Annotate the image with bounding boxes after NMS
|
63 |
+
annotated_img = box_annotator.annotate(scene=image.copy(), detections=detections)
|
64 |
+
|
65 |
+
# Convert annotated image to RGB for Gradio display (PIL expects RGB)
|
66 |
+
annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
|
67 |
+
return Image.fromarray(annotated_img_rgb)
|
68 |
+
|
69 |
+
# Reset function for Gradio UI
|
70 |
+
def gradio_reset():
|
71 |
+
return gr.update(value=None), gr.update(value=None)
|
72 |
+
|
73 |
+
# Set up Gradio interface
|
74 |
+
with gr.Blocks() as demo:
|
75 |
+
gr.Markdown("<h1>Interstitial Cell of Cajal Detection and Quantification Tool</h1>")
|
76 |
+
|
77 |
+
with gr.Row():
|
78 |
+
with gr.Column():
|
79 |
+
input_img = gr.Image(label="Upload an Image", type="numpy", interactive=True)
|
80 |
+
clear = gr.Button(value="Clear")
|
81 |
+
predict = gr.Button(value="Detect", variant="primary")
|
82 |
+
|
83 |
+
with gr.Column():
|
84 |
+
output_img = gr.Image(label="Detection Result", interactive=False)
|
85 |
+
|
86 |
+
# Define button actions
|
87 |
+
clear.click(gradio_reset, inputs=None, outputs=[input_img, output_img])
|
88 |
+
predict.click(detect_objects, inputs=[input_img], outputs=[output_img])
|
89 |
+
|
90 |
+
# Launch Gradio app
|
91 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
|