import torch from model import get_model from torchvision.transforms import ToTensor from PIL import Image import io import os # Constants NUM_CLASSES = 4 CONFIDENCE_THRESHOLD = 0.5 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class EndpointHandler: def __init__(self, path: str = ""): """ Initialize the handler: load the model. """ # Load the model self.model_weights_path = os.path.join(path, "model.pt") self.model = get_model(NUM_CLASSES).to(DEVICE) checkpoint = torch.load(self.model_weights_path, map_location=DEVICE) self.model.load_state_dict(checkpoint["model_state_dict"]) self.model.eval() # Preprocessing function self.preprocess = ToTensor() # Class labels self.label_map = {1: "yellow", 2: "red", 3: "blue"} def preprocess_frame(self, image_bytes): """ Convert raw binary image data to a tensor. """ # Load image from binary data image = Image.open(io.BytesIO(image_bytes)).convert("RGB") image_tensor = self.preprocess(image).unsqueeze(0).to(DEVICE) return image_tensor def __call__(self, data): """ Process incoming raw binary image data. """ try: if "body" not in data: # Return error in the expected output format return [{"error": "No image data provided in request."}] image_bytes = data["body"] image_tensor = self.preprocess_frame(image_bytes) with torch.no_grad(): predictions = self.model(image_tensor) # Extract predictions boxes = predictions[0]["boxes"].cpu().tolist() labels = predictions[0]["labels"].cpu().tolist() scores = predictions[0]["scores"].cpu().tolist() # Build the results array results = [] for box, label, score in zip(boxes, labels, scores): if score >= CONFIDENCE_THRESHOLD: x1, y1, x2, y2 = map(int, box) # Ensure integers for box coordinates label_text = self.label_map.get(label, "unknown") results.append({ "label": label_text, "score": round(score, 2), "box": { "xmin": x1, "ymin": y1, "xmax": x2, "ymax": y2 } }) # Return results in the required schema return results except Exception as e: # Return errors wrapped in a list to match the required schema return [{"error": str(e)}]