CodeJackR commited on
Commit
c78d04e
·
1 Parent(s): f9b3f94

Input image as image

Browse files
Files changed (1) hide show
  1. handler.py +23 -36
handler.py CHANGED
@@ -9,6 +9,9 @@ from transformers import SamModel, SamProcessor
9
  from typing import Dict, List, Any
10
  import torch.nn.functional as F
11
 
 
 
 
12
  class EndpointHandler():
13
  def __init__(self, path=""):
14
  """
@@ -17,51 +20,29 @@ class EndpointHandler():
17
  """
18
  try:
19
  # Load the model and processor from the local path
20
- self.model = SamModel.from_pretrained(path)
21
  self.processor = SamProcessor.from_pretrained(path)
22
  except Exception as e:
23
  # Fallback to loading from a known SAM model if local loading fails
24
  print(f"Failed to load from local path: {e}")
25
  print("Attempting to load from facebook/sam-vit-base")
26
- self.model = SamModel.from_pretrained("facebook/sam-vit-base")
27
  self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
28
 
29
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
30
  """
31
  Called on every HTTP request.
32
- Expecting base64 encoded image in the 'inputs' field or 'image' field.
 
 
33
  """
34
- # Handle different input formats
35
- if "inputs" in data:
36
- if isinstance(data["inputs"], str):
37
- # Handle data URL format (data:image/jpeg;base64,...)
38
- image_data = data["inputs"]
39
- if image_data.startswith("data:"):
40
- # Strip data URL prefix
41
- image_data = image_data.split(",", 1)[1]
42
- # Base64 encoded image
43
- image_bytes = base64.b64decode(image_data)
44
- elif isinstance(data["inputs"], dict) and "image" in data["inputs"]:
45
- # Nested structure with image field
46
- image_data = data["inputs"]["image"]
47
- if image_data.startswith("data:"):
48
- # Strip data URL prefix
49
- image_data = image_data.split(",", 1)[1]
50
- image_bytes = base64.b64decode(image_data)
51
- else:
52
- raise ValueError("Invalid input format. Expected base64 encoded image string.")
53
- elif "image" in data:
54
- # Direct image field
55
- image_data = data["image"]
56
- if image_data.startswith("data:"):
57
- # Strip data URL prefix
58
- image_data = image_data.split(",", 1)[1]
59
- image_bytes = base64.b64decode(image_data)
60
- else:
61
- raise ValueError("No image found in request. Expected 'inputs' or 'image' field with base64 encoded image.")
62
-
63
- # Process the image
64
- img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
65
 
66
  # SAM requires input prompts, so we'll generate a center point prompt
67
  height, width = img.size[1], img.size[0] # PIL returns (width, height)
@@ -120,8 +101,14 @@ class EndpointHandler():
120
  out.seek(0)
121
  mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
122
 
 
 
 
 
 
 
123
  # Return in the expected format
124
- return [{"mask_png_base64": mask_base64, "num_masks": 1}]
125
 
126
  def main():
127
  # Hardcoded input and output paths
 
9
  from typing import Dict, List, Any
10
  import torch.nn.functional as F
11
 
12
+ # set device
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+
15
  class EndpointHandler():
16
  def __init__(self, path=""):
17
  """
 
20
  """
21
  try:
22
  # Load the model and processor from the local path
23
+ self.model = SamModel.from_pretrained(path).to(device)
24
  self.processor = SamProcessor.from_pretrained(path)
25
  except Exception as e:
26
  # Fallback to loading from a known SAM model if local loading fails
27
  print(f"Failed to load from local path: {e}")
28
  print("Attempting to load from facebook/sam-vit-base")
29
+ self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
30
  self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
31
 
32
+ def __call__(self, data: Any) -> Any:
33
  """
34
  Called on every HTTP request.
35
+ Args:
36
+ data (:obj:):
37
+ includes the input data and the parameters for the inference.
38
  """
39
+ inputs = data.pop("inputs", data)
40
+ parameters = data.pop("parameters", {})
41
+
42
+ raw_images = [Image.open(io.BytesIO(_img)) for _img in inputs]
43
+
44
+ # img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
45
+ img = raw_images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # SAM requires input prompts, so we'll generate a center point prompt
48
  height, width = img.size[1], img.size[0] # PIL returns (width, height)
 
101
  out.seek(0)
102
  mask_base64 = base64.b64encode(out.getvalue()).decode('utf-8')
103
 
104
+ # Decode the returned mask and save
105
+ mask_bytes = base64.b64decode(mask_base64)
106
+ mask_img = Image.open(io.BytesIO(mask_bytes)).convert("RGB")
107
+ # mask_img.save(output_path, format="JPEG")
108
+ # print(f"Wrote mask to {output_path}")
109
+
110
  # Return in the expected format
111
+ return mask_img
112
 
113
  def main():
114
  # Hardcoded input and output paths