File size: 5,475 Bytes
fa23262
4cb51c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa23262
4cb51c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273c30f
 
4cb51c5
273c30f
 
4cb51c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273c30f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings


class Chatbot():
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b')
        special_tokens_dict = {'additional_special_tokens': ['<sep>', '<eos>', '<sos>', '#@์ด๋ฆ„#', '#@๊ณ„์ •#', '#@์‹ ์›#', '#@์ „๋ฒˆ#', '#@๊ธˆ์œต#', '#@๋ฒˆํ˜ธ#', '#@์ฃผ์†Œ#', '#@์†Œ์†#', '#@๊ธฐํƒ€#']}
        num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict)

        self.model = AutoModelForCausalLM.from_pretrained("/workspace/test_trainer/checkpoint-10000")
        self.model.resize_token_embeddings(len(self.tokenizer))
        self.model = self.model.cuda()

        self.info = None
        self.talk = []

    def initialize(self, topic, bot_addr, bot_age, bot_sex, my_addr, my_age, my_sex):
        def encode(age):
            if age < 20:
                age = "20๋Œ€ ๋ฏธ๋งŒ"
            elif age >= 70:
                age = "70๋Œ€ ์ด์ƒ"
            else:
                age = str(age // 10 * 10) + "๋Œ€"
            return age
        bot_age = encode(bot_age)
        my_age = encode(my_age)
        self.info = f"์ผ์ƒ ๋Œ€ํ™” {topic}<sep>P01:{my_addr} {my_age} {my_sex}<sep>P02:{bot_addr} {bot_age} {bot_sex}<sep>"
        return self.info_check()

    def info_check(self):
        return self.info.replace('<sep>', '\n').replace('P01', '๋‹น์‹ ').replace('P02', '์ฑ—๋ด‡')

    def reset_talk(self):
        self.talk = []

    def test(self, myinp):
        state = None
        inp = "P01<sos>" + myinp + "<eos>"
        self.talk.append(inp)
        self.talk.append("P02<sos>")

        while True:
            now_inp = self.info + "".join(self.talk)
            inputs = self.tokenizer(now_inp, max_length=1024, truncation='longest_first', return_tensors='pt')
            seq_len = inputs.input_ids.size(1)
            if seq_len > 512 * 0.8:
                state = f"<์ฃผ์˜> ํ˜„์žฌ ๋Œ€ํ™” ๊ธธ์ด๊ฐ€ ๊ณง ์ตœ๋Œ€ ๊ธธ์ด์— ๋„๋‹ฌํ•ฉ๋‹ˆ๋‹ค. ({seq_len} / 512)"

            if seq_len >= 512:
                state = "<์ฃผ์˜> ๋Œ€ํ™” ๊ธธ์ด๊ฐ€ ๋„ˆ๋ฌด ๊ธธ์–ด์กŒ๊ธฐ ๋•Œ๋ฌธ์—, ์ดํ›„ ๋Œ€ํ™”๋Š” ๋งจ ์•ž์˜ ๋ฐœํ™”๋ฅผ ์กฐ๊ธˆ์”ฉ ์ง€์šฐ๋ฉด์„œ ์ง„ํ–‰๋ฉ๋‹ˆ๋‹ค."
                talk = talk[1:]
            else:
                break
            
        out = self.model.generate(
            inputs=inputs.input_ids.cuda(), 
            attention_mask=inputs.attention_mask.cuda(),
            max_length=512, 
            do_sample=True,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.encode('<eos>')[0]
            )
        out = self.tokenizer.batch_decode(out)
        real_out = out[0][len(now_inp):-5]
        self.talk[-1] += out[0][len(now_inp):]
        return [(self.talk[i][8:-5], self.talk[i+1][8:-5]) for i in range(0, len(self.talk)-1, 2)]


if __name__ == "__main__":
    warnings.filterwarnings("ignore")

    chatbot = Chatbot()
    demo = gr.Blocks()

    with demo:
        gr.Markdown("# <center>MINDs Lab Brain's Fast Neural Chit-Chatbot</center>")
        with gr.Row():
            with gr.Column():
                topic = gr.Radio(label="Topic", choices=['์—ฌ๊ฐ€ ์ƒํ™œ', '์‹œ์‚ฌ/๊ต์œก', '๋ฏธ์šฉ๊ณผ ๊ฑด๊ฐ•', '์‹์Œ๋ฃŒ', '์ƒ๊ฑฐ๋ž˜(์‡ผํ•‘)', '์ผ๊ณผ ์ง์—…', '์ฃผ๊ฑฐ์™€ ์ƒํ™œ', '๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„', 'ํ–‰์‚ฌ'])
                with gr.Column():
                    gr.Markdown(f"Bot's persona")
                    bot_addr = gr.Dropdown(label="์ง€์—ญ", choices=['์„œ์šธํŠน๋ณ„์‹œ', '๊ฒฝ๊ธฐ๋„', '๋ถ€์‚ฐ๊ด‘์—ญ์‹œ', '๋Œ€์ „๊ด‘์—ญ์‹œ', '๊ด‘์ฃผ๊ด‘์—ญ์‹œ', '์šธ์‚ฐ๊ด‘์—ญ์‹œ', '๊ฒฝ์ƒ๋‚จ๋„', '์ธ์ฒœ๊ด‘์—ญ์‹œ', '์ถฉ์ฒญ๋ถ๋„', '์ œ์ฃผ๋„', '๊ฐ•์›๋„', '์ถฉ์ฒญ๋‚จ๋„', '์ „๋ผ๋ถ๋„', '๋Œ€๊ตฌ๊ด‘์—ญ์‹œ', '์ „๋ผ๋‚จ๋„', '๊ฒฝ์ƒ๋ถ๋„', '์„ธ์ข…ํŠน๋ณ„์ž์น˜์‹œ', '๊ธฐํƒ€'])
                    bot_age = gr.Slider(label="๋‚˜์ด", minimum=10, maximum=80, value=45, step=1)
                    bot_sex = gr.Radio(label="์„ฑ๋ณ„", choices=["๋‚จ์„ฑ", "์—ฌ์„ฑ"])
                with gr.Column():
                    gr.Markdown(f"Your persona")
                    my_addr = gr.Dropdown(label="์ง€์—ญ", choices=['์„œ์šธํŠน๋ณ„์‹œ', '๊ฒฝ๊ธฐ๋„', '๋ถ€์‚ฐ๊ด‘์—ญ์‹œ', '๋Œ€์ „๊ด‘์—ญ์‹œ', '๊ด‘์ฃผ๊ด‘์—ญ์‹œ', '์šธ์‚ฐ๊ด‘์—ญ์‹œ', '๊ฒฝ์ƒ๋‚จ๋„', '์ธ์ฒœ๊ด‘์—ญ์‹œ', '์ถฉ์ฒญ๋ถ๋„', '์ œ์ฃผ๋„', '๊ฐ•์›๋„', '์ถฉ์ฒญ๋‚จ๋„', '์ „๋ผ๋ถ๋„', '๋Œ€๊ตฌ๊ด‘์—ญ์‹œ', '์ „๋ผ๋‚จ๋„', '๊ฒฝ์ƒ๋ถ๋„', '์„ธ์ข…ํŠน๋ณ„์ž์น˜์‹œ', '๊ธฐํƒ€'])
                    my_age = gr.Slider(label="๋‚˜์ด", minimum=10, maximum=80, value=45, step=1)
                    my_sex = gr.Radio(label="์„ฑ๋ณ„", choices=["๋‚จ์„ฑ", "์—ฌ์„ฑ"])
                with gr.Row():
                    btn = gr.Button(label="์ ์šฉ")
                    state = gr.Textbox(label="์ƒํƒœ")
                    btn.click(
                        fn=chatbot.initialize, 
                        inputs=[topic, bot_addr, bot_age, bot_sex, my_addr, my_age, my_sex], 
                        outputs=state
                    )
                    
            with gr.Column():
                screen = gr.Chatbot(label="์ต๋ช…์˜ ์ƒ๋Œ€")
                with gr.Row():
                    speak = gr.Textbox(label="์ž…๋ ฅ์ฐฝ")
                    btn = gr.Button(label="Talk")
                    btn.click(
                        fn=chatbot.test, 
                        inputs=speak, 
                        outputs=screen
                    )
    demo.launch(share=True)