Spaces:
Runtime error
Runtime error
File size: 5,475 Bytes
fa23262 4cb51c5 fa23262 4cb51c5 273c30f 4cb51c5 273c30f 4cb51c5 273c30f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
class Chatbot():
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b')
special_tokens_dict = {'additional_special_tokens': ['<sep>', '<eos>', '<sos>', '#@์ด๋ฆ#', '#@๊ณ์ #', '#@์ ์#', '#@์ ๋ฒ#', '#@๊ธ์ต#', '#@๋ฒํธ#', '#@์ฃผ์#', '#@์์#', '#@๊ธฐํ#']}
num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict)
self.model = AutoModelForCausalLM.from_pretrained("/workspace/test_trainer/checkpoint-10000")
self.model.resize_token_embeddings(len(self.tokenizer))
self.model = self.model.cuda()
self.info = None
self.talk = []
def initialize(self, topic, bot_addr, bot_age, bot_sex, my_addr, my_age, my_sex):
def encode(age):
if age < 20:
age = "20๋ ๋ฏธ๋ง"
elif age >= 70:
age = "70๋ ์ด์"
else:
age = str(age // 10 * 10) + "๋"
return age
bot_age = encode(bot_age)
my_age = encode(my_age)
self.info = f"์ผ์ ๋ํ {topic}<sep>P01:{my_addr} {my_age} {my_sex}<sep>P02:{bot_addr} {bot_age} {bot_sex}<sep>"
return self.info_check()
def info_check(self):
return self.info.replace('<sep>', '\n').replace('P01', '๋น์ ').replace('P02', '์ฑ๋ด')
def reset_talk(self):
self.talk = []
def test(self, myinp):
state = None
inp = "P01<sos>" + myinp + "<eos>"
self.talk.append(inp)
self.talk.append("P02<sos>")
while True:
now_inp = self.info + "".join(self.talk)
inputs = self.tokenizer(now_inp, max_length=1024, truncation='longest_first', return_tensors='pt')
seq_len = inputs.input_ids.size(1)
if seq_len > 512 * 0.8:
state = f"<์ฃผ์> ํ์ฌ ๋ํ ๊ธธ์ด๊ฐ ๊ณง ์ต๋ ๊ธธ์ด์ ๋๋ฌํฉ๋๋ค. ({seq_len} / 512)"
if seq_len >= 512:
state = "<์ฃผ์> ๋ํ ๊ธธ์ด๊ฐ ๋๋ฌด ๊ธธ์ด์ก๊ธฐ ๋๋ฌธ์, ์ดํ ๋ํ๋ ๋งจ ์์ ๋ฐํ๋ฅผ ์กฐ๊ธ์ฉ ์ง์ฐ๋ฉด์ ์งํ๋ฉ๋๋ค."
talk = talk[1:]
else:
break
out = self.model.generate(
inputs=inputs.input_ids.cuda(),
attention_mask=inputs.attention_mask.cuda(),
max_length=512,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.encode('<eos>')[0]
)
out = self.tokenizer.batch_decode(out)
real_out = out[0][len(now_inp):-5]
self.talk[-1] += out[0][len(now_inp):]
return [(self.talk[i][8:-5], self.talk[i+1][8:-5]) for i in range(0, len(self.talk)-1, 2)]
if __name__ == "__main__":
warnings.filterwarnings("ignore")
chatbot = Chatbot()
demo = gr.Blocks()
with demo:
gr.Markdown("# <center>MINDs Lab Brain's Fast Neural Chit-Chatbot</center>")
with gr.Row():
with gr.Column():
topic = gr.Radio(label="Topic", choices=['์ฌ๊ฐ ์ํ', '์์ฌ/๊ต์ก', '๋ฏธ์ฉ๊ณผ ๊ฑด๊ฐ', '์์๋ฃ', '์๊ฑฐ๋(์ผํ)', '์ผ๊ณผ ์ง์
', '์ฃผ๊ฑฐ์ ์ํ', '๊ฐ์ธ ๋ฐ ๊ด๊ณ', 'ํ์ฌ'])
with gr.Column():
gr.Markdown(f"Bot's persona")
bot_addr = gr.Dropdown(label="์ง์ญ", choices=['์์ธํน๋ณ์', '๊ฒฝ๊ธฐ๋', '๋ถ์ฐ๊ด์ญ์', '๋์ ๊ด์ญ์', '๊ด์ฃผ๊ด์ญ์', '์ธ์ฐ๊ด์ญ์', '๊ฒฝ์๋จ๋', '์ธ์ฒ๊ด์ญ์', '์ถฉ์ฒญ๋ถ๋', '์ ์ฃผ๋', '๊ฐ์๋', '์ถฉ์ฒญ๋จ๋', '์ ๋ผ๋ถ๋', '๋๊ตฌ๊ด์ญ์', '์ ๋ผ๋จ๋', '๊ฒฝ์๋ถ๋', '์ธ์ข
ํน๋ณ์์น์', '๊ธฐํ'])
bot_age = gr.Slider(label="๋์ด", minimum=10, maximum=80, value=45, step=1)
bot_sex = gr.Radio(label="์ฑ๋ณ", choices=["๋จ์ฑ", "์ฌ์ฑ"])
with gr.Column():
gr.Markdown(f"Your persona")
my_addr = gr.Dropdown(label="์ง์ญ", choices=['์์ธํน๋ณ์', '๊ฒฝ๊ธฐ๋', '๋ถ์ฐ๊ด์ญ์', '๋์ ๊ด์ญ์', '๊ด์ฃผ๊ด์ญ์', '์ธ์ฐ๊ด์ญ์', '๊ฒฝ์๋จ๋', '์ธ์ฒ๊ด์ญ์', '์ถฉ์ฒญ๋ถ๋', '์ ์ฃผ๋', '๊ฐ์๋', '์ถฉ์ฒญ๋จ๋', '์ ๋ผ๋ถ๋', '๋๊ตฌ๊ด์ญ์', '์ ๋ผ๋จ๋', '๊ฒฝ์๋ถ๋', '์ธ์ข
ํน๋ณ์์น์', '๊ธฐํ'])
my_age = gr.Slider(label="๋์ด", minimum=10, maximum=80, value=45, step=1)
my_sex = gr.Radio(label="์ฑ๋ณ", choices=["๋จ์ฑ", "์ฌ์ฑ"])
with gr.Row():
btn = gr.Button(label="์ ์ฉ")
state = gr.Textbox(label="์ํ")
btn.click(
fn=chatbot.initialize,
inputs=[topic, bot_addr, bot_age, bot_sex, my_addr, my_age, my_sex],
outputs=state
)
with gr.Column():
screen = gr.Chatbot(label="์ต๋ช
์ ์๋")
with gr.Row():
speak = gr.Textbox(label="์
๋ ฅ์ฐฝ")
btn = gr.Button(label="Talk")
btn.click(
fn=chatbot.test,
inputs=speak,
outputs=screen
)
demo.launch(share=True)
|