import os import torch from threading import Thread from typing import Iterator from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList ) from huggingface_hub import login login(token=os.environ["hf_read_token"]) class StopWordsCriteria(StoppingCriteria): def __init__(self, tokenizer, stop_words, stop_ids, stream_callback): self._tokenizer = tokenizer self._stop_words = stop_words self._stop_ids = stop_ids self._partial_result = '' self._stream_buffer = '' self._stream_callback = stream_callback # use both stop words (human id) and stop token ids (EOS tokens) def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> bool: first = not self._partial_result text = self._tokenizer.decode(input_ids[0, -1]) self._partial_result += text # Check stop words for stop_word in self._stop_words: if stop_word in self._partial_result: return True # Check stop ids for stop_id in self._stop_ids: if input_ids[0][-1] == stop_id: return True if self._stream_callback: if first: text = text.lstrip() # buffer tokens if the partial result ends with a prefix of a stop word, e.g. " str: texts = [f'<>\n{system_prompt}\n<>\n\n'] # The first user input is _not_ stripped do_strip = False for user_input, response in chat_history: user_input = user_input.strip() if do_strip else user_input do_strip = True texts.append(f'{user_input} : {response.strip()} : ') message = message.strip() if do_strip else message texts.append(f'{message} :') print(texts) print('---------------------------------------------') return ''.join(texts) def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int: prompt = get_prompt(message, chat_history, system_prompt) input_ids = tokenizer( [prompt], return_token_type_ids=False, return_tensors='np', add_special_tokens=False)['input_ids'] return input_ids.shape[-1] def run(message: str, chat_history: list[tuple[str, str]], system_prompt: str, max_new_tokens: int = 1024, temperature: float = 0.8, top_p: float = 0.90, top_k: int = 20) -> Iterator[str]: prompt = get_prompt(message, chat_history, system_prompt) print(prompt) print('=================================================') inputs = tokenizer( [prompt], return_token_type_ids=False, return_tensors='pt', add_special_tokens=False).to('cuda') streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) stop_criteria = StopWordsCriteria( tokenizer=tokenizer, stop_words=["", ""], stop_ids=[1,2,32001,32002], stream_callback=None ) generate_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, stopping_criteria=StoppingCriteriaList([stop_criteria]), num_beams=1, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield ''.join(outputs)