clemparpa's picture
removed to.cuda in 4bit
73d33c9 verified
raw
history blame
3.28 kB
from typing import Dict, List, Any
#from unsloth import FastLanguageModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch import bfloat16
class EndpointHandler():
max_seq_length = 4096 # Choose any! We auto support RoPE Scaling internally!
dtype = bfloat16 # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
max_new_tokens=500
summary_prompt = """
We will provide you an introduction to scientific research articles in the medical field, delimited by '<article>' tags.
Your mission is to produce a small summary of this introduction in a short paragraph of 200 characters maximum.
### Input:
<article>{}</article>
### Response:
{}
"""
def __init__(self, path=""):
self.model = AutoModelForCausalLM.from_pretrained(
path,
temperature=0,
# torch_dtype = self.dtype,
load_in_4bit = self.load_in_4bit,
)
# .to("cuda")
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.tokenizer.padding_side="left"
self.tokenizer.pad_token= self.tokenizer.eos_token
def _secure_inputs(self, data: Dict[str, Any]):
if not isinstance(data, dict):
return [{"error": "inputs should be shaped like {'temperature': float, 'inputs': <string or List of strings (abstracts)>}"}], False
if not 'inputs' in data:
return [{"error": "inputs should be shaped like {'temperature': float, 'inputs': <string or List of strings (abstracts)>}"}], False
temperature = data.get("temperature", 0.01)
inputs = data["inputs"]
if isinstance(inputs, str):
inputs = [inputs]
return inputs, temperature, True
def _format_inputs(self, inputs: list[str]):
prompts = [self.summary_prompt.format(abstract, "") for abstract in inputs]
prompts_lengths = [len(prompt) for prompt in prompts]
return prompts, prompts_lengths
def _generate_outputs(self, inputs, temperature):
tokenized = self.tokenizer(inputs, return_tensors="pt", padding=True).to("cuda")
outputs = self.model.generate(**tokenized, temperature=temperature, max_new_tokens=500, use_cache=True)
decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return decoded
def _format_outputs(self, outputs: list[str], inputs_lengths: list[int]):
decoded_without_input = [
output_str[input_len:].strip()
for output_str, input_len
in zip(outputs, inputs_lengths)
]
return decoded_without_input
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs, temperature, is_secure = self._secure_inputs(data)
if not is_secure:
return inputs
inputs, inputs_length = self._format_inputs(inputs)
outputs = self._generate_outputs(inputs, temperature)
outputs = self._format_outputs(outputs, inputs_length)
outputs = [{"summary": output_} for output_ in outputs]
return outputs