Spaces:
Runtime error
Runtime error
File size: 4,596 Bytes
84c4d0e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
import torch
from threading import Thread
from typing import Iterator
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
StoppingCriteria,
StoppingCriteriaList
)
from huggingface_hub import login
login(token=os.environ["hf_read_token"])
class StopWordsCriteria(StoppingCriteria):
def __init__(self, tokenizer, stop_words, stop_ids, stream_callback):
self._tokenizer = tokenizer
self._stop_words = stop_words
self._stop_ids = stop_ids
self._partial_result = ''
self._stream_buffer = ''
self._stream_callback = stream_callback
# use both stop words (human id) and stop token ids (EOS tokens)
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
first = not self._partial_result
text = self._tokenizer.decode(input_ids[0, -1])
self._partial_result += text
# Check stop words
for stop_word in self._stop_words:
if stop_word in self._partial_result:
return True
# Check stop ids
for stop_id in self._stop_ids:
if input_ids[0][-1] == stop_id:
return True
if self._stream_callback:
if first:
text = text.lstrip()
# buffer tokens if the partial result ends with a prefix of a stop word, e.g. "<hu"
for stop_word in self._stop_words:
for i in range(1, len(stop_word)):
if self._partial_result.endswith(stop_word[0:i]):
self._stream_buffer += text
return False
self._stream_callback(self._stream_buffer + text)
self._stream_buffer = ''
return False
model_id = "medalpaca/medalpaca-7b"
if torch.cuda.is_available():
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map='auto',
use_auth_token=True,
)
else:
model = None
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=True)
def get_prompt(message: str, chat_history: list[tuple[str, str]],
system_prompt: str) -> str:
texts = [f'<<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
# The first user input is _not_ stripped
do_strip = False
for user_input, response in chat_history:
user_input = user_input.strip() if do_strip else user_input
do_strip = True
texts.append(f'{user_input} <Answer>: {response.strip()} <Question>: ')
message = message.strip() if do_strip else message
texts.append(f'{message} <Answer>:')
print(texts)
print('---------------------------------------------')
return ''.join(texts)
def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
prompt = get_prompt(message, chat_history, system_prompt)
input_ids = tokenizer(
[prompt],
return_token_type_ids=False,
return_tensors='np',
add_special_tokens=False)['input_ids']
return input_ids.shape[-1]
def run(message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.8,
top_p: float = 0.90,
top_k: int = 20) -> Iterator[str]:
prompt = get_prompt(message, chat_history, system_prompt)
print(prompt)
print('=================================================')
inputs = tokenizer(
[prompt],
return_token_type_ids=False,
return_tensors='pt',
add_special_tokens=False).to('cuda')
streamer = TextIteratorStreamer(tokenizer,
timeout=10.,
skip_prompt=True,
skip_special_tokens=True)
stop_criteria = StopWordsCriteria(
tokenizer=tokenizer,
stop_words=["<Question>", "<Answer>"],
stop_ids=[1,2,32001,32002],
stream_callback=None
)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
stopping_criteria=StoppingCriteriaList([stop_criteria]),
num_beams=1,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield ''.join(outputs)
|