kiranr commited on
Commit
41a22ce
1 Parent(s): c59532d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +30 -5
handler.py CHANGED
@@ -6,22 +6,47 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
  device = 0 if torch.cuda.is_available() else -1
7
 
8
 
 
 
 
 
 
 
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
  # load the model
12
  tokenizer = AutoTokenizer.from_pretrained(path)
13
- model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True)
 
 
 
 
14
  # create inference pipeline
15
- self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
 
 
 
 
 
 
16
 
17
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
18
  inputs = data.pop("inputs", data)
19
  parameters = data.pop("parameters", None)
20
 
 
 
21
  # pass inputs with all kwargs in data
22
  if parameters is not None:
23
- prediction = self.pipeline(inputs, **parameters)
24
  else:
25
- prediction = self.pipeline(inputs)
 
26
  # postprocess the prediction
27
- return prediction
 
 
 
 
 
 
6
  device = 0 if torch.cuda.is_available() else -1
7
 
8
 
9
+ format_input = (
10
+ "Below is an instruction that describes a task. "
11
+ "Write a response that appropriately completes the request.\n\n"
12
+ "### Instruction:\n{instruction}\n\n### Response:"
13
+ )
14
+
15
+
16
  class EndpointHandler:
17
  def __init__(self, path=""):
18
  # load the model
19
  tokenizer = AutoTokenizer.from_pretrained(path)
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ path,
22
+ device_map="auto",
23
+ torch_dtype=torch.float16,
24
+ )
25
  # create inference pipeline
26
+ self.pipeline = pipeline(
27
+ "text-generation",
28
+ model=model,
29
+ tokenizer=tokenizer,
30
+ device=device,
31
+ max_length=256,
32
+ )
33
 
34
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
35
  inputs = data.pop("inputs", data)
36
  parameters = data.pop("parameters", None)
37
 
38
+ text_input = format_input.format(instruction=inputs)
39
+
40
  # pass inputs with all kwargs in data
41
  if parameters is not None:
42
+ prediction = self.pipeline(text_input, **parameters)
43
  else:
44
+ prediction = self.pipeline(text_input)
45
+
46
  # postprocess the prediction
47
+ output = [
48
+ {"generated_text": pred["generated_text"].split("### Response:")[1].strip()}
49
+ for pred in prediction
50
+ ]
51
+
52
+ return output