LLaMA-7B / app.py
ysharma's picture
ysharma HF staff
update css styling for chatbot
0635d16
raw
history blame
No virus
1.27 kB
import os
import torch
import gradio as gr
from strings import TITLE, ABSTRACT
from gen import get_pretrained_models, get_output, setup_model_parallel
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "50505"
local_rank, world_size = setup_model_parallel()
generator = get_pretrained_models("7B", "tokenizer", local_rank, world_size)
history = []
def chat(user_input):
bot_response = get_output(generator, user_input)[0]
history.append({
"role": "user",
"content": user_input
})
history.append({
"role": "system",
"content": bot_response
})
response = ""
for word in bot_response.split(" "):
response += word + " "
yield [(user_input, response)]
with gr.Blocks(css = """#col_container {width: 700px; margin-left: auto; margin-right: auto;}
#chatbot {height: 400px; overflow: auto;}""") as demo:
gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}")
with gr.Column(elem_id='col_container'):
chatbot = gr.Chatbot(elem_id='chatbot')
textbox = gr.Textbox(placeholder="Enter a prompt")
textbox.submit(chat, textbox, chatbot)
demo.queue(api_open=False).launch()