0xnewton-superlore commited on
Commit
426785e
1 Parent(s): d318463

nits throw on bad request

Browse files
Files changed (1) hide show
  1. handler.py +51 -6
handler.py CHANGED
@@ -1,11 +1,15 @@
1
  import base64
2
- import io
3
  import torch
4
  from typing import Dict, List, Any
 
5
  from transformers import CLIPProcessor, CLIPModel
6
  from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
7
  from PIL import Image
8
  from torch.nn.functional import cosine_similarity
 
 
 
 
9
 
10
  class EndpointHandler():
11
  def __init__(self, path: str="", image_size: int=224) -> None:
@@ -34,7 +38,7 @@ class EndpointHandler():
34
  data (Dict[str, Any]): A dictionary containing the following key:
35
  - "inputs" (Dict[str, list]): A dictionary containing the following keys:
36
  - "image_list" (List[str]): A list of base64-encoded images.
37
- - "text_list" (List[str]): A list of text strings.
38
 
39
  Returns:
40
  Dict[str, list]: A dictionary containing the following keys:
@@ -43,10 +47,37 @@ class EndpointHandler():
43
  - "similarity_scores" (List[List[float]]): A list of similarity scores between image and text embeddings.
44
  Empty if either "image_list" or "text_list" is empty.
45
  """
 
 
 
46
  inputs = data.get("inputs", {})
 
 
 
 
47
  image_list = inputs.get("image_list", []) # list of b64 images
48
- text_list = inputs.get("text_list", []) # list of texts
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  image_features = self.get_image_embeddings(image_list) if len(image_list) > 0 else None
51
  text_features = self.get_text_embeddings(text_list) if len(text_list) > 0 else None
52
 
@@ -68,7 +99,7 @@ class EndpointHandler():
68
  for base64_image in base64_images:
69
  # Decode the base64-encoded image and convert it to an RGB image
70
  image_data = base64.b64decode(base64_image)
71
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
72
  preprocessed_image = self.image_transform(image).unsqueeze(0)
73
  preprocessed_images.append(preprocessed_image)
74
 
@@ -83,7 +114,7 @@ class EndpointHandler():
83
 
84
  return image_features
85
 
86
- def get_text_embeddings(self, text_list: List[str]) -> torch.Tensor:
87
  with torch.no_grad():
88
  # Tokenize the input text list
89
  input_tokens = self.processor(text_list, return_tensors="pt", padding=True, truncation=True)
@@ -93,4 +124,18 @@ class EndpointHandler():
93
  text_features = self.model.get_text_features(**input_tokens)
94
  return text_features
95
 
96
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import base64
 
2
  import torch
3
  from typing import Dict, List, Any
4
+ from io import BytesIO
5
  from transformers import CLIPProcessor, CLIPModel
6
  from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
7
  from PIL import Image
8
  from torch.nn.functional import cosine_similarity
9
+ from typing import Union
10
+
11
+ max_text_list_length = 30
12
+ max_image_list_length = 20
13
 
14
  class EndpointHandler():
15
  def __init__(self, path: str="", image_size: int=224) -> None:
 
38
  data (Dict[str, Any]): A dictionary containing the following key:
39
  - "inputs" (Dict[str, list]): A dictionary containing the following keys:
40
  - "image_list" (List[str]): A list of base64-encoded images.
41
+ - "text_list" (Union[List[str], str]): A list of text strings.
42
 
43
  Returns:
44
  Dict[str, list]: A dictionary containing the following keys:
 
47
  - "similarity_scores" (List[List[float]]): A list of similarity scores between image and text embeddings.
48
  Empty if either "image_list" or "text_list" is empty.
49
  """
50
+ if not isinstance(data, dict):
51
+ raise ValueError("Expected input data to be a dict.")
52
+
53
  inputs = data.get("inputs", {})
54
+
55
+ if not isinstance(inputs, dict):
56
+ raise ValueError("Expected 'inputs' to be a dict.")
57
+
58
  image_list = inputs.get("image_list", []) # list of b64 images
59
+ text_list = inputs.get("text_list", []) # list of texts (or just plain string)
60
 
61
+ if not isinstance(image_list, list):
62
+ raise ValueError("Expected 'image_list' to be a list.")
63
+ if not isinstance(text_list, list) and not isinstance(text_list, str):
64
+ raise ValueError("Expected 'text_list' to be a list or string.")
65
+ if not all(isinstance(image, str) for image in image_list):
66
+ raise ValueError("Expected 'image_list' to contain only strings.")
67
+ if isinstance(text_list, list) and not all(isinstance(text, str) for text in text_list):
68
+ raise ValueError("Expected 'text_list' to contain only strings.")
69
+
70
+ # if text_list is a string, convert to list
71
+ if isinstance(text_list, str):
72
+ text_list = [text_list]
73
+
74
+ if len(image_list) > max_image_list_length:
75
+ raise ValueError(f"Expected 'image_list' to have a maximum length of {max_image_list_length}.")
76
+ if len(text_list) > max_text_list_length:
77
+ raise ValueError(f"Expected 'text_list' to have a maximum length of {max_text_list_length}.")
78
+ if not all(is_valid_base64_image(image) for image in image_list):
79
+ raise ValueError("Expected 'image_list' to contain only valid base64-encoded images.")
80
+
81
  image_features = self.get_image_embeddings(image_list) if len(image_list) > 0 else None
82
  text_features = self.get_text_embeddings(text_list) if len(text_list) > 0 else None
83
 
 
99
  for base64_image in base64_images:
100
  # Decode the base64-encoded image and convert it to an RGB image
101
  image_data = base64.b64decode(base64_image)
102
+ image = Image.open(BytesIO(image_data)).convert("RGB")
103
  preprocessed_image = self.image_transform(image).unsqueeze(0)
104
  preprocessed_images.append(preprocessed_image)
105
 
 
114
 
115
  return image_features
116
 
117
+ def get_text_embeddings(self, text_list: Union[List[str], str]) -> torch.Tensor:
118
  with torch.no_grad():
119
  # Tokenize the input text list
120
  input_tokens = self.processor(text_list, return_tensors="pt", padding=True, truncation=True)
 
124
  text_features = self.model.get_text_features(**input_tokens)
125
  return text_features
126
 
127
+
128
+ def is_valid_base64_image(data: str) -> bool:
129
+ try:
130
+ # Decode the base64 string
131
+ img_data = base64.b64decode(data)
132
+
133
+ # Open the image using PIL
134
+ img = Image.open(BytesIO(img_data))
135
+
136
+ # Check that the image format is supported
137
+ img.verify()
138
+
139
+ return True
140
+ except:
141
+ return False