File size: 2,818 Bytes
59c3137
f53d612
 
59c3137
 
28a32c3
59c3137
f53d612
 
 
 
639e661
 
 
 
f53d612
639e661
f53d612
 
 
 
639e661
 
 
f53d612
 
 
 
 
 
 
 
 
 
 
 
 
 
639e661
 
b22dbd1
 
 
 
 
f132f1f
 
639e661
b22dbd1
 
387dfb8
b22dbd1
 
639e661
f132f1f
b22dbd1
 
 
639e661
f132f1f
b22dbd1
 
 
32478f9
b22dbd1
 
 
 
 
 
 
 
 
 
 
639e661
f132f1f
b22dbd1
 
f132f1f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from model import get_model 
from torchvision.transforms import ToTensor
from PIL import Image
import io
import os

# Constants
NUM_CLASSES = 4
CONFIDENCE_THRESHOLD = 0.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        Initialize the handler: load the model.
        """
        # Load the model
        self.model_weights_path = os.path.join(path, "model.pt")
        self.model = get_model(NUM_CLASSES).to(DEVICE)
        checkpoint = torch.load(self.model_weights_path, map_location=DEVICE)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.model.eval()

        # Preprocessing function
        self.preprocess = ToTensor()

        # Class labels
        self.label_map = {1: "yellow", 2: "red", 3: "blue"}

    def preprocess_frame(self, image_bytes):
        """
        Convert raw binary image data to a tensor.
        """
        # Load image from binary data
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        image_tensor = self.preprocess(image).unsqueeze(0).to(DEVICE)
        return image_tensor

    def __call__(self, data):
        """
        Process incoming raw binary image data.
        """
        try:
            if "body" not in data:
                # Return error in the expected output format
                return [{"error": "No image data provided in request."}]

            image_bytes = data["body"] 
            image_tensor = self.preprocess_frame(image_bytes)

            with torch.no_grad():
                predictions = self.model(image_tensor)

            # Extract predictions
            boxes = predictions[0]["boxes"].cpu().tolist()
            labels = predictions[0]["labels"].cpu().tolist()
            scores = predictions[0]["scores"].cpu().tolist()

            # Build the results array
            results = []
            for box, label, score in zip(boxes, labels, scores):
                if score >= CONFIDENCE_THRESHOLD:
                    x1, y1, x2, y2 = map(int, box)  # Ensure integers for box coordinates
                    label_text = self.label_map.get(label, "unknown") 
                    results.append({
                        "label": label_text,
                        "score": round(score, 2), 
                        "box": {
                            "xmin": x1,
                            "ymin": y1,
                            "xmax": x2,
                            "ymax": y2
                        }
                    })

            # Return results in the required schema
            return results 
        except Exception as e:
            # Return errors wrapped in a list to match the required schema
            return [{"error": str(e)}]