manish commited on
Commit
28e5926
1 Parent(s): 0d9a7b6

change return type

Browse files
Files changed (1) hide show
  1. handler.py +8 -2
handler.py CHANGED
@@ -7,7 +7,7 @@ class EndpointHandler():
7
  self.tokenizer = AutoTokenizer.from_pretrained("verseAI/vai-GPT-NeoXT-Chat-Base-20B")
8
  self.model = AutoModelForCausalLM.from_pretrained("verseAI/vai-GPT-NeoXT-Chat-Base-20B", device_map="auto", load_in_8bit=True)
9
 
10
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
11
  """
12
  data args:
13
  inputs (:obj: `str`)
@@ -17,10 +17,16 @@ class EndpointHandler():
17
 
18
  from transformers import AutoTokenizer, AutoModelForCausalLM
19
  """
 
 
 
 
20
  # infer
21
  inputs = self.tokenizer("<human>: Hello!\n<bot>:", return_tensors='pt').to(self.model.device)
22
  outputs = self.model.generate(**inputs, max_new_tokens=10, do_sample=True, temperature=0.8)
23
  output_str = self.tokenizer.decode(outputs[0])
24
  print(output_str)
25
- return output_str
 
 
26
 
 
7
  self.tokenizer = AutoTokenizer.from_pretrained("verseAI/vai-GPT-NeoXT-Chat-Base-20B")
8
  self.model = AutoModelForCausalLM.from_pretrained("verseAI/vai-GPT-NeoXT-Chat-Base-20B", device_map="auto", load_in_8bit=True)
9
 
10
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
11
  """
12
  data args:
13
  inputs (:obj: `str`)
 
17
 
18
  from transformers import AutoTokenizer, AutoModelForCausalLM
19
  """
20
+
21
+ input = data.pop("inputs", data)
22
+ print(input)
23
+
24
  # infer
25
  inputs = self.tokenizer("<human>: Hello!\n<bot>:", return_tensors='pt').to(self.model.device)
26
  outputs = self.model.generate(**inputs, max_new_tokens=10, do_sample=True, temperature=0.8)
27
  output_str = self.tokenizer.decode(outputs[0])
28
  print(output_str)
29
+ # return output_str
30
+ return {"generated_text": output_str}
31
+
32