binaryaaron commited on
Commit
2e3949d
·
unverified ·
1 Parent(s): dfa140e

update handler for inputs and parameters

Browse files
Files changed (1) hide show
  1. handler.py +13 -8
handler.py CHANGED
@@ -2,21 +2,26 @@ from typing import Dict, List, Any
2
  import transformers
3
  import torch
4
 
5
- MAX_TOKENS=8192
6
 
7
  class EndpointHandler(object):
8
  def __init__(self, path=''):
9
  self.pipeline: transformers.Pipeline = transformers.pipeline(
10
  "text-generation",
11
  model="humane-intelligence/gemma2-9b-cpt-sealionv3-instruct-endpoint",
12
- model_kwargs={"torch_dtype": torch.bfloat16, "low_cpu_mem_usage": True, },
13
  device_map="auto",
14
  )
15
 
16
- def __call__(self, text_inputs: Any) -> List[List[Dict[str, float]]]:
17
- outputs = self.pipeline(
18
- text_inputs,
19
- max_new_tokens=MAX_TOKENS,
20
- )
21
- print(outputs[0]["generated_text"][-1])
 
 
 
 
 
22
  return outputs
 
2
  import transformers
3
  import torch
4
 
5
+ MAX_TOKENS=1024
6
 
7
  class EndpointHandler(object):
8
  def __init__(self, path=''):
9
  self.pipeline: transformers.Pipeline = transformers.pipeline(
10
  "text-generation",
11
  model="humane-intelligence/gemma2-9b-cpt-sealionv3-instruct-endpoint",
12
+ model_kwargs={"torch_dtype": torch.bfloat16 },
13
  device_map="auto",
14
  )
15
 
16
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
17
+ inputs = data.pop("inputs")
18
+
19
+ if parameters:= data.pop("parameters", None):
20
+ outputs = self.pipeline(
21
+ inputs,
22
+ **parameters
23
+ )
24
+ else:
25
+ outputs = self.pipeline(inputs, max_new_tokens=MAX_TOKENS)
26
+
27
  return outputs