torinriley commited on
Commit
dd95c76
1 Parent(s): 639e661

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +3 -10
handler.py CHANGED
@@ -4,7 +4,6 @@ from torchvision import transforms
4
  from PIL import Image
5
  import io
6
 
7
- # Import your Faster R-CNN model definition
8
  from model import get_model
9
 
10
  class EndpointHandler:
@@ -13,10 +12,10 @@ class EndpointHandler:
13
  Initialize the handler. Load the Faster R-CNN model.
14
  """
15
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- self.model_weights_path = os.path.join(path, "model.pt") # Adjust for your file name
17
 
18
  # Load the model
19
- self.model = get_model(num_classes=4) # Modify for your num_classes
20
  print(f"Loading weights from: {self.model_weights_path}")
21
  checkpoint = torch.load(self.model_weights_path, map_location=self.device)
22
  self.model.load_state_dict(checkpoint["model_state_dict"])
@@ -25,7 +24,7 @@ class EndpointHandler:
25
 
26
  # Define image preprocessing
27
  self.transform = transforms.Compose([
28
- transforms.Resize((640, 640)), # Adjust size to match your training setup
29
  transforms.ToTensor(),
30
  ])
31
 
@@ -34,27 +33,21 @@ class EndpointHandler:
34
  Process the incoming request and return object detection predictions.
35
  """
36
  try:
37
- # Expect input data to include a Base64-encoded image
38
  if "image" not in data:
39
  return [{"error": "No 'image' provided in request."}]
40
 
41
- # Convert Base64-encoded image to bytes
42
  image_bytes = data["image"].encode("latin1")
43
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
44
 
45
- # Preprocess the image
46
  input_tensor = self.transform(image).unsqueeze(0).to(self.device)
47
 
48
- # Run inference
49
  with torch.no_grad():
50
  outputs = self.model(input_tensor)
51
 
52
- # Extract results
53
  boxes = outputs[0]["boxes"].cpu().tolist()
54
  labels = outputs[0]["labels"].cpu().tolist()
55
  scores = outputs[0]["scores"].cpu().tolist()
56
 
57
- # Confidence threshold
58
  threshold = 0.5
59
  predictions = [
60
  {"box": box, "label": label, "score": score}
 
4
  from PIL import Image
5
  import io
6
 
 
7
  from model import get_model
8
 
9
  class EndpointHandler:
 
12
  Initialize the handler. Load the Faster R-CNN model.
13
  """
14
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ self.model_weights_path = os.path.join(path, "model.pt")
16
 
17
  # Load the model
18
+ self.model = get_model(num_classes=4)
19
  print(f"Loading weights from: {self.model_weights_path}")
20
  checkpoint = torch.load(self.model_weights_path, map_location=self.device)
21
  self.model.load_state_dict(checkpoint["model_state_dict"])
 
24
 
25
  # Define image preprocessing
26
  self.transform = transforms.Compose([
27
+ transforms.Resize((640, 640)),
28
  transforms.ToTensor(),
29
  ])
30
 
 
33
  Process the incoming request and return object detection predictions.
34
  """
35
  try:
 
36
  if "image" not in data:
37
  return [{"error": "No 'image' provided in request."}]
38
 
 
39
  image_bytes = data["image"].encode("latin1")
40
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
41
 
 
42
  input_tensor = self.transform(image).unsqueeze(0).to(self.device)
43
 
 
44
  with torch.no_grad():
45
  outputs = self.model(input_tensor)
46
 
 
47
  boxes = outputs[0]["boxes"].cpu().tolist()
48
  labels = outputs[0]["labels"].cpu().tolist()
49
  scores = outputs[0]["scores"].cpu().tolist()
50
 
 
51
  threshold = 0.5
52
  predictions = [
53
  {"box": box, "label": label, "score": score}