FTCVision-PyTorch / handler.py
torinriley's picture
Update handler.py
32478f9 verified
import torch
from model import get_model
from torchvision.transforms import ToTensor
from PIL import Image
import io
import os
# Constants
NUM_CLASSES = 4
CONFIDENCE_THRESHOLD = 0.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the handler: load the model.
"""
# Load the model
self.model_weights_path = os.path.join(path, "model.pt")
self.model = get_model(NUM_CLASSES).to(DEVICE)
checkpoint = torch.load(self.model_weights_path, map_location=DEVICE)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.model.eval()
# Preprocessing function
self.preprocess = ToTensor()
# Class labels
self.label_map = {1: "yellow", 2: "red", 3: "blue"}
def preprocess_frame(self, image_bytes):
"""
Convert raw binary image data to a tensor.
"""
# Load image from binary data
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image_tensor = self.preprocess(image).unsqueeze(0).to(DEVICE)
return image_tensor
def __call__(self, data):
"""
Process incoming raw binary image data.
"""
try:
if "body" not in data:
# Return error in the expected output format
return [{"error": "No image data provided in request."}]
image_bytes = data["body"]
image_tensor = self.preprocess_frame(image_bytes)
with torch.no_grad():
predictions = self.model(image_tensor)
# Extract predictions
boxes = predictions[0]["boxes"].cpu().tolist()
labels = predictions[0]["labels"].cpu().tolist()
scores = predictions[0]["scores"].cpu().tolist()
# Build the results array
results = []
for box, label, score in zip(boxes, labels, scores):
if score >= CONFIDENCE_THRESHOLD:
x1, y1, x2, y2 = map(int, box) # Ensure integers for box coordinates
label_text = self.label_map.get(label, "unknown")
results.append({
"label": label_text,
"score": round(score, 2),
"box": {
"xmin": x1,
"ymin": y1,
"xmax": x2,
"ymax": y2
}
})
# Return results in the required schema
return results
except Exception as e:
# Return errors wrapped in a list to match the required schema
return [{"error": str(e)}]