import os import streamlit as st from transformers import pipeline, Conversation import time model_id = "arampacha/DialoGPT-medium-simpsons" @st.cache(allow_output_mutation=True) def get_pipeline(): return pipeline("conversational", model=model_id) dialog = get_pipeline() parameters = { "min_length":None, "max_length":100, "top_p":0.92, "temperature":1.0, "repetition_penalty":None, "do_sample":True, } def on_input(): if st.session_state.count > 0: user_input = st.session_state.user_input st.session_state.full_text += f"_user_ >>> {user_input}\n\n" dialog_output.markdown(st.session_state.full_text) st.session_state.user_input = "" conv = Conversation( text = user_input, past_user_inputs = st.session_state.past_user_inputs, generated_responses = st.session_state.generated_responses, ) conv = dialog(conv, **parameters) try: st.session_state.update({ "past_user_inputs": conv.past_user_inputs, "generated_responses": conv.generated_responses, }) st.session_state.full_text += f'_chatbot_ > {conv.generated_responses[-1]}\n\n' except Exception as e: st.write("D'oh! Something went wrong. Try to rerun the app.") st.write(conv) st.write(e) st.session_state.count += 1 # init session state if "past_user_inputs" not in st.session_state: st.session_state["past_user_inputs"] = [] if "generated_responses" not in st.session_state: st.session_state["generated_responses"] = [] if "full_text" not in st.session_state: st.session_state["full_text"] = "" if "user_input" not in st.session_state: st.session_state["user_input"] = "" if "count" not in st.session_state: st.session_state["count"] = 0 # body st.title("Chat with Simpsons") st.image( "https://raw.githubusercontent.com/arampacha/chat-with-simpsons/main/the-simpsons.png", caption="(c) 20th Century Fox Television", ) if st.session_state.count == 0: st.write("Start dialog by inputing some text:") dialog_output = st.empty() if st.session_state.count > 0: dialog_output.markdown(st.session_state.full_text) user_input = st.text_input( "user >> ", # value="Hey Homer! How is it going?", on_change=on_input(), key="user_input", ) dialog_text = st.session_state.full_text dialog_output.markdown(dialog_text) def restart(): st.session_state.clear() st.button("Restart", on_click=st.session_state.clear)