from typing import Dict, List, Any #from unsloth import FastLanguageModel from transformers import AutoTokenizer, AutoModelForCausalLM from torch import bfloat16 class EndpointHandler(): max_seq_length = 4096 # Choose any! We auto support RoPE Scaling internally! dtype = bfloat16 # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. max_new_tokens=500 summary_prompt = """ We will provide you an introduction to scientific research articles in the medical field, delimited by '
' tags. Your mission is to produce a small summary of this introduction in a short paragraph of 200 characters maximum. ### Input:
{}
### Response: {} """ def __init__(self, path=""): self.model = AutoModelForCausalLM.from_pretrained( path, temperature=0, # torch_dtype = self.dtype, load_in_4bit = self.load_in_4bit, ) # .to("cuda") self.tokenizer = AutoTokenizer.from_pretrained(path) self.tokenizer.padding_side="left" self.tokenizer.pad_token= self.tokenizer.eos_token def _secure_inputs(self, data: Dict[str, Any]): if not isinstance(data, dict): return [{"error": "inputs should be shaped like {'temperature': float, 'inputs': }"}], False if not 'inputs' in data: return [{"error": "inputs should be shaped like {'temperature': float, 'inputs': }"}], False temperature = data.get("temperature", 0.01) inputs = data["inputs"] if isinstance(inputs, str): inputs = [inputs] return inputs, temperature, True def _format_inputs(self, inputs: list[str]): prompts = [self.summary_prompt.format(abstract, "") for abstract in inputs] prompts_lengths = [len(prompt) for prompt in prompts] return prompts, prompts_lengths def _generate_outputs(self, inputs, temperature): tokenized = self.tokenizer(inputs, return_tensors="pt", padding=True).to("cuda") outputs = self.model.generate(**tokenized, temperature=temperature, max_new_tokens=500, use_cache=True) decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) return decoded def _format_outputs(self, outputs: list[str], inputs_lengths: list[int]): decoded_without_input = [ output_str[input_len:].strip() for output_str, input_len in zip(outputs, inputs_lengths) ] return decoded_without_input def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: inputs, temperature, is_secure = self._secure_inputs(data) if not is_secure: return inputs inputs, inputs_length = self._format_inputs(inputs) outputs = self._generate_outputs(inputs, temperature) outputs = self._format_outputs(outputs, inputs_length) outputs = [{"summary": output_} for output_ in outputs] return outputs