Spaces:
Runtime error
Runtime error
import streamlit as st | |
from streamlit_chat import message | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
def on_text_input(): | |
user_input = (st.session_state.user_input) | |
st.session_state.user_input = '' | |
st.session_state["chat_history"].append("A: " + user_input) | |
while len(st.session_state["chat_history"]) > 5: | |
st.session_state["chat_history"].pop(0) | |
hist = "" | |
for chat in st.session_state["chat_history"]: | |
hist += "\n" + chat | |
hist += "\nB: " | |
new_user_input_ids = tokenizer.encode(hist, return_tensors='pt') | |
bot_input_ids = new_user_input_ids | |
chat_history_ids = model.generate( | |
bot_input_ids, max_length=200, | |
pad_token_id=tokenizer.eos_token_id, | |
#no_repeat_ngram_size=3, | |
do_sample=True, | |
#top_k=100, | |
#top_p=0.7, | |
#temperature = 0.1 | |
) | |
bot_text = tokenizer.decode(chat_history_ids[0], skip_special_tokens=True).replace("#@이름#", "OOO") | |
bot_text = bot_text.replace("\n", " / ") | |
st.session_state["chat_history"].append("B: " + bot_text) | |
st.session_state.past.append(user_input) | |
st.session_state.generated.append(bot_text) | |
model_dir = "lcw99/t5-base-korean-chit-chat" | |
tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) | |
if 'generated' not in st.session_state: | |
st.session_state['generated'] = [] | |
if 'past' not in st.session_state: | |
st.session_state['past'] = [] | |
if 'chat_history' not in st.session_state: | |
st.session_state["chat_history"] = [] | |
st.title("Chit-Chat Korean") | |
chat_hist = st.empty() | |
hist = "" | |
for i in range(len(st.session_state['generated'])): | |
hist += "User:\t" + st.session_state['past'][i] + "\n" | |
hist += "Bot :\t" + st.session_state['generated'][i] + "\n" | |
chat_hist.text_area("Chat history:", hist, height=300) | |
user_input = st.text_input('Please enter your message :', '', key="user_input", on_change=on_text_input) | |
if st.session_state['generated']: | |
for i in range(len(st.session_state['generated'])): | |
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') | |
message(st.session_state["generated"][i], key=str(i)) | |