File size: 6,290 Bytes
410636c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import atexit
import Levenshtein
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir))
import streamlit as st
from streamlit_chat import message
from query_methods import query, avail_query_methods
import pickle
conversations_file = "conversations.pkl"
def load_conversations():
try:
with open(conversations_file, "rb") as f:
return pickle.load(f)
except FileNotFoundError:
return []
except EOFError:
return []
def save_conversations(conversations, current_conversation):
updated = False
for idx, conversation in enumerate(conversations):
if conversation == current_conversation:
conversations[idx] = current_conversation
updated = True
break
if not updated:
conversations.append(current_conversation)
temp_conversations_file = "temp_" + conversations_file
with open(temp_conversations_file, "wb") as f:
pickle.dump(conversations, f)
os.replace(temp_conversations_file, conversations_file)
def delete_conversation(conversations, current_conversation):
for idx, conversation in enumerate(conversations):
conversations[idx] = current_conversation
break
conversations.remove(current_conversation)
temp_conversations_file = "temp_" + conversations_file
with open(temp_conversations_file, "wb") as f:
pickle.dump(conversations, f)
os.replace(temp_conversations_file, conversations_file)
def exit_handler():
print("Exiting, saving data...")
# Perform cleanup operations here, like saving data or closing open files.
save_conversations(st.session_state.conversations, st.session_state.current_conversation)
# Register the exit_handler function to be called when the program is closing.
atexit.register(exit_handler)
st.header("Chat Placeholder")
if 'conversations' not in st.session_state:
st.session_state['conversations'] = load_conversations()
if 'input_text' not in st.session_state:
st.session_state['input_text'] = ''
if 'selected_conversation' not in st.session_state:
st.session_state['selected_conversation'] = None
if 'input_field_key' not in st.session_state:
st.session_state['input_field_key'] = 0
if 'query_method' not in st.session_state:
st.session_state['query_method'] = query
if 'search_query' not in st.session_state:
st.session_state['search_query'] = ''
# Initialize new conversation
if 'current_conversation' not in st.session_state or st.session_state['current_conversation'] is None:
st.session_state['current_conversation'] = {'user_inputs': [], 'generated_responses': []}
input_placeholder = st.empty()
user_input = input_placeholder.text_input(
'You:', value=st.session_state['input_text'], key=f'input_text_-1'#{st.session_state["input_field_key"]}
)
submit_button = st.button("Submit")
if (user_input and user_input != st.session_state['input_text']) or submit_button:
output = query(user_input, st.session_state['query_method'])
escaped_output = output.encode('utf-8').decode('unicode-escape')
st.session_state['current_conversation']['user_inputs'].append(user_input)
st.session_state.current_conversation['generated_responses'].append(escaped_output)
save_conversations(st.session_state.conversations, st.session_state.current_conversation)
st.session_state['input_text'] = ''
st.session_state['input_field_key'] += 1 # Increment key value for new widget
user_input = input_placeholder.text_input(
'You:', value=st.session_state['input_text'], key=f'input_text_{st.session_state["input_field_key"]}'
) # Clear the input field
# Add a button to create a new conversation
if st.sidebar.button("New Conversation"):
st.session_state['selected_conversation'] = None
st.session_state['current_conversation'] = {'user_inputs': [], 'generated_responses': []}
st.session_state['input_field_key'] += 1 # Increment key value for new widget
st.session_state['query_method'] = st.sidebar.selectbox("Select API:", options=avail_query_methods, index=0)
# Proxy
st.session_state['proxy'] = st.sidebar.text_input("Proxy: ")
# Searchbar
search_query = st.sidebar.text_input("Search Conversations:", value=st.session_state.get('search_query', ''), key='search')
if search_query:
filtered_conversations = []
indices = []
for idx, conversation in enumerate(st.session_state.conversations):
if search_query in conversation['user_inputs'][0]:
filtered_conversations.append(conversation)
indices.append(idx)
filtered_conversations = list(zip(indices, filtered_conversations))
conversations = sorted(filtered_conversations, key=lambda x: Levenshtein.distance(search_query, x[1]['user_inputs'][0]))
sidebar_header = f"Search Results ({len(conversations)})"
else:
conversations = st.session_state.conversations
sidebar_header = "Conversation History"
# Sidebar
st.sidebar.header(sidebar_header)
sidebar_col1, sidebar_col2 = st.sidebar.columns([5,1])
for idx, conversation in enumerate(conversations):
if sidebar_col1.button(f"Conversation {idx + 1}: {conversation['user_inputs'][0]}", key=f"sidebar_btn_{idx}"):
st.session_state['selected_conversation'] = idx
st.session_state['current_conversation'] = conversation
if sidebar_col2.button('🗑️', key=f"sidebar_btn_delete_{idx}"):
if st.session_state['selected_conversation'] == idx:
st.session_state['selected_conversation'] = None
st.session_state['current_conversation'] = {'user_inputs': [], 'generated_responses': []}
delete_conversation(conversations, conversation)
st.experimental_rerun()
if st.session_state['selected_conversation'] is not None:
conversation_to_display = conversations[st.session_state['selected_conversation']]
else:
conversation_to_display = st.session_state.current_conversation
if conversation_to_display['generated_responses']:
for i in range(len(conversation_to_display['generated_responses']) - 1, -1, -1):
message(conversation_to_display["generated_responses"][i], key=f"display_generated_{i}")
message(conversation_to_display['user_inputs'][i], is_user=True, key=f"display_user_{i}") |