from unsloth import FastLanguageModel from typing import Dict, List, Any import torch class EndpointHandler: def __init__(self, path=""): max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally! dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. self.model, self.tokenizer = FastLanguageModel.from_pretrained( model_name=path, max_seq_length=max_seq_length, dtype=dtype, load_in_4bit=load_in_4bit, # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf ) self.alpaca_prompt = """ ### Instruction: {} ### Input: {} ### Response: """ self.EOS_TOKEN = self.tokenizer.eos_token def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ data args: inputs (:obj: `str`) date (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ data = data.pop("inputs", data) input_text = data.get("input_text", "") lex_diversity = data.get("lex_diversity", 80) order_diversity = data.get("order_diversity", 20) repetition_penalty = data.get("repetition_penalty", 1.0) use_cache = data.get("use_cache", False) max_length = data.get("max_length", 128) prediction = self.paraphrase( input_text, lex_diversity, order_diversity, repetition_penalty=repetition_penalty, use_cache=use_cache, max_length=max_length ) prediction = {'prediction': prediction} return prediction def paraphrase(self, input_text, lex_diversity, order_diversity, repetition_penalty, use_cache, max_length, **kwargs): FastLanguageModel.for_inference(self.model) # Enable native 2x faster inference inputs = self.tokenizer( [ self.alpaca_prompt.format( "You are an AI assistant, capable of paraphrasing any text to a human-like version of the text. Human writing often exhibits bursts and lulls, with a mix of long and short sentences", # instruction f"lexical = {lex_diversity}, order = {order_diversity} {input_text}", "", # output - leave this blank for generation! ) ], return_tensors="pt").to("cuda") outputs = self.model.generate(**inputs, max_new_tokens=max_length, use_cache=False, repetition_penalty=repetition_penalty) output_text = self.tokenizer.batch_decode(outputs) return output_text