File size: 2,437 Bytes
ae384c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e30aa42
 
 
 
 
 
 
 
 
 
 
 
 
 
468a803
 
 
d8b5474
055073f
 
399d645
055073f
 
 
 
 
399d645
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e30aa42
 
 
399d645
7388e28
399d645
 
 
 
 
7388e28
 
 
 
 
468a803
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import streamlit as st
from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationEntityMemory
from langchain.chains.conversation.prompt import ENTITY_MEMORY_CONVERSATION_TEMPLATE
from langchain.llms import OpenAI


if "generated" not in st.session_state:
    st.session_state["generated"] = []

if "past" not in st.session_state:
    st.session_state["past"] = []

if "input" not in st.session_state:
    st.session_state["input"] = ""

if "stored_session" not in st.session_state:
    st.session_state["stored_session"] = []


def get_text():
  input_text = st.text_input("You: ", st.session_state["input"], key = "input", placeholder = "Your AI Assistant here.. Ask me anything!", label_visibility = "hidden" )
  return input_text

def new_chat():
    save = []
    for i in range(len(st.session_state["generated"])-1,-1,-1):
        save.append("User:" + st.session_state["past"][i])
        save.append("Bot:" + st.session_state["generated"][i])
    st.session_state["started_session"].append(save)
    st.session_state["generated"] = []
    st.session_state["past"] = []
    st.session_state["input"] = ""
    st.session_state.entity_memory.store = {}
    st.session_state.entity_memory.buffer.clear()
    


st.title("Memory Bot")

api = st.sidebar.text_input("API-key", type = "password")
MODEL = st.sidebar.selectbox(label = "Model", options = ["gpt-3.5-turbo", "text-davinci-003"])

if api:
    
    llm = OpenAI(
        temperature = 0,
        open_api_key = api,
        model_name = MODEL
    )

    #Create Conv Memory

    if "entity_memory" not in st.session_state:
        st.session_state.entity_memory = ConversationEntityMemory(
            llm = llm,
            k = 10
        )

    # Create Conv Chain

    Conversation = ConversationChain(
        llm = llm,
        prompt = ENTITY_MEMORY_CONVERSATION_TEMPLATE,
        memory = st.session_state.entity_memory
    )

else:
    st.error("No API key found!")

st.sidebar.button("New Chat", on_click = new_chat, type = "primary" )

user_input = get_text()

if user_input:
    output = Conversation.run(input = user_input)
    st.session_state.past.append(user_input)
    st.session_state.generated.append(output)

with st.expander("Conversation"):
    for i in range(len(st.session_state["generated"])-1,-1,-1):
        st.info(st.session_state["past"])
        st.success(st.session_state["generated"][i])