dangkhoa99's picture
Update handler.py
e21069a
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel, PeftConfig
import torch
import time
class EndpointHandler:
def __init__(self, path="dangkhoa99/falcon-7b-finetuned-QA-MRC-4-bit"):
# load the model
config = PeftConfig.from_pretrained(path)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
self.model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
load_in_4bit=True,
device_map={"":0},
trust_remote_code=True,
quantization_config=bnb_config,
)
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = PeftModel.from_pretrained(self.model, path)
def __call__(self, data: Any) -> Dict[str, Any]:
"""
Args:
inputs :obj:`list`:. The object should be like {"context": "some word", "question": "some word"} containing:
- "context":
- "question":
Return:
A :obj:`list`:. The object returned should be like {"answer": "some word", time: "..."} containing:
- "answer": answer the question based on the context
- "time": the time run predict
"""
inputs = data.pop("inputs", data)
context = inputs.pop("context", inputs)
question = inputs.pop("question", inputs)
prompt = f"""Answer the question based on the context below. If the question cannot be answered using the information provided answer with 'No answer'. Stop response if end.
>>TITLE<<: Flawless answer.
>>CONTEXT<<: {context}
>>QUESTION<<: {question}
>>ANSWER<<:
""".strip()
batch = self.tokenizer(
prompt,
padding=True,
truncation=True,
return_tensors='pt'
)
batch = batch.to('cuda:0')
generation_config = self.model.generation_config
generation_config.top_p = 0.7
generation_config.temperature = 0.7
generation_config.max_new_tokens = 256
generation_config.num_return_sequences = 1
generation_config.pad_token_id = self.tokenizer.eos_token_id
generation_config.eos_token_id = self.tokenizer.eos_token_id
start = time.time()
with torch.cuda.amp.autocast():
output_tokens = self.model.generate(
input_ids = batch.input_ids,
generation_config=generation_config,
)
end = time.time()
generated_text = self.tokenizer.decode(output_tokens[0])
prediction = {'answer': generated_text.split('>>END<<')[0].split('>>ANSWER<<:')[1].strip(), 'time': f"{(end-start):.2f} s"}
return prediction