Spaces:
Runtime error
Runtime error
from threading import Thread | |
import gradio as gr | |
import torch | |
from transformers import ( | |
pipeline, | |
AutoTokenizer, | |
TextIteratorStreamer, | |
) | |
def chat_history(history) -> str: | |
messages = [] | |
for dialog in history: | |
for i, message in enumerate(dialog): | |
role = "user" if i % 2 == 0 else "assistant" | |
messages.append({"role": role, "content": message}) | |
messages.pop(-1) | |
return pipe.tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
def model_loading_pipeline(): | |
model_id = "vilm/vinallama-2.7b" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, Timeout=5) | |
pipe = pipeline( | |
"text-generation", | |
model=model_id, | |
model_kwargs={ | |
"torch_dtype": torch.bfloat16, | |
}, | |
streamer=streamer, | |
) | |
return pipe, streamer | |
def launch_app(pipe, streamer): | |
with gr.Blocks() as demo: | |
chat = gr.Chatbot() | |
msg = gr.Textbox() | |
clear = gr.Button("Clear") | |
def user(user_message, history): | |
return "", history + [[user_message, None]] | |
def bot(history): | |
prompt = chat_history(history) | |
history[-1][1] = "" | |
kwargs = { | |
"text_inputs": prompt, | |
"max_new_tokens": 64, | |
"do_sample": True, | |
"temperature": 0.7, | |
"top_k": 50, | |
"top_p": 0.95, | |
} | |
thread = Thread(target=pipe, kwargs=kwargs) | |
thread.start() | |
for token in streamer: | |
history[-1][1] += token | |
yield history | |
msg.submit(user, [msg, chat], [msg, chat], queue=False).then(bot, chat, chat) | |
clear.click(lambda: None, None, chat, queue=False) | |
demo.queue() | |
demo.launch(share=True, debug=True) | |
if __name__ == "__main__": | |
pipe, streamer = model_loading_pipeline() | |
launch_app(pipe, streamer) | |