CheRy01 commited on
Commit
72d8d64
1 Parent(s): b16d24f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -116
app.py CHANGED
@@ -2,121 +2,44 @@ import gradio as gr
2
  import torch
3
  from ultralytics import YOLO
4
 
5
-
6
- import io
7
- import requests
8
- from google.colab import auth
9
- from googleapiclient.http import MediaIoBaseDownload
10
- from googleapiclient.discovery import build
11
-
12
- def download_from_google_drive(id, mime_type):
13
- auth.authenticate_user()
14
- drive_service = build('drive', 'v3')
15
- request = drive_service.files().export_media(fileId=id, mimeType=mime_type)
16
- file = io.BytesIO()
17
- downloader = MediaIoBaseDownload(file, request)
18
- done = False
19
- while done is False:
20
- status, done = downloader.next_chunk()
21
- print("Download %d%%." % int(status.progress() * 100))
22
- file.seek(0)
23
- return file
24
-
25
- def download_image_from_google_drive(shareable_link, image_name):
26
- file_id = shareable_link.split("/")[5]
27
- request = requests.get(shareable_link)
28
- request.raise_for_status()
29
- file_metadata = request.json()
30
- mime_type = file_metadata["mimeType"]
31
- image_file = download_from_google_drive(file_id, mime_type)
32
- with open(image_name, "wb") as f:
33
- f.write(image_file.read())
34
-
35
- download_image_from_google_drive('https://drive.google.com/file/d/1D0YH45dh5vubZsg54nAMBRw8RRKxlheC/view?usp=drive_link', 'one.jpg')
36
- download_image_from_google_drive('https://drive.google.com/file/d/18xJsjWJqSqA2ca6e2LO570rMoQeJKjGo/view?usp=drive_link', 'two.jpg')
37
- download_image_from_google_drive('https://drive.google.com/file/d/1PPQC0qkmoziNNEtKv7uPv9wh72hxNKMo/view?usp=drive_link', 'three.jpg')
38
-
39
-
40
- def yoloV8_func(image: gr.inputs.Image = None,
41
- image_size: gr.inputs.Slider = 640,
42
- conf_threshold: gr.inputs.Slider = 0.4,
43
- iou_threshold: gr.inputs.Slider = 0.50):
44
- """This function performs YOLOv8 object detection on the given image.
45
-
46
- Args:
47
- image (gr.inputs.Image, optional): Input image to detect objects on. Defaults to None.
48
- image_size (gr.inputs.Slider, optional): Desired image size for the model. Defaults to 640.
49
- conf_threshold (gr.inputs.Slider, optional): Confidence threshold for object detection. Defaults to 0.4.
50
- iou_threshold (gr.inputs.Slider, optional): Intersection over Union threshold for object detection. Defaults to 0.50.
51
- """
52
- # Load the YOLOv8 model from the 'best.pt' checkpoint
53
- model_path = "best.pt"
54
- model = YOLO(model_path)
55
-
56
- # Perform object detection on the input image using the YOLOv8 model
57
- results = model.predict(image,
58
- conf=conf_threshold,
59
- iou=iou_threshold,
60
- imgsz=image_size)
61
-
62
- # Print the detected objects' information (class, coordinates, and probability)
63
- box = results[0].boxes
64
- print("Object type:", box.cls)
65
- print("Coordinates:", box.xyxy)
66
- print("Probability:", box.conf)
67
-
68
- # Custom rendering function using OpenCV
69
- def render_custom(model, image, result):
70
- class_ids = result.boxes.cls.cpu().numpy()
71
- confidences = result.boxes.conf.cpu().numpy()
72
- x1y1x2y2 = result.boxes.xyxy.cpu().numpy()
73
-
74
- for i in range(len(class_ids)):
75
- class_id = class_ids[i]
76
- confidence = confidences[i]
77
- x1, y1, x2, y2 = x1y1x2y2[i]
78
-
79
- label = f"{model.names[class_id]} {confidence:.2f}"
80
- color = (0, 255, 0) # Green color for bounding boxes
81
-
82
- cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
83
- cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
84
-
85
- return image
86
-
87
- # Render the output image with bounding boxes around detected objects
88
- custom_render = render_custom(model, image, results[0])
89
- return custom_render
90
-
91
-
92
- inputs = [
93
- gr.inputs.Image(type="filepath", label="Input Image"),
94
- gr.inputs.Slider(minimum=320, maximum=1280, default=640,
95
- step=32, label="Image Size"),
96
- gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.25,
97
- step=0.05, label="Confidence Threshold"),
98
- gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.45,
99
- step=0.05, label="IOU Threshold"),
100
- ]
101
-
102
-
103
- outputs = gr.outputs.Image(type="filepath", label="Output Image")
104
-
105
- title = "Bone Fracture Detection"
106
-
107
-
108
- examples = [['one.jpg', 640, 0.5, 0.7],
109
- ['two.jpg', 800, 0.5, 0.6],
110
- ['three.jpg', 900, 0.5, 0.8]]
111
-
112
- yolo_app = gr.Interface(
113
- fn=yoloV8_func,
114
- inputs=inputs,
115
- outputs=outputs,
116
- title=title,
117
- examples=examples,
118
- cache_examples=True,
119
  )
120
 
121
- # Launch the Gradio interface in debug mode with queue enabled
122
- yolo_app.launch(debug=True, enable_queue=True)
 
2
  import torch
3
  from ultralytics import YOLO
4
 
5
+ file_path = 'best.pt'
6
+
7
+ def load_model(file_path):
8
+ # load the model weights from the file
9
+ yolov8_model = torch.load(file_path)
10
+
11
+ yolov8_model.eval()
12
+
13
+ return yolov8_model
14
+
15
+ def predict_fracture(image):
16
+ # Preprocess the image for YOLOv8
17
+ img_tensor = to_tensor(image).unsqueeze(0) # Convert image to tensor and add batch dimension
18
+ results = yolov8_model(img_tensor) # Perform inference
19
+
20
+ # Display the results on the image
21
+ img_with_boxes = image.copy()
22
+ for box in results.xyxy[0]:
23
+ label = int(box[5])
24
+ score = float(box[4])
25
+ if label == 0: # Assuming 0 corresponds to the bone fracture class
26
+ color = "red" if score > 0.5 else "orange" # Adjust the threshold as needed
27
+ xmin, ymin, xmax, ymax = box[:4].int().tolist()
28
+ img_with_boxes.rectangle([xmin, ymin, xmax, ymax], outline=color, width=2)
29
+ img_with_boxes.text((xmin, ymin), f"Fracture: {score:.2f}", font_size=12, color=color)
30
+
31
+ return img_with_boxes
32
+
33
+
34
+ # Gradio Interface
35
+ iface = gr.Interface(
36
+ predict_fracture,
37
+ inputs=gr.Image(),
38
+ outputs=gr.Image(),
39
+ live=True,
40
+ #capture_session=True,
41
+ title="Bone Fracture Detection",
42
+ description="Upload an X-ray image to detect bone fractures using YOLOv8.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  )
44
 
45
+ iface.launch()