Ashishkr's picture
Update model.py
46b4cb9
raw
history blame contribute delete
No virus
3.87 kB
from threading import Thread
from typing import Iterator
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
import transformers
from torch import cuda, bfloat16
from peft import PeftModel, PeftConfig
token = os.environ.get("HF_API_TOKEN")
base_model_id = 'meta-llama/Llama-2-7b-chat-hf'
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
bnb_config = transformers.BitsAndBytesConfig(
llm_int8_enable_fp32_cpu_offload = True
)
model_config = transformers.AutoConfig.from_pretrained(
base_model_id,
use_auth_token=token
)
model = transformers.AutoModelForCausalLM.from_pretrained(
base_model_id,
trust_remote_code=True,
config=model_config,
quantization_config=bnb_config,
# device_map='auto',
use_auth_token=token
)
config = PeftConfig.from_pretrained("Ashishkr/llama-2-medical-consultation")
model = PeftModel.from_pretrained(model, "Ashishkr/llama-2-medical-consultation").to(device)
model.eval()
tokenizer = transformers.AutoTokenizer.from_pretrained(
base_model_id,
use_auth_token=token
)
# def get_prompt(message: str, chat_history: list[tuple[str, str]],
# system_prompt: str) -> str:
# texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
# # The first user input is _not_ stripped
# do_strip = False
# for user_input, response in chat_history:
# user_input = user_input.strip() if do_strip else user_input
# do_strip = True
# texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
# message = message.strip() if do_strip else message
# texts.append(f'{message} [/INST]')
# return ''.join(texts)
def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str:
texts = [f'{system_prompt}\n']
if chat_history:
for user_input, response in chat_history[:-1]:
texts.append(f'{user_input} {response}\n')
# Getting the user input and response from the last tuple in the chat history
last_user_input, last_response = chat_history[-1]
texts.append(f' input: {last_user_input} {last_response} {message} Response: ')
else:
texts.append(f' input: {message} Response: ')
return ''.join(texts)
def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
prompt = get_prompt(message, chat_history, system_prompt)
input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
return input_ids.shape[-1]
def run(message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.8,
top_p: float = 0.95,
top_k: int = 50) -> Iterator[str]:
prompt = get_prompt(message, chat_history, system_prompt)
inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to(device)
streamer = TextIteratorStreamer(tokenizer,
timeout=20.,
skip_prompt=True,
skip_special_tokens=True)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
if "instruction:" in text:
# Append only the part of text before "instruction:" and stop streaming
outputs.append(text.split("instruction:")[0])
break
else:
outputs.append(text)
yield ''.join(outputs)