alvarobartt HF staff commited on
Commit
bc5f69b
·
verified ·
1 Parent(s): 5a89820

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +18 -6
handler.py CHANGED
@@ -151,6 +151,8 @@ class EndpointHandler:
151
  )
152
 
153
  def __call__(self, data: Dict[str, Any]) -> Dict[str, List[Any]]:
 
 
154
  if "instances" in data:
155
  logger.warning("Using `instances` instead of `inputs` is deprecated.")
156
  data["inputs"] = data.pop("instances")
@@ -159,19 +161,29 @@ class EndpointHandler:
159
  raise ValueError(
160
  "The request body must contain a key 'inputs' with a list of inputs."
161
  )
162
-
163
- logger.info(f"Received incoming request with {data=}")
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  predictions = []
166
  for input in data["inputs"]:
167
  if "prompt" not in input:
168
  raise ValueError(
169
- "The request input body must contain a key 'prompt' with the prompt to use."
170
  )
171
 
172
- logger.info(f"{input=}")
173
- # generation_config = input.get("generation_config", dict(max_new_tokens=1024, do_sample=False))
174
- generation_config = dict(max_new_tokens=1024, do_sample=False)
175
 
176
  if "image_url" not in input:
177
  # pure-text conversation
 
151
  )
152
 
153
  def __call__(self, data: Dict[str, Any]) -> Dict[str, List[Any]]:
154
+ logger.info(f"Received incoming request with {data=}")
155
+
156
  if "instances" in data:
157
  logger.warning("Using `instances` instead of `inputs` is deprecated.")
158
  data["inputs"] = data.pop("instances")
 
161
  raise ValueError(
162
  "The request body must contain a key 'inputs' with a list of inputs."
163
  )
164
+
165
+ if not isinstance(data["inputs"], list):
166
+ raise ValueError(
167
+ "The request inputs must be a list of dictionaries with either the key"
168
+ " 'prompt' or 'prompt' + 'image_url', and optionally including the key"
169
+ " 'generation_config'."
170
+ )
171
+
172
+ if not all(isinstance(input, dict) and "prompt" in input.keys() for input in data["inputs"]):
173
+ raise ValueError(
174
+ "The request inputs must be a list of dictionaries with either the key"
175
+ " 'prompt' or 'prompt' + 'image_url', and optionally including the key"
176
+ " 'generation_config'."
177
+ )
178
 
179
  predictions = []
180
  for input in data["inputs"]:
181
  if "prompt" not in input:
182
  raise ValueError(
183
+ "The request input body must contain at least the key 'prompt' with the prompt to use."
184
  )
185
 
186
+ generation_config = input.get("generation_config", dict(max_new_tokens=1024, do_sample=False))
 
 
187
 
188
  if "image_url" not in input:
189
  # pure-text conversation