manish commited on
Commit
ac360bc
1 Parent(s): 28e5926

use pipeline

Browse files
Files changed (1) hide show
  1. handler.py +20 -7
handler.py CHANGED
@@ -1,13 +1,16 @@
1
  from typing import Dict, List, Any
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
  class EndpointHandler():
5
  def __init__(self, path=""):
6
  # init
7
- self.tokenizer = AutoTokenizer.from_pretrained("verseAI/vai-GPT-NeoXT-Chat-Base-20B")
8
- self.model = AutoModelForCausalLM.from_pretrained("verseAI/vai-GPT-NeoXT-Chat-Base-20B", device_map="auto", load_in_8bit=True)
 
 
 
9
 
10
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
11
  """
12
  data args:
13
  inputs (:obj: `str`)
@@ -18,15 +21,25 @@ class EndpointHandler():
18
  from transformers import AutoTokenizer, AutoModelForCausalLM
19
  """
20
 
21
- input = data.pop("inputs", data)
22
- print(input)
 
23
 
24
- # infer
 
 
 
 
 
 
 
 
25
  inputs = self.tokenizer("<human>: Hello!\n<bot>:", return_tensors='pt').to(self.model.device)
26
  outputs = self.model.generate(**inputs, max_new_tokens=10, do_sample=True, temperature=0.8)
27
  output_str = self.tokenizer.decode(outputs[0])
28
  print(output_str)
29
  # return output_str
30
  return {"generated_text": output_str}
 
31
 
32
 
 
1
  from typing import Dict, List, Any
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
 
4
  class EndpointHandler():
5
  def __init__(self, path=""):
6
  # init
7
+ # load the model
8
+ tokenizer = AutoTokenizer.from_pretrained("verseAI/vai-GPT-NeoXT-Chat-Base-20B")
9
+ model = AutoModelForCausalLM.from_pretrained("verseAI/vai-GPT-NeoXT-Chat-Base-20B", device_map="auto", load_in_8bit=True)
10
+ # create inference pipeline
11
+ self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
12
 
13
+ def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]:
14
  """
15
  data args:
16
  inputs (:obj: `str`)
 
21
  from transformers import AutoTokenizer, AutoModelForCausalLM
22
  """
23
 
24
+ inputs = data.pop("inputs", data)
25
+ parameters = data.pop("parameters", None)
26
+ # print(input)
27
 
28
+ # pass inputs with all kwargs in data
29
+ if parameters is not None:
30
+ prediction = self.pipeline(inputs, **parameters)
31
+ else:
32
+ prediction = self.pipeline(inputs)
33
+ # postprocess the prediction
34
+ return prediction
35
+
36
+ """
37
  inputs = self.tokenizer("<human>: Hello!\n<bot>:", return_tensors='pt').to(self.model.device)
38
  outputs = self.model.generate(**inputs, max_new_tokens=10, do_sample=True, temperature=0.8)
39
  output_str = self.tokenizer.decode(outputs[0])
40
  print(output_str)
41
  # return output_str
42
  return {"generated_text": output_str}
43
+ """
44
 
45