Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import warnings | |
class Chatbot(): | |
def __init__(self): | |
self.tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b') | |
special_tokens_dict = {'additional_special_tokens': ['<sep>', '<eos>', '<sos>', '#@์ด๋ฆ#', '#@๊ณ์ #', '#@์ ์#', '#@์ ๋ฒ#', '#@๊ธ์ต#', '#@๋ฒํธ#', '#@์ฃผ์#', '#@์์#', '#@๊ธฐํ#']} | |
num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict) | |
self.model = AutoModelForCausalLM.from_pretrained("/workspace/test_trainer/checkpoint-10000") | |
self.model.resize_token_embeddings(len(self.tokenizer)) | |
self.model = self.model.cuda() | |
self.info = None | |
self.talk = [] | |
def initialize(self, topic, bot_addr, bot_age, bot_sex, my_addr, my_age, my_sex): | |
def encode(age): | |
if age < 20: | |
age = "20๋ ๋ฏธ๋ง" | |
elif age >= 70: | |
age = "70๋ ์ด์" | |
else: | |
age = str(age // 10 * 10) + "๋" | |
return age | |
bot_age = encode(bot_age) | |
my_age = encode(my_age) | |
self.info = f"์ผ์ ๋ํ {topic}<sep>P01:{my_addr} {my_age} {my_sex}<sep>P02:{bot_addr} {bot_age} {bot_sex}<sep>" | |
return self.info_check() | |
def info_check(self): | |
return self.info.replace('<sep>', '\n').replace('P01', '๋น์ ').replace('P02', '์ฑ๋ด') | |
def reset_talk(self): | |
self.talk = [] | |
def test(self, myinp): | |
state = None | |
inp = "P01<sos>" + myinp + "<eos>" | |
self.talk.append(inp) | |
self.talk.append("P02<sos>") | |
while True: | |
now_inp = self.info + "".join(self.talk) | |
inputs = self.tokenizer(now_inp, max_length=1024, truncation='longest_first', return_tensors='pt') | |
seq_len = inputs.input_ids.size(1) | |
if seq_len > 512 * 0.8: | |
state = f"<์ฃผ์> ํ์ฌ ๋ํ ๊ธธ์ด๊ฐ ๊ณง ์ต๋ ๊ธธ์ด์ ๋๋ฌํฉ๋๋ค. ({seq_len} / 512)" | |
if seq_len >= 512: | |
state = "<์ฃผ์> ๋ํ ๊ธธ์ด๊ฐ ๋๋ฌด ๊ธธ์ด์ก๊ธฐ ๋๋ฌธ์, ์ดํ ๋ํ๋ ๋งจ ์์ ๋ฐํ๋ฅผ ์กฐ๊ธ์ฉ ์ง์ฐ๋ฉด์ ์งํ๋ฉ๋๋ค." | |
talk = talk[1:] | |
else: | |
break | |
out = self.model.generate( | |
inputs=inputs.input_ids.cuda(), | |
attention_mask=inputs.attention_mask.cuda(), | |
max_length=512, | |
do_sample=True, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.encode('<eos>')[0] | |
) | |
out = self.tokenizer.batch_decode(out) | |
real_out = out[0][len(now_inp):-5] | |
self.talk[-1] += out[0][len(now_inp):] | |
return [(self.talk[i][8:-5], self.talk[i+1][8:-5]) for i in range(0, len(self.talk)-1, 2)] | |
if __name__ == "__main__": | |
warnings.filterwarnings("ignore") | |
chatbot = Chatbot() | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("# <center>MINDs Lab Brain's Fast Neural Chit-Chatbot</center>") | |
with gr.Row(): | |
with gr.Column(): | |
topic = gr.Radio(label="Topic", choices=['์ฌ๊ฐ ์ํ', '์์ฌ/๊ต์ก', '๋ฏธ์ฉ๊ณผ ๊ฑด๊ฐ', '์์๋ฃ', '์๊ฑฐ๋(์ผํ)', '์ผ๊ณผ ์ง์ ', '์ฃผ๊ฑฐ์ ์ํ', '๊ฐ์ธ ๋ฐ ๊ด๊ณ', 'ํ์ฌ']) | |
with gr.Column(): | |
gr.Markdown(f"Bot's persona") | |
bot_addr = gr.Dropdown(label="์ง์ญ", choices=['์์ธํน๋ณ์', '๊ฒฝ๊ธฐ๋', '๋ถ์ฐ๊ด์ญ์', '๋์ ๊ด์ญ์', '๊ด์ฃผ๊ด์ญ์', '์ธ์ฐ๊ด์ญ์', '๊ฒฝ์๋จ๋', '์ธ์ฒ๊ด์ญ์', '์ถฉ์ฒญ๋ถ๋', '์ ์ฃผ๋', '๊ฐ์๋', '์ถฉ์ฒญ๋จ๋', '์ ๋ผ๋ถ๋', '๋๊ตฌ๊ด์ญ์', '์ ๋ผ๋จ๋', '๊ฒฝ์๋ถ๋', '์ธ์ข ํน๋ณ์์น์', '๊ธฐํ']) | |
bot_age = gr.Slider(label="๋์ด", minimum=10, maximum=80, value=45, step=1) | |
bot_sex = gr.Radio(label="์ฑ๋ณ", choices=["๋จ์ฑ", "์ฌ์ฑ"]) | |
with gr.Column(): | |
gr.Markdown(f"Your persona") | |
my_addr = gr.Dropdown(label="์ง์ญ", choices=['์์ธํน๋ณ์', '๊ฒฝ๊ธฐ๋', '๋ถ์ฐ๊ด์ญ์', '๋์ ๊ด์ญ์', '๊ด์ฃผ๊ด์ญ์', '์ธ์ฐ๊ด์ญ์', '๊ฒฝ์๋จ๋', '์ธ์ฒ๊ด์ญ์', '์ถฉ์ฒญ๋ถ๋', '์ ์ฃผ๋', '๊ฐ์๋', '์ถฉ์ฒญ๋จ๋', '์ ๋ผ๋ถ๋', '๋๊ตฌ๊ด์ญ์', '์ ๋ผ๋จ๋', '๊ฒฝ์๋ถ๋', '์ธ์ข ํน๋ณ์์น์', '๊ธฐํ']) | |
my_age = gr.Slider(label="๋์ด", minimum=10, maximum=80, value=45, step=1) | |
my_sex = gr.Radio(label="์ฑ๋ณ", choices=["๋จ์ฑ", "์ฌ์ฑ"]) | |
with gr.Row(): | |
btn = gr.Button(label="์ ์ฉ") | |
state = gr.Textbox(label="์ํ") | |
btn.click( | |
fn=chatbot.initialize, | |
inputs=[topic, bot_addr, bot_age, bot_sex, my_addr, my_age, my_sex], | |
outputs=state | |
) | |
with gr.Column(): | |
screen = gr.Chatbot(label="์ต๋ช ์ ์๋") | |
with gr.Row(): | |
speak = gr.Textbox(label="์ ๋ ฅ์ฐฝ") | |
btn = gr.Button(label="Talk") | |
btn.click( | |
fn=chatbot.test, | |
inputs=speak, | |
outputs=screen | |
) | |
demo.launch(share=True) | |