torinriley commited on
Commit
387dfb8
1 Parent(s): 3f726f1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +23 -18
handler.py CHANGED
@@ -1,9 +1,9 @@
1
- import os
2
  import torch
3
  from torchvision import transforms
4
  from PIL import Image
5
  import io
6
 
 
7
  from model import get_model
8
 
9
  class EndpointHandler:
@@ -12,49 +12,54 @@ class EndpointHandler:
12
  Initialize the handler. Load the Faster R-CNN model.
13
  """
14
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- self.model_weights_path = os.path.join(path, "model.pt")
16
-
17
- # Load the model
18
  self.model = get_model(num_classes=4)
19
- print(f"Loading weights from: {self.model_weights_path}")
20
  checkpoint = torch.load(self.model_weights_path, map_location=self.device)
21
  self.model.load_state_dict(checkpoint["model_state_dict"])
22
  self.model.to(self.device)
23
  self.model.eval()
24
 
25
- # Define image preprocessing
26
  self.transform = transforms.Compose([
27
- transforms.Resize((640, 640)),
28
  transforms.ToTensor(),
29
  ])
30
 
31
  def __call__(self, data):
32
  """
33
- Process the incoming request and return object detection predictions.
34
  """
35
  try:
36
- if "image" not in data:
37
- return [{"error": "No 'image' provided in request."}]
 
 
38
 
39
- image_bytes = data["image"].encode("latin1")
40
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
41
 
 
42
  input_tensor = self.transform(image).unsqueeze(0).to(self.device)
43
 
 
44
  with torch.no_grad():
45
- outputs = self.model(input_tensor)
46
 
47
- boxes = outputs[0]["boxes"].cpu().tolist()
48
- labels = outputs[0]["labels"].cpu().tolist()
49
- scores = outputs[0]["scores"].cpu().tolist()
 
50
 
 
51
  threshold = 0.5
52
- predictions = [
53
  {"box": box, "label": label, "score": score}
54
  for box, label, score in zip(boxes, labels, scores)
55
  if score > threshold
56
  ]
57
 
58
- return [{"predictions": predictions}]
59
  except Exception as e:
60
- return [{"error": str(e)}]
 
 
1
  import torch
2
  from torchvision import transforms
3
  from PIL import Image
4
  import io
5
 
6
+ # Load the Faster R-CNN model
7
  from model import get_model
8
 
9
  class EndpointHandler:
 
12
  Initialize the handler. Load the Faster R-CNN model.
13
  """
14
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ self.model_weights_path = os.path.join(path, "model.pt") # Adjust path
16
+
17
+ # Load model
18
  self.model = get_model(num_classes=4)
 
19
  checkpoint = torch.load(self.model_weights_path, map_location=self.device)
20
  self.model.load_state_dict(checkpoint["model_state_dict"])
21
  self.model.to(self.device)
22
  self.model.eval()
23
 
24
+ # Image preprocessing
25
  self.transform = transforms.Compose([
26
+ transforms.Resize((640, 640)),
27
  transforms.ToTensor(),
28
  ])
29
 
30
  def __call__(self, data):
31
  """
32
+ Process incoming binary image data and return object detection results.
33
  """
34
  try:
35
+ # Read raw binary data (image file)
36
+ image_bytes = data.get("body", b"")
37
+ if not image_bytes:
38
+ return {"error": "No image data provided in request."}
39
 
40
+
41
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
42
 
43
+
44
  input_tensor = self.transform(image).unsqueeze(0).to(self.device)
45
 
46
+
47
  with torch.no_grad():
48
+ predictions = self.model(input_tensor)
49
 
50
+
51
+ boxes = predictions[0]["boxes"].cpu().tolist()
52
+ labels = predictions[0]["labels"].cpu().tolist()
53
+ scores = predictions[0]["scores"].cpu().tolist()
54
 
55
+
56
  threshold = 0.5
57
+ results = [
58
  {"box": box, "label": label, "score": score}
59
  for box, label, score in zip(boxes, labels, scores)
60
  if score > threshold
61
  ]
62
 
63
+ return {"predictions": results}
64
  except Exception as e:
65
+ return {"error": str(e)}