vincentmireau clemparpa commited on
Commit
7138786
·
verified ·
1 Parent(s): 0550bc2

update handler py in order to deal with batch inputs (#2)

Browse files

- update handler py in order to deal with batch inputs (b8b060e69b5d03b9a2e58faf8b1f29460e68454a)


Co-authored-by: parpaillon <clemparpa@users.noreply.huggingface.co>

Files changed (1) hide show
  1. handler.py +42 -26
handler.py CHANGED
@@ -32,34 +32,50 @@ class EndpointHandler():
32
  )
33
  FastLanguageModel.for_inference(self.model) # Enable native 2x faster inference
34
 
 
 
 
35
 
36
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
37
  inputs = data.get("inputs", None)
38
  if inputs is None:
39
- return [{"err": "no inputs"}]
40
-
41
- if not isinstance(inputs, str):
42
- return [{"err": "inputs must bet of type {'inputs': str}"}]
43
-
44
-
45
- tokenized_inputs = (
46
- self.tokenizer(
47
- self.summary_prompt.format(inputs, ""),
48
- return_tensors="pt"
49
- )
50
- .to("cuda")
51
- )
52
-
53
- outputs = self.model.generate(
54
- **tokenized_inputs,
55
- max_new_tokens=self.max_new_tokens,
56
- use_cache=True
57
- )
58
- outputs = outputs[:, tokenized_inputs["input_ids"].shape[1]:]
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
61
- outputs = outputs.strip(" ").strip("\n")
 
62
 
63
- return [{
64
- "result": outputs
65
- }]
 
32
  )
33
  FastLanguageModel.for_inference(self.model) # Enable native 2x faster inference
34
 
35
+ self.tokenizer.padding_side="left"
36
+ self.tokenizer.pad_token=tokenizer.eos_token
37
+
38
 
39
+ def _secure_inputs(self, data: Dict[str, Any]):
40
  inputs = data.get("inputs", None)
41
  if inputs is None:
42
+ return [{"error": "inputs should be shaped like {'inputs': <string or List of strings (abstracts)>}"}], False
43
+
44
+ if isinstance(inputs, str):
45
+ inputs = [inputs]
46
+
47
+ return inputs, True
48
+
49
+ def _format_inputs(self, inputs: list[str]):
50
+ prompts = [summary_prompt.format(abstract, "") for abstract in abstracts]
51
+ prompts_lengths = [len(prompt) for prompt in prompts]
52
+ return prompts, prompts_lengths
53
+
54
+ def _generate_outputs(self, inputs):
55
+ tokenized = tokenizer(inputs, return_tensors="pt", padding=True).to("cuda")
56
+ outputs = model.generate(**tokenized, max_new_tokens=500, use_cache=True)
57
+ decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
58
+ return decoded
59
+
60
+ def _format_outputs(self, outputs: list[str], inputs_lengths: list[int]):
61
+ decoded_without_input = [
62
+ output_str[input_len:].strip()
63
+ for output_str, input_len
64
+ in zip(outputs, inputs_lengths)
65
+ ]
66
+ return decoded_without_input
67
+
68
+
69
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
70
+ inputs, is_secure = self._secure_inputs(data)
71
+
72
+ if not is_secure:
73
+ return inputs
74
 
75
+ inputs, inputs_length = self._format_inputs(inputs)
76
+ outputs = self._generate_outputs(inputs)
77
+ outputs = self._format_outputs(outputs, inputs_length)
78
 
79
+ outputs = [{"summary": output_} for output_ in outputs]
80
+
81
+ return outputs