|
import torch |
|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers import TextIteratorStreamer |
|
import threading |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
"RWKV-Red-Team/ARWKV-7B-Preview-0.1", |
|
device_map="auto", |
|
torch_dtype=torch.float16, |
|
trust_remote_code=True, |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"RWKV-Red-Team/ARWKV-7B-Preview-0.1" |
|
) |
|
device = "cuda" |
|
|
|
|
|
def convert_history_to_messages(history): |
|
messages = [] |
|
for user_msg, bot_msg in history: |
|
messages.append({"role": "user", "content": user_msg}) |
|
if bot_msg is not None: |
|
messages.append({"role": "assistant", "content": bot_msg}) |
|
return messages |
|
|
|
|
|
def stream_chat(prompt, history): |
|
|
|
messages = convert_history_to_messages(history) |
|
messages.append({"role": "user", "content": prompt}) |
|
|
|
text = tokenizer.apply_chat_template( |
|
messages, tokenize=False, add_generation_prompt=True |
|
) |
|
model_inputs = tokenizer([text], return_tensors="pt").to(device) |
|
|
|
streamer = TextIteratorStreamer( |
|
tokenizer, skip_prompt=True, skip_special_tokens=True |
|
) |
|
|
|
generation_kwargs = dict( |
|
model_inputs, |
|
streamer=streamer, |
|
max_new_tokens=4096, |
|
do_sample=True, |
|
temperature=1.5, |
|
top_p=0.2, |
|
top_k=0, |
|
) |
|
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
response = "" |
|
for new_text in streamer: |
|
response += new_text |
|
yield history + [(prompt, response)] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
chatbot = gr.Chatbot(label="Chat with LLM", height=750) |
|
msg = gr.Textbox(label="Your Message") |
|
clear = gr.Button("Clear Chat") |
|
|
|
def user(user_message, history): |
|
return "", history + [[user_message, None]] |
|
|
|
def bot(history): |
|
prompt = history[-1][0] |
|
history[-1][1] = "" |
|
for updated_history in stream_chat(prompt, history[:-1]): |
|
yield updated_history |
|
|
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
bot, chatbot, chatbot |
|
) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
demo.queue().launch(server_name="0.0.0.0") |
|
|