CodeJackR
commited on
Commit
·
2f4ef92
1
Parent(s):
e0fb0e6
Fix image upload errors
Browse files- handler.py +19 -12
handler.py
CHANGED
|
@@ -29,26 +29,33 @@ class EndpointHandler():
|
|
| 29 |
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
|
| 30 |
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
| 31 |
|
| 32 |
-
def __call__(self, data):
|
| 33 |
"""
|
| 34 |
Called on every HTTP request.
|
| 35 |
-
|
| 36 |
"""
|
| 37 |
# 1. Parse and decode the input image
|
| 38 |
-
|
| 39 |
-
if
|
| 40 |
-
raise ValueError("Missing 'inputs' key
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
# 2. Prepare prompts and process the image
|
| 49 |
height, width = img.size[1], img.size[0]
|
| 50 |
-
input_points = [[[width // 2, height // 2]]]
|
| 51 |
-
input_labels = [[1]]
|
| 52 |
|
| 53 |
inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
|
| 54 |
|
|
|
|
| 29 |
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
|
| 30 |
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
| 31 |
|
| 32 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 33 |
"""
|
| 34 |
Called on every HTTP request.
|
| 35 |
+
Handles both base64-encoded images and PIL images.
|
| 36 |
"""
|
| 37 |
# 1. Parse and decode the input image
|
| 38 |
+
inputs = data.pop("inputs", None)
|
| 39 |
+
if inputs is None:
|
| 40 |
+
raise ValueError("Missing 'inputs' key in the payload.")
|
| 41 |
|
| 42 |
+
# Check the type of inputs to handle both base64 strings and pre-processed PIL Images
|
| 43 |
+
if isinstance(inputs, Image.Image):
|
| 44 |
+
# Input is already a PIL Image
|
| 45 |
+
img = inputs.convert("RGB")
|
| 46 |
+
elif isinstance(inputs, str):
|
| 47 |
+
# Input is a base64-encoded string
|
| 48 |
+
if inputs.startswith("data:"):
|
| 49 |
+
inputs = inputs.split(",", 1)[1] # Handle data URL format
|
| 50 |
+
image_bytes = base64.b64decode(inputs)
|
| 51 |
+
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 52 |
+
else:
|
| 53 |
+
raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
|
| 54 |
|
| 55 |
# 2. Prepare prompts and process the image
|
| 56 |
height, width = img.size[1], img.size[0]
|
| 57 |
+
input_points = [[[width // 2, height // 2]]]
|
| 58 |
+
input_labels = [[1]]
|
| 59 |
|
| 60 |
inputs = self.processor(img, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
|
| 61 |
|