Spaces:
Sleeping
Sleeping
File size: 3,828 Bytes
2717026 482b320 2717026 70cbeb4 2717026 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import time
from os.path import basename
import yaml
import streamlit as st
from yaml.loader import SafeLoader
from streamlit_authenticator import Authenticate
from htmlTemplates import css, bot_template, user_template
from sources import sources_ref
import constants
import backend
import torch
print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
# Tesla T4
def on_query():
constants.question = st.session_state.InputText
st.session_state.InputText = ""
def handle_userinput(user_question):
res = st.session_state.model({'query': user_question})
answer, docs = res["result"], res["source_documents"]
# source = sources_ref[basename(docs[0].metadata["source"])]
st.session_state.chat_history.insert(0, user_question)
# st.session_state.chat_history.insert(1, answer + "<br><br><em>Source: " + source)
st.session_state.chat_history.insert(1, answer)
for i, message in enumerate(st.session_state.chat_history):
if i % 2 == 0:
st.write(user_template.replace("{{MSG}}", message), unsafe_allow_html=True)
else:
st.write(bot_template.replace("{{MSG}}", message), unsafe_allow_html=True)
def login_page():
st.set_page_config(page_title="logotherapyGPT",
page_icon=":books:")
with open('auth_config.yaml') as file:
config = yaml.load(file, Loader=SafeLoader)
authenticator = Authenticate(
config['credentials'],
config['cookie']['name'],
config['cookie']['key'],
config['cookie']['expiry_days'],
config['preauthorized']
)
name, authentication_status, username = authenticator.login('Login', 'main')
if authentication_status:
if 'db_clicked' not in st.session_state:
st.session_state.db_clicked = False
placeholder_box = st.empty()
placeholder_button = st.empty()
placeholder_box.selectbox(
'Choose a Database',
("All", "Frankl's Works", "Journal of Search for Meaning"),
key="database"
)
placeholder_button.button(
"Enter",
on_click=db_button_clicked
)
if st.session_state.db_clicked:
placeholder_box.empty()
placeholder_button.empty()
main_page()
elif not authentication_status:
st.error('Username/password is incorrect')
elif authentication_status is None:
st.warning('Please enter your username and password')
def db_button_clicked():
st.session_state.db_clicked = True
def main_page():
st.write(css, unsafe_allow_html=True)
st.markdown("# logotherapyGPT")
st.markdown("Chosen Database: " + st.session_state.database)
st.markdown("An AI chatbot to quickly access logotherapy information. Trained on the works of Viktor "
"Frankl and The International Forum for Logotherapy.")
if "model" not in st.session_state:
with st.spinner("Loading Database..."):
st.session_state.model = backend.load_qa(st.session_state.database)
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
st.text_input(
"Input",
label_visibility="collapsed",
placeholder="Ask a Question...",
on_change=on_query,
key="InputText"
)
print(constants.question)
if st.session_state.model is not None and constants.question is not None:
start = time.time()
with st.spinner("Processing..."):
handle_userinput(constants.question)
end = time.time()
print(f"\n> Answer (took {round(end - start, 2)} s.):")
if __name__ == '__main__':
login_page()
|