test_space / app.py
JYYong's picture
maybe complete
4cb51c5
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
class Chatbot():
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b')
special_tokens_dict = {'additional_special_tokens': ['<sep>', '<eos>', '<sos>', '#@์ด๋ฆ„#', '#@๊ณ„์ •#', '#@์‹ ์›#', '#@์ „๋ฒˆ#', '#@๊ธˆ์œต#', '#@๋ฒˆํ˜ธ#', '#@์ฃผ์†Œ#', '#@์†Œ์†#', '#@๊ธฐํƒ€#']}
num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict)
self.model = AutoModelForCausalLM.from_pretrained("/workspace/test_trainer/checkpoint-10000")
self.model.resize_token_embeddings(len(self.tokenizer))
self.model = self.model.cuda()
self.info = None
self.talk = []
def initialize(self, topic, bot_addr, bot_age, bot_sex, my_addr, my_age, my_sex):
def encode(age):
if age < 20:
age = "20๋Œ€ ๋ฏธ๋งŒ"
elif age >= 70:
age = "70๋Œ€ ์ด์ƒ"
else:
age = str(age // 10 * 10) + "๋Œ€"
return age
bot_age = encode(bot_age)
my_age = encode(my_age)
self.info = f"์ผ์ƒ ๋Œ€ํ™” {topic}<sep>P01:{my_addr} {my_age} {my_sex}<sep>P02:{bot_addr} {bot_age} {bot_sex}<sep>"
return self.info_check()
def info_check(self):
return self.info.replace('<sep>', '\n').replace('P01', '๋‹น์‹ ').replace('P02', '์ฑ—๋ด‡')
def reset_talk(self):
self.talk = []
def test(self, myinp):
state = None
inp = "P01<sos>" + myinp + "<eos>"
self.talk.append(inp)
self.talk.append("P02<sos>")
while True:
now_inp = self.info + "".join(self.talk)
inputs = self.tokenizer(now_inp, max_length=1024, truncation='longest_first', return_tensors='pt')
seq_len = inputs.input_ids.size(1)
if seq_len > 512 * 0.8:
state = f"<์ฃผ์˜> ํ˜„์žฌ ๋Œ€ํ™” ๊ธธ์ด๊ฐ€ ๊ณง ์ตœ๋Œ€ ๊ธธ์ด์— ๋„๋‹ฌํ•ฉ๋‹ˆ๋‹ค. ({seq_len} / 512)"
if seq_len >= 512:
state = "<์ฃผ์˜> ๋Œ€ํ™” ๊ธธ์ด๊ฐ€ ๋„ˆ๋ฌด ๊ธธ์–ด์กŒ๊ธฐ ๋•Œ๋ฌธ์—, ์ดํ›„ ๋Œ€ํ™”๋Š” ๋งจ ์•ž์˜ ๋ฐœํ™”๋ฅผ ์กฐ๊ธˆ์”ฉ ์ง€์šฐ๋ฉด์„œ ์ง„ํ–‰๋ฉ๋‹ˆ๋‹ค."
talk = talk[1:]
else:
break
out = self.model.generate(
inputs=inputs.input_ids.cuda(),
attention_mask=inputs.attention_mask.cuda(),
max_length=512,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.encode('<eos>')[0]
)
out = self.tokenizer.batch_decode(out)
real_out = out[0][len(now_inp):-5]
self.talk[-1] += out[0][len(now_inp):]
return [(self.talk[i][8:-5], self.talk[i+1][8:-5]) for i in range(0, len(self.talk)-1, 2)]
if __name__ == "__main__":
warnings.filterwarnings("ignore")
chatbot = Chatbot()
demo = gr.Blocks()
with demo:
gr.Markdown("# <center>MINDs Lab Brain's Fast Neural Chit-Chatbot</center>")
with gr.Row():
with gr.Column():
topic = gr.Radio(label="Topic", choices=['์—ฌ๊ฐ€ ์ƒํ™œ', '์‹œ์‚ฌ/๊ต์œก', '๋ฏธ์šฉ๊ณผ ๊ฑด๊ฐ•', '์‹์Œ๋ฃŒ', '์ƒ๊ฑฐ๋ž˜(์‡ผํ•‘)', '์ผ๊ณผ ์ง์—…', '์ฃผ๊ฑฐ์™€ ์ƒํ™œ', '๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„', 'ํ–‰์‚ฌ'])
with gr.Column():
gr.Markdown(f"Bot's persona")
bot_addr = gr.Dropdown(label="์ง€์—ญ", choices=['์„œ์šธํŠน๋ณ„์‹œ', '๊ฒฝ๊ธฐ๋„', '๋ถ€์‚ฐ๊ด‘์—ญ์‹œ', '๋Œ€์ „๊ด‘์—ญ์‹œ', '๊ด‘์ฃผ๊ด‘์—ญ์‹œ', '์šธ์‚ฐ๊ด‘์—ญ์‹œ', '๊ฒฝ์ƒ๋‚จ๋„', '์ธ์ฒœ๊ด‘์—ญ์‹œ', '์ถฉ์ฒญ๋ถ๋„', '์ œ์ฃผ๋„', '๊ฐ•์›๋„', '์ถฉ์ฒญ๋‚จ๋„', '์ „๋ผ๋ถ๋„', '๋Œ€๊ตฌ๊ด‘์—ญ์‹œ', '์ „๋ผ๋‚จ๋„', '๊ฒฝ์ƒ๋ถ๋„', '์„ธ์ข…ํŠน๋ณ„์ž์น˜์‹œ', '๊ธฐํƒ€'])
bot_age = gr.Slider(label="๋‚˜์ด", minimum=10, maximum=80, value=45, step=1)
bot_sex = gr.Radio(label="์„ฑ๋ณ„", choices=["๋‚จ์„ฑ", "์—ฌ์„ฑ"])
with gr.Column():
gr.Markdown(f"Your persona")
my_addr = gr.Dropdown(label="์ง€์—ญ", choices=['์„œ์šธํŠน๋ณ„์‹œ', '๊ฒฝ๊ธฐ๋„', '๋ถ€์‚ฐ๊ด‘์—ญ์‹œ', '๋Œ€์ „๊ด‘์—ญ์‹œ', '๊ด‘์ฃผ๊ด‘์—ญ์‹œ', '์šธ์‚ฐ๊ด‘์—ญ์‹œ', '๊ฒฝ์ƒ๋‚จ๋„', '์ธ์ฒœ๊ด‘์—ญ์‹œ', '์ถฉ์ฒญ๋ถ๋„', '์ œ์ฃผ๋„', '๊ฐ•์›๋„', '์ถฉ์ฒญ๋‚จ๋„', '์ „๋ผ๋ถ๋„', '๋Œ€๊ตฌ๊ด‘์—ญ์‹œ', '์ „๋ผ๋‚จ๋„', '๊ฒฝ์ƒ๋ถ๋„', '์„ธ์ข…ํŠน๋ณ„์ž์น˜์‹œ', '๊ธฐํƒ€'])
my_age = gr.Slider(label="๋‚˜์ด", minimum=10, maximum=80, value=45, step=1)
my_sex = gr.Radio(label="์„ฑ๋ณ„", choices=["๋‚จ์„ฑ", "์—ฌ์„ฑ"])
with gr.Row():
btn = gr.Button(label="์ ์šฉ")
state = gr.Textbox(label="์ƒํƒœ")
btn.click(
fn=chatbot.initialize,
inputs=[topic, bot_addr, bot_age, bot_sex, my_addr, my_age, my_sex],
outputs=state
)
with gr.Column():
screen = gr.Chatbot(label="์ต๋ช…์˜ ์ƒ๋Œ€")
with gr.Row():
speak = gr.Textbox(label="์ž…๋ ฅ์ฐฝ")
btn = gr.Button(label="Talk")
btn.click(
fn=chatbot.test,
inputs=speak,
outputs=screen
)
demo.launch(share=True)