lcw99's picture
v0.1
7fbbf5b
import streamlit as st
from streamlit_chat import message
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
def on_text_input():
user_input = (st.session_state.user_input)
st.session_state.user_input = ''
st.session_state["chat_history"].append("A: " + user_input)
while len(st.session_state["chat_history"]) > 5:
st.session_state["chat_history"].pop(0)
hist = ""
for chat in st.session_state["chat_history"]:
hist += "\n" + chat
hist += "\nB: "
new_user_input_ids = tokenizer.encode(hist, return_tensors='pt')
bot_input_ids = new_user_input_ids
chat_history_ids = model.generate(
bot_input_ids, max_length=200,
pad_token_id=tokenizer.eos_token_id,
#no_repeat_ngram_size=3,
do_sample=True,
#top_k=100,
#top_p=0.7,
#temperature = 0.1
)
bot_text = tokenizer.decode(chat_history_ids[0], skip_special_tokens=True).replace("#@이름#", "OOO")
bot_text = bot_text.replace("\n", " / ")
st.session_state["chat_history"].append("B: " + bot_text)
st.session_state.past.append(user_input)
st.session_state.generated.append(bot_text)
model_dir = "lcw99/t5-base-korean-chit-chat"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
if 'chat_history' not in st.session_state:
st.session_state["chat_history"] = []
st.title("Chit-Chat Korean")
chat_hist = st.empty()
hist = ""
for i in range(len(st.session_state['generated'])):
hist += "User:\t" + st.session_state['past'][i] + "\n"
hist += "Bot :\t" + st.session_state['generated'][i] + "\n"
chat_hist.text_area("Chat history:", hist, height=300)
user_input = st.text_input('Please enter your message :', '', key="user_input", on_change=on_text_input)
if st.session_state['generated']:
for i in range(len(st.session_state['generated'])):
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
message(st.session_state["generated"][i], key=str(i))