|
import streamlit as st |
|
import os |
|
import pickle |
|
import time |
|
import requests |
|
|
|
st.set_page_config(page_title="Psychedelics GPT") |
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
.title { |
|
text-align: center; |
|
font-size: 2em; |
|
font-weight: bold; |
|
} |
|
</style> |
|
<div class="title"> Psychedelics Chatbot </div> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
conversations_file = "conversations.pkl" |
|
|
|
|
|
@st.cache_data |
|
def load_conversations(): |
|
try: |
|
with open(conversations_file, "rb") as f: |
|
return pickle.load(f) |
|
except (FileNotFoundError, EOFError): |
|
return [] |
|
|
|
|
|
def save_conversations(conversations): |
|
temp_conversations_file = conversations_file |
|
with open(temp_conversations_file, "wb") as f: |
|
pickle.dump(conversations, f) |
|
os.replace(temp_conversations_file, conversations_file) |
|
|
|
|
|
if 'conversations' not in st.session_state: |
|
st.session_state.conversations = load_conversations() |
|
|
|
if 'current_conversation' not in st.session_state: |
|
st.session_state.current_conversation = [{"role": "assistant", "content": "How may I assist you today?"}] |
|
|
|
|
|
def truncate_string(s, length=30): |
|
return s[:length].rstrip() + "..." if len(s) > length else s |
|
|
|
|
|
def display_chats_sidebar(): |
|
with st.sidebar.container(): |
|
st.header('Settings') |
|
col1, col2 = st.columns([1, 1]) |
|
|
|
with col1: |
|
if col1.button('Start New Chat', key="new_chat"): |
|
st.session_state.current_conversation = [] |
|
st.session_state.conversations.append(st.session_state.current_conversation) |
|
|
|
with col2: |
|
if col2.button('Clear All Chats', key="clear_all"): |
|
st.session_state.conversations = [] |
|
st.session_state.current_conversation = [] |
|
|
|
with st.sidebar.container(): |
|
st.header('Conversations') |
|
for idx, conversation in enumerate(st.session_state.conversations): |
|
if conversation: |
|
chat_title_raw = next((msg["content"] for msg in conversation if msg["role"] == "user"), "New Chat") |
|
chat_title = truncate_string(chat_title_raw) |
|
if st.sidebar.button(f"{chat_title}", key=f"chat_button_{idx}"): |
|
st.session_state.current_conversation = st.session_state.conversations[idx] |
|
|
|
|
|
def main_app(): |
|
for message in st.session_state.current_conversation: |
|
with st.chat_message(message["role"]): |
|
st.write(message["content"]) |
|
|
|
def generate_response(prompt_input): |
|
json = { |
|
|
|
"user_prompt": prompt_input, |
|
"chat_history": [] |
|
|
|
} |
|
response = requests.post('http://3.223.163.181:8090/generate', json=json) |
|
|
|
return response.json() |
|
|
|
if prompt := st.chat_input('Send a Message'): |
|
st.session_state.current_conversation.append({"role": "user", "content": prompt}) |
|
with st.chat_message("user"): |
|
st.write(prompt) |
|
|
|
with st.chat_message("assistant"): |
|
with st.spinner("Thinking..."): |
|
response = generate_response(prompt) |
|
st.markdown(response['response']) |
|
sources_str = 'References:\n' + '\n'.join( |
|
[f'{idx + 1}. {source}' for idx, source in enumerate(response['sources'])]) |
|
st.markdown(sources_str) |
|
st.session_state.current_conversation.append( |
|
{"role": "assistant", "content": response['response'] + "\n" + sources_str}) |
|
save_conversations(st.session_state.conversations) |
|
|
|
|
|
display_chats_sidebar() |
|
|
|
main_app() |
|
|