Thomasboosinger commited on
Commit
03dec79
1 Parent(s): bbb90ac

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +14 -14
handler.py CHANGED
@@ -6,30 +6,30 @@ from typing import Dict, List, Any
6
 
7
  class EndpointHandler():
8
  def __init__(self, model_path=""):
9
- # Initialize the pipeline with the specified model and set the device to GPU
 
10
  self.pipeline = pipeline(task="zero-shot-object-detection", model=model_path, device=0)
11
 
12
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
13
  """
14
- Process an incoming request for zero-shot object detection.
 
15
 
16
  Args:
17
  data (Dict[str, Any]): The input data containing an encoded image and candidate labels.
18
 
19
  Returns:
20
- A list of dictionaries, each containing a label and its corresponding score.
21
  """
22
- # Correctly accessing the 'inputs' key and fixing the typo in 'candidates'
23
- inputs = data.get("inputs", {})
24
-
25
- # Decode the base64 image to a PIL image
26
- image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
27
 
28
- # Get candidate labels
29
- candidate_labels=inputs["candidates"]
30
 
31
- # Correctly passing the image and candidate labels to the pipeline
32
- detection_results = self.pipeline(image=image, candidate_labels=inputs["candidates"])
33
 
34
- # Adjusting the return statement to match the expected output structure
35
- return detection_results
 
6
 
7
  class EndpointHandler():
8
  def __init__(self, model_path=""):
9
+ # Initialize the zero-shot object detection pipeline with the specified model
10
+ # and set the device to GPU for faster computation.
11
  self.pipeline = pipeline(task="zero-shot-object-detection", model=model_path, device=0)
12
 
13
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
14
  """
15
+ Handles incoming requests for zero-shot object detection, decoding the image
16
+ and predicting labels based on provided candidates.
17
 
18
  Args:
19
  data (Dict[str, Any]): The input data containing an encoded image and candidate labels.
20
 
21
  Returns:
22
+ List[Dict[str, Any]]: Predictions with labels and scores for the detected objects.
23
  """
24
+ # Decode the base64-encoded image to a PIL Image object for processing.
25
+ image_data = data.get("inputs", {}).get('image', '')
26
+ image = Image.open(BytesIO(base64.b64decode(image_data)))
 
 
27
 
28
+ # Extract candidate labels from the input data.
29
+ candidate_labels = data.get("inputs", {}).get("candidates", [])
30
 
31
+ # Perform zero-shot object detection using the provided image and candidate labels.
32
+ detection_results = self.pipeline(image=image, candidate_labels=candidate_labels)
33
 
34
+ # Return the detection results directly, which should match the expected output structure.
35
+ return detection_results