File size: 1,619 Bytes
c84d465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import streamlit as st

from langchain.memory import ConversationBufferMemory
from langchain.schema import HumanMessage, AIMessage
from streamlit_chat_media import message


class ChatHistory:
    def __init__(self):
        self.history = st.session_state.get("history",
                                            ConversationBufferMemory(memory_key="chat_history", return_messages=True))
        st.session_state["history"] = self.history

    def default_greeting(self):
        return "Hi ! πŸ‘‹"

    def default_prompt(self, topic):
        return f"Hello ! Ask me anything about {topic} πŸ€—"

    def initialize(self, topic):
        message(self.default_greeting(), key='hi', avatar_style="adventurer", is_user=True)
        message(self.default_prompt(topic), key='ai', avatar_style="thumbs")

    def reset(self):
        st.session_state["history"].clear()
        st.session_state["reset_chat"] = False

    def generate_messages(self, container):
        if st.session_state["history"]:
            with container:
                messages = st.session_state["history"].chat_memory.messages
                for i in range(len(messages)):
                    msg = messages[i]
                    if isinstance(msg, HumanMessage):
                        message(
                            msg.content,
                            is_user=True,
                            key=f"{i}_user",
                            avatar_style="adventurer",
                        )
                    elif isinstance(msg, AIMessage):
                        message(msg.content, key=str(i), avatar_style="thumbs")