clemparpa's picture
rerere hotfix
0fa989d verified
raw
history blame
2.92 kB
from typing import Dict, List, Any
from unsloth import FastLanguageModel
class EndpointHandler():
max_seq_length = 4096 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = False # Use 4bit quantization to reduce memory usage. Can be False.
max_new_tokens=500
summary_prompt = """
We will provide you an introduction to scientific research articles in the medical field, delimited by '<article>' tags.
Your mission is to produce a small summary of this introduction in a short paragraph of 200 characters maximum.
### Input:
<article>{}</article>
### Response:
{}
"""
def __init__(self, path=""):
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=path,
max_seq_length = self.max_seq_length,
temperature=0,
dtype = self.dtype,
load_in_4bit = self.load_in_4bit,
)
FastLanguageModel.for_inference(self.model) # Enable native 2x faster inference
self.tokenizer.padding_side="left"
self.tokenizer.pad_token= self.tokenizer.eos_token
def _secure_inputs(self, data: Dict[str, Any]):
inputs = data.get("inputs", None)
if inputs is None:
return [{"error": "inputs should be shaped like {'inputs': <string or List of strings (abstracts)>}"}], False
if isinstance(inputs, str):
inputs = [inputs]
return inputs, True
def _format_inputs(self, inputs: list[str]):
prompts = [self.summary_prompt.format(abstract, "") for abstract in inputs]
prompts_lengths = [len(prompt) for prompt in prompts]
return prompts, prompts_lengths
def _generate_outputs(self, inputs):
tokenized = self.tokenizer(inputs, return_tensors="pt", padding=True).to("cuda")
outputs = self.model.generate(**tokenized, max_new_tokens=500, use_cache=True)
decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return decoded
def _format_outputs(self, outputs: list[str], inputs_lengths: list[int]):
decoded_without_input = [
output_str[input_len:].strip()
for output_str, input_len
in zip(outputs, inputs_lengths)
]
return decoded_without_input
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs, is_secure = self._secure_inputs(data)
if not is_secure:
return inputs
inputs, inputs_length = self._format_inputs(inputs)
outputs = self._generate_outputs(inputs)
outputs = self._format_outputs(outputs, inputs_length)
outputs = [{"summary": output_} for output_ in outputs]
return outputs