CodeJackR commited on
Commit
2f4ef92
·
1 Parent(s): e0fb0e6

Fix image upload errors

Browse files
Files changed (1) hide show
  1. handler.py +19 -12
handler.py CHANGED
@@ -29,26 +29,33 @@ class EndpointHandler():
29
  self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
30
  self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
31
 
32
- def __call__(self, data):
33
  """
34
  Called on every HTTP request.
35
- Expecting base64 encoded image in the 'inputs' field.
36
  """
37
  # 1. Parse and decode the input image
38
- image_data = data.pop("inputs", None)
39
- if not image_data:
40
- raise ValueError("Missing 'inputs' key with a base64 image string.")
41
 
42
- if isinstance(image_data, str) and image_data.startswith("data:"):
43
- image_data = image_data.split(",", 1)[1]
44
-
45
- image_bytes = base64.b64decode(image_data)
46
- img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
 
 
 
 
 
 
47
 
48
  # 2. Prepare prompts and process the image
49
  height, width = img.size[1], img.size[0]
50
- input_points = [[[width // 2, height // 2]]] # Center point
51
- input_labels = [[1]] # Positive prompt
52
 
53
  inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
54
 
 
29
  self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
30
  self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
31
 
32
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
33
  """
34
  Called on every HTTP request.
35
+ Handles both base64-encoded images and PIL images.
36
  """
37
  # 1. Parse and decode the input image
38
+ inputs = data.pop("inputs", None)
39
+ if inputs is None:
40
+ raise ValueError("Missing 'inputs' key in the payload.")
41
 
42
+ # Check the type of inputs to handle both base64 strings and pre-processed PIL Images
43
+ if isinstance(inputs, Image.Image):
44
+ # Input is already a PIL Image
45
+ img = inputs.convert("RGB")
46
+ elif isinstance(inputs, str):
47
+ # Input is a base64-encoded string
48
+ if inputs.startswith("data:"):
49
+ inputs = inputs.split(",", 1)[1] # Handle data URL format
50
+ image_bytes = base64.b64decode(inputs)
51
+ img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
52
+ else:
53
+ raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
54
 
55
  # 2. Prepare prompts and process the image
56
  height, width = img.size[1], img.size[0]
57
+ input_points = [[[width // 2, height // 2]]]
58
+ input_labels = [[1]]
59
 
60
  inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
61