File size: 2,371 Bytes
6ccd417
 
 
dae67e9
 
69570eb
dae67e9
 
6ccd417
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0533c1e
dae67e9
6ccd417
 
 
dae67e9
 
 
 
 
 
0533c1e
 
dae67e9
 
 
 
6ccd417
0533c1e
 
dae67e9
0533c1e
 
 
 
7ffda7a
0533c1e
dae67e9
0533c1e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
from gradio.themes.utils import colors
from t5 import T5
from koalpaca import KoAlpaca

LOCAL_TEST = False
MODELS = []
cur_index = 0

def prepare_theme():
    theme = gr.themes.Default(primary_hue=colors.gray, 
                            secondary_hue=colors.emerald,
                            neutral_hue=colors.emerald).set(
        body_background_fill="*primary_800",
        body_background_fill_dark="*primary_800",
        
        block_background_fill="*primary_700",
        block_background_fill_dark="*primary_700",
        
        border_color_primary="*secondary_300",
        border_color_primary_dark="*secondary_300",
        block_border_width="3px",
        input_border_width="2px",
        
        input_background_fill="*primary_700",
        input_background_fill_dark="*primary_700",
        
        background_fill_secondary="*primary_700",
        background_fill_secondary_dark="*primary_700",
        
        body_text_color="white",
        body_text_color_dark="white",
        
        block_label_text_color="*secondary_300",
        block_label_text_color_dark="*secondary_300",
        block_label_background_fill="*primary_800",
        block_label_background_fill_dark="*primary_800",
        
        color_accent_soft="*primary_600",
        color_accent_soft_dark="*primary_600",
    )
    return theme

def chat(message, chat_history):
    response = MODELS[cur_index].generate(message)
    chat_history.append((message, response))
    return "", chat_history

def change_model_index(idx):
    global cur_index
    cur_index = idx
    print(cur_index)
    return

if __name__=='__main__':
    theme = prepare_theme()
    
    MODELS.append(T5())
    if not LOCAL_TEST:
        MODELS.append(KoAlpaca())

    with gr.Blocks(theme=theme) as demo:
        with gr.Row():
            rd = gr.Radio(['T5','KoAlpaca'], value='T5', type='index', label='Model Selection', show_label=True, interactive=True)
            with gr.Column(scale=5): # 챗봇 부분 
                chatbot = gr.Chatbot(label="T5", bubble_full_width=False)
                with gr.Row():
                    txt = gr.Textbox(show_label=False, placeholder='Send a message...', container=False)

        txt.submit(chat, [txt, chatbot], [txt, chatbot])
        rd.select(change_model_index, [rd])
    demo.launch(debug=True, share=True)