|
from typing import Dict, List, Any |
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from torch import bfloat16 |
|
|
|
|
|
|
|
class EndpointHandler(): |
|
max_seq_length = 4096 |
|
dtype = bfloat16 |
|
load_in_4bit = True |
|
max_new_tokens=500 |
|
|
|
summary_prompt = """ |
|
We will provide you an introduction to scientific research articles in the medical field, delimited by '<article>' tags. |
|
Your mission is to produce a small summary of this introduction in a short paragraph of 200 characters maximum. |
|
|
|
### Input: |
|
<article>{}</article> |
|
|
|
### Response: |
|
{} |
|
""" |
|
|
|
|
|
def __init__(self, path=""): |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
path, |
|
temperature=0, |
|
|
|
load_in_4bit = self.load_in_4bit, |
|
) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
self.tokenizer.padding_side="left" |
|
self.tokenizer.pad_token= self.tokenizer.eos_token |
|
|
|
|
|
def _secure_inputs(self, data: Dict[str, Any]): |
|
if not isinstance(data, dict): |
|
return [{"error": "inputs should be shaped like {'temperature': float, 'inputs': <string or List of strings (abstracts)>}"}], False |
|
|
|
if not 'inputs' in data: |
|
return [{"error": "inputs should be shaped like {'temperature': float, 'inputs': <string or List of strings (abstracts)>}"}], False |
|
|
|
temperature = data.get("temperature", 0.01) |
|
inputs = data["inputs"] |
|
|
|
if isinstance(inputs, str): |
|
inputs = [inputs] |
|
|
|
return inputs, temperature, True |
|
|
|
|
|
def _format_inputs(self, inputs: list[str]): |
|
prompts = [self.summary_prompt.format(abstract, "") for abstract in inputs] |
|
prompts_lengths = [len(prompt) for prompt in prompts] |
|
return prompts, prompts_lengths |
|
|
|
def _generate_outputs(self, inputs, temperature): |
|
tokenized = self.tokenizer(inputs, return_tensors="pt", padding=True).to("cuda") |
|
outputs = self.model.generate(**tokenized, temperature=temperature, max_new_tokens=500, use_cache=True) |
|
decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
return decoded |
|
|
|
def _format_outputs(self, outputs: list[str], inputs_lengths: list[int]): |
|
decoded_without_input = [ |
|
output_str[input_len:].strip() |
|
for output_str, input_len |
|
in zip(outputs, inputs_lengths) |
|
] |
|
return decoded_without_input |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
inputs, temperature, is_secure = self._secure_inputs(data) |
|
|
|
if not is_secure: |
|
return inputs |
|
|
|
inputs, inputs_length = self._format_inputs(inputs) |
|
outputs = self._generate_outputs(inputs, temperature) |
|
outputs = self._format_outputs(outputs, inputs_length) |
|
|
|
outputs = [{"summary": output_} for output_ in outputs] |
|
|
|
return outputs |
|
|