parkinson-prediction / modules.py
Mahimai Raja J
Initial Release
f747692
raw
history blame
1.31 kB
from super_gradients.training import models
import supervision as sv
import torch
import cv2
# install the required packages
# pip install -q super-gradients==3.1.1 supervision
# predict function to predict the image
def predict(img, model):
SOURCE_IMAGE_PATH = img
image = cv2.imread(SOURCE_IMAGE_PATH)
result = list(model.predict(image, conf=0.70))[0]
detections = sv.Detections(
xyxy=result.prediction.bboxes_xyxy,
confidence=result.prediction.confidence,
class_id=result.prediction.labels.astype(int)
)
# Create a box annotator
box_annotator = sv.BoxAnnotator()
# Get the labels
labels = [
f"{result.class_names[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, _
in detections
]
# Annotate the image
annotated_frame = box_annotator.annotate(
scene=image.copy(),
detections=detections,
labels=labels
)
# sv.plot_image(annotated_frame, (12, 12))
return annotated_frame, labels
def setup():
DEVICE = 'cuda' if torch.cuda.is_available() else "cpu" # 'cpu' or 'cuda'
MODEL_ARCH = 'yolo_nas_l'
# Load the model
model = models.get(
MODEL_ARCH,
num_classes=2,
checkpoint_path="model.pth"
).to(DEVICE)
return model