Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import yaml | |
import numpy as np | |
from PIL import Image | |
from roboflow import Roboflow | |
from ultralytics import YOLO | |
def load_model(file_path): | |
# Load the Roboflow model | |
rf = Roboflow(api_key="K1TXQnJq7EE7yoCf1g3C") | |
project = rf.workspace("fyp-l87nq").project("bone-fracture-detection-rkuqr") | |
model = project.version(3).model | |
# Load the model weights into a PyTorch model | |
pytorch_model = YOLO('args.yaml') | |
pytorch_model.load_state_dict(torch.load(file_path, map_location=torch.device('cpu'))) | |
pytorch_model.eval() | |
return pytorch_model | |
file_path = 'best.pt' | |
model = load_model(file_path) | |
def predict_fracture(image): | |
# Preprocess the image for the Roboflow model | |
img = Image.fromarray(image) | |
img_tensor = to_tensor(img).unsqueeze(0) # Convert image to tensor and add batch dimension | |
# Perform inference with the Roboflow model | |
with torch.no_grad(): | |
output = model(img_tensor) | |
# Postprocess the inference output | |
results = output[0] | |
img_with_boxes = image.copy() | |
for box in results: | |
label = int(box[5]) | |
score = float(box[4]) | |
if label == 0: # Assuming 0 corresponds to the bone fracture class | |
color = "red" if score > 0.5 else "orange" # Adjust the threshold as needed | |
xmin, ymin, xmax, ymax = box[:4].int().tolist() | |
img_with_boxes.rectangle([xmin, ymin, xmax, ymax], outline=color, width=2) | |
img_with_boxes.text((xmin, ymin), f"Fracture: {score:.2f}", font_size=12, color=color) | |
return Image.fromarray((np.uint8(img_with_boxes))) | |
# Define the to_tensor function | |
def to_tensor(image): | |
image = np.array(image) / 255.0 | |
return torch.from_numpy(image.transpose((2, 0, 1))).float() | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=lambda *args, **kwargs: predict_fracture(args[0], load_model), | |
inputs=gr.Image(), | |
outputs=gr.Image(), | |
live=True, | |
title="Bone Fracture Detection", | |
description="Upload an X-ray image to detect bone fractures using Roboflow's YOLOv8 model.", | |
) | |
iface.launch() | |