File size: 2,586 Bytes
22f88f8
49eede3
22f88f8
49eede3
 
 
22f88f8
49eede3
ca77bef
 
 
 
 
49eede3
 
 
c6b7aba
49eede3
f3d44ff
49eede3
3d7233c
49eede3
 
 
 
 
 
22f88f8
49eede3
 
 
22f88f8
 
 
 
 
3d7233c
c6b7aba
22f88f8
 
 
 
3d7233c
22f88f8
c6b7aba
4de1f02
22f88f8
49eede3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22f88f8
49eede3
 
 
 
 
 
 
 
 
 
 
c6b7aba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import streamlit as st
from transformers import pipeline, Conversation

import time

model_id = "arampacha/DialoGPT-medium-simpsons"

@st.cache(allow_output_mutation=True)
def get_pipeline():
    return pipeline("conversational", model=model_id)

dialog = get_pipeline()

parameters = {
    "min_length":None,
    "max_length":100,
    "top_p":0.92,
    "temperature":1.0,
    "repetition_penalty":None,
    "do_sample":True,
}


def on_input():
    if st.session_state.count > 0:
        user_input = st.session_state.user_input
        st.session_state.full_text += f"_user_  >>> {user_input}\n\n"
        dialog_output.markdown(st.session_state.full_text)
        st.session_state.user_input = ""

        conv = Conversation(
            text =  user_input,
            past_user_inputs = st.session_state.past_user_inputs,
            generated_responses = st.session_state.generated_responses,
        )
        conv = dialog(conv, **parameters)
        try:
            st.session_state.update({
                "past_user_inputs": conv.past_user_inputs,
                "generated_responses": conv.generated_responses,
            })
            st.session_state.full_text += f'_chatbot_ > {conv.generated_responses[-1]}\n\n'
        except Exception as e:
            st.write("D'oh! Something went wrong. Try to rerun the app.")
            st.write(conv)
            st.write(e)
    st.session_state.count += 1

# init session state
if "past_user_inputs" not in st.session_state:
    st.session_state["past_user_inputs"] = []
if "generated_responses" not in st.session_state:
    st.session_state["generated_responses"] = []
if "full_text" not in st.session_state:
    st.session_state["full_text"] = ""
if "user_input" not in st.session_state:
    st.session_state["user_input"] = ""
if "count" not in st.session_state:
    st.session_state["count"] = 0

# body
st.title("Chat with Simpsons")

st.image(
    "https://raw.githubusercontent.com/arampacha/chat-with-simpsons/main/the-simpsons.png",
    caption="(c) 20th Century Fox Television",
)
if st.session_state.count == 0:
    st.write("Start dialog by inputing some text:")

dialog_output = st.empty()

if st.session_state.count > 0:
    dialog_output.markdown(st.session_state.full_text)

user_input = st.text_input(
    "user >> ",
    # value="Hey Homer! How is it going?",
    on_change=on_input(),
    key="user_input",
)

dialog_text = st.session_state.full_text
dialog_output.markdown(dialog_text)

def restart():
    st.session_state.clear()
    
st.button("Restart", on_click=st.session_state.clear)