File size: 1,267 Bytes
f745223
e8fb838
1ffd977
 
13a089e
e8fb838
 
f745223
cbcb343
f745223
 
 
d8a82cd
52c453e
f745223
13a089e
 
1ffd977
3988351
13a089e
 
 
 
 
 
 
 
 
 
1ffd977
 
f57923a
 
1ffd977
0635d16
 
13a089e
0635d16
 
 
13a089e
0635d16
1ffd977
2334dc1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import torch
import gradio as gr

from strings import TITLE, ABSTRACT 
from gen import get_pretrained_models, get_output, setup_model_parallel

os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "50505"

local_rank, world_size = setup_model_parallel()
generator = get_pretrained_models("7B", "tokenizer", local_rank, world_size)

history = []

def chat(user_input):
    bot_response = get_output(generator, user_input)[0]

    history.append({
        "role": "user",
        "content": user_input
    })
    history.append({
        "role": "system",
        "content": bot_response
    })    
    
    response = ""
    for word in bot_response.split(" "):
        response += word + " "
        yield [(user_input, response)]

with gr.Blocks(css = """#col_container {width: 700px; margin-left: auto; margin-right: auto;}
                #chatbot {height: 400px; overflow: auto;}""") as demo:
    gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}")
    with gr.Column(elem_id='col_container'):
        chatbot = gr.Chatbot(elem_id='chatbot')
        textbox = gr.Textbox(placeholder="Enter a prompt")
    
        textbox.submit(chat, textbox, chatbot)

demo.queue(api_open=False).launch()