from threading import Thread from typing import Iterator import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import os import transformers from torch import cuda, bfloat16 from peft import PeftModel, PeftConfig token = os.environ.get("HF_API_TOKEN") base_model_id = 'meta-llama/Llama-2-7b-chat-hf' device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu' bnb_config = transformers.BitsAndBytesConfig( llm_int8_enable_fp32_cpu_offload = True ) model_config = transformers.AutoConfig.from_pretrained( base_model_id, use_auth_token=token ) model = transformers.AutoModelForCausalLM.from_pretrained( base_model_id, trust_remote_code=True, config=model_config, quantization_config=bnb_config, # device_map='auto', use_auth_token=token ) config = PeftConfig.from_pretrained("Ashishkr/llama-2-medical-consultation") model = PeftModel.from_pretrained(model, "Ashishkr/llama-2-medical-consultation").to(device) model.eval() tokenizer = transformers.AutoTokenizer.from_pretrained( base_model_id, use_auth_token=hf_auth ) def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str: texts = [f'[INST] <>\n{system_prompt}\n<>\n\n'] # The first user input is _not_ stripped do_strip = False for user_input, response in chat_history: user_input = user_input.strip() if do_strip else user_input do_strip = True texts.append(f'{user_input} [/INST] {response.strip()} [INST] ') message = message.strip() if do_strip else message texts.append(f'{message} [/INST]') return ''.join(texts) def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int: prompt = get_prompt(message, chat_history, system_prompt) input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids'] return input_ids.shape[-1] def run(message: str, chat_history: list[tuple[str, str]], system_prompt: str, max_new_tokens: int = 1024, temperature: float = 0.8, top_p: float = 0.95, top_k: int = 50) -> Iterator[str]: prompt = get_prompt(message, chat_history, system_prompt) inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda') streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield ''.join(outputs)