ksee commited on
Commit
fd296d0
1 Parent(s): 115be49

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +21 -20
handler.py CHANGED
@@ -13,24 +13,25 @@ class EndpointHandler():
13
  self.model = PeftModel.from_pretrained(model, path)
14
 
15
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
16
- """
17
- Args:
18
- data (Dict): The payload with the text prompt
19
- and generation parameters.
20
- """
21
- # Get inputs
22
- prompt = data.pop("inputs", None)
23
- parameters = data.pop("parameters", None)
24
- if prompt is None:
25
- raise ValueError("Missing prompt.")
26
- # Preprocess
27
- input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
28
- # Forward
29
- if parameters is not None:
30
- output = self.model.generate(input_ids=input_ids, **parameters)
31
- else:
32
- output = self.model.generate(input_ids=input_ids)
33
- # Postprocess
34
- prediction = self.tokenizer.decode(output[0])
35
- return {"generated_text": prediction}
 
36
 
 
13
  self.model = PeftModel.from_pretrained(model, path)
14
 
15
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
16
+ """
17
+ Args:
18
+ data (Dict): The payload with the text prompt
19
+ and generation parameters.
20
+ """
21
+ # Get inputs
22
+ prompt = data.pop("inputs", None)
23
+ parameters = data.pop("parameters", None)
24
+ if prompt is None:
25
+ raise ValueError("Missing prompt.")
26
+ # Preprocess
27
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
28
+ # Forward
29
+ # if parameters is not None:
30
+ # output = self.model.generate(input_ids=input_ids, **parameters)
31
+ # else:
32
+ # output = self.model.generate(input_ids=input_ids)
33
+ output = self.model.generate(input_ids, temperature=0.9, max_new_tokens=50)
34
+ # Postprocess
35
+ prediction = self.tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0]
36
+ return {"generated_text": prediction}
37