KK-dev's picture
model-file
84c4d0e verified
import os
import torch
from threading import Thread
from typing import Iterator
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
StoppingCriteria,
StoppingCriteriaList
)
from huggingface_hub import login
login(token=os.environ["hf_read_token"])
class StopWordsCriteria(StoppingCriteria):
def __init__(self, tokenizer, stop_words, stop_ids, stream_callback):
self._tokenizer = tokenizer
self._stop_words = stop_words
self._stop_ids = stop_ids
self._partial_result = ''
self._stream_buffer = ''
self._stream_callback = stream_callback
# use both stop words (human id) and stop token ids (EOS tokens)
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
first = not self._partial_result
text = self._tokenizer.decode(input_ids[0, -1])
self._partial_result += text
# Check stop words
for stop_word in self._stop_words:
if stop_word in self._partial_result:
return True
# Check stop ids
for stop_id in self._stop_ids:
if input_ids[0][-1] == stop_id:
return True
if self._stream_callback:
if first:
text = text.lstrip()
# buffer tokens if the partial result ends with a prefix of a stop word, e.g. "<hu"
for stop_word in self._stop_words:
for i in range(1, len(stop_word)):
if self._partial_result.endswith(stop_word[0:i]):
self._stream_buffer += text
return False
self._stream_callback(self._stream_buffer + text)
self._stream_buffer = ''
return False
model_id = "medalpaca/medalpaca-7b"
if torch.cuda.is_available():
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map='auto',
use_auth_token=True,
)
else:
model = None
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=True)
def get_prompt(message: str, chat_history: list[tuple[str, str]],
system_prompt: str) -> str:
texts = [f'<<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
# The first user input is _not_ stripped
do_strip = False
for user_input, response in chat_history:
user_input = user_input.strip() if do_strip else user_input
do_strip = True
texts.append(f'{user_input} <Answer>: {response.strip()} <Question>: ')
message = message.strip() if do_strip else message
texts.append(f'{message} <Answer>:')
print(texts)
print('---------------------------------------------')
return ''.join(texts)
def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
prompt = get_prompt(message, chat_history, system_prompt)
input_ids = tokenizer(
[prompt],
return_token_type_ids=False,
return_tensors='np',
add_special_tokens=False)['input_ids']
return input_ids.shape[-1]
def run(message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.8,
top_p: float = 0.90,
top_k: int = 20) -> Iterator[str]:
prompt = get_prompt(message, chat_history, system_prompt)
print(prompt)
print('=================================================')
inputs = tokenizer(
[prompt],
return_token_type_ids=False,
return_tensors='pt',
add_special_tokens=False).to('cuda')
streamer = TextIteratorStreamer(tokenizer,
timeout=10.,
skip_prompt=True,
skip_special_tokens=True)
stop_criteria = StopWordsCriteria(
tokenizer=tokenizer,
stop_words=["<Question>", "<Answer>"],
stop_ids=[1,2,32001,32002],
stream_callback=None
)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
stopping_criteria=StoppingCriteriaList([stop_criteria]),
num_beams=1,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield ''.join(outputs)