Files changed (1) hide show
  1. handler.py +4 -4
handler.py CHANGED
@@ -47,14 +47,14 @@ class EndpointHandler():
47
  return inputs, True
48
 
49
  def _format_inputs(self, inputs: list[str]):
50
- prompts = [summary_prompt.format(abstract, "") for abstract in inputs]
51
  prompts_lengths = [len(prompt) for prompt in prompts]
52
  return prompts, prompts_lengths
53
 
54
  def _generate_outputs(self, inputs):
55
- tokenized = tokenizer(inputs, return_tensors="pt", padding=True).to("cuda")
56
- outputs = model.generate(**tokenized, max_new_tokens=500, use_cache=True)
57
- decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
58
  return decoded
59
 
60
  def _format_outputs(self, outputs: list[str], inputs_lengths: list[int]):
 
47
  return inputs, True
48
 
49
  def _format_inputs(self, inputs: list[str]):
50
+ prompts = [self.summary_prompt.format(abstract, "") for abstract in inputs]
51
  prompts_lengths = [len(prompt) for prompt in prompts]
52
  return prompts, prompts_lengths
53
 
54
  def _generate_outputs(self, inputs):
55
+ tokenized = self.tokenizer(inputs, return_tensors="pt", padding=True).to("cuda")
56
+ outputs = self.model.generate(**tokenized, max_new_tokens=500, use_cache=True)
57
+ decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
58
  return decoded
59
 
60
  def _format_outputs(self, outputs: list[str], inputs_lengths: list[int]):