|
import re |
|
|
|
import gradio as gr |
|
import torch |
|
from transformers import (AutoConfig, AutoModel, AutoModelForSeq2SeqLM, |
|
AutoTokenizer, LlamaForCausalLM, LlamaTokenizer) |
|
from vllm import LLM, SamplingParams |
|
|
|
model_id = "georgesung/llama2_7b_chat_uncensored" |
|
|
|
prompt_config = { |
|
"system_header": None, |
|
"system_footer": None, |
|
"user_header": "### HUMAN:", |
|
"user_footer": None, |
|
"input_header": None, |
|
"response_header": "### RESPONSE:", |
|
} |
|
|
|
def get_llm_response_chat(prompt): |
|
outputs = llm.generate(prompt, sampling_params) |
|
output = outputs[0].outputs[0].text |
|
|
|
|
|
eos_token = llm.get_tokenizer().eos_token |
|
if output.endswith(eos_token): |
|
output = output[:-len(eos_token)] |
|
return output |
|
|
|
def hist_to_prompt(history): |
|
prompt = "" |
|
if prompt_config["system_header"]: |
|
system_footer = "" |
|
if prompt_config["system_footer"]: |
|
system_footer = prompt_config["system_footer"] |
|
prompt += f"{prompt_config['system_header']}\n{SYSTEM_MESSAGE}{system_footer}\n\n" |
|
|
|
for i, (human_text, bot_text) in enumerate(history): |
|
user_footer = "" |
|
if prompt_config["user_footer"]: |
|
user_footer = prompt_config["user_footer"] |
|
|
|
prompt += f"{prompt_config['user_header']}\n{human_text}{user_footer}\n\n" |
|
|
|
prompt += f"{prompt_config['response_header']}\n" |
|
|
|
if bot_text: |
|
prompt += f"{bot_text}\n\n" |
|
return prompt |
|
|
|
def get_bot_response(text): |
|
bot_text_index = text.rfind(prompt_config['response_header']) |
|
if bot_text_index != -1: |
|
text = text[bot_text_index + len(prompt_config['response_header']):].strip() |
|
return text |
|
|
|
def main(): |
|
|
|
|
|
|
|
llm = LLM(model=model_id, tokenizer='hf-internal-testing/llama-tokenizer') |
|
|
|
sampling_params = SamplingParams(temperature=0.01, top_p=0.1, top_k=40, max_tokens=2048) |
|
|
|
tokenizer = llm.get_tokenizer() |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# Let's chat |
|
""") |
|
|
|
chatbot = gr.Chatbot() |
|
msg = gr.Textbox() |
|
clear = gr.Button("Clear") |
|
|
|
def user(user_message, history): |
|
return "", history + [[user_message, None]] |
|
|
|
def bot(history): |
|
hist_text = hist_to_prompt(history) |
|
|
|
bot_message = get_llm_response_chat(hist_text) |
|
history[-1][1] = bot_message |
|
|
|
return history |
|
|
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
bot, chatbot, chatbot |
|
) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
demo.queue() |
|
demo.launch() |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|