File size: 3,828 Bytes
2717026
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482b320
2717026
 
70cbeb4
 
2717026
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import time
from os.path import basename
import yaml

import streamlit as st
from yaml.loader import SafeLoader
from streamlit_authenticator import Authenticate

from htmlTemplates import css, bot_template, user_template
from sources import sources_ref
import constants
import backend


import torch
print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
# Tesla T4


def on_query():
    constants.question = st.session_state.InputText
    st.session_state.InputText = ""


def handle_userinput(user_question):
    res = st.session_state.model({'query': user_question})
    answer, docs = res["result"], res["source_documents"]
    # source = sources_ref[basename(docs[0].metadata["source"])]

    st.session_state.chat_history.insert(0, user_question)
    # st.session_state.chat_history.insert(1, answer + "<br><br><em>Source: " + source)
    st.session_state.chat_history.insert(1, answer)

    for i, message in enumerate(st.session_state.chat_history):
        if i % 2 == 0:
            st.write(user_template.replace("{{MSG}}", message), unsafe_allow_html=True)
        else:
            st.write(bot_template.replace("{{MSG}}", message), unsafe_allow_html=True)


def login_page():
    st.set_page_config(page_title="logotherapyGPT",
                       page_icon=":books:")

    with open('auth_config.yaml') as file:
        config = yaml.load(file, Loader=SafeLoader)

    authenticator = Authenticate(
        config['credentials'],
        config['cookie']['name'],
        config['cookie']['key'],
        config['cookie']['expiry_days'],
        config['preauthorized']
    )

    name, authentication_status, username = authenticator.login('Login', 'main')

    if authentication_status:

        if 'db_clicked' not in st.session_state:
            st.session_state.db_clicked = False

        placeholder_box = st.empty()
        placeholder_button = st.empty()
        placeholder_box.selectbox(
            'Choose a Database',
            ("All", "Frankl's Works", "Journal of Search for Meaning"),
            key="database"
        )
        placeholder_button.button(
            "Enter",
            on_click=db_button_clicked
        )

        if st.session_state.db_clicked:
            placeholder_box.empty()
            placeholder_button.empty()
            main_page()

    elif not authentication_status:
        st.error('Username/password is incorrect')

    elif authentication_status is None:
        st.warning('Please enter your username and password')


def db_button_clicked():
    st.session_state.db_clicked = True


def main_page():
    st.write(css, unsafe_allow_html=True)

    st.markdown("# logotherapyGPT")
    st.markdown("Chosen Database: " + st.session_state.database)
    st.markdown("An AI chatbot to quickly access logotherapy information. Trained on the works of Viktor "
                "Frankl and The International Forum for Logotherapy.")

    if "model" not in st.session_state:
        with st.spinner("Loading Database..."):
            st.session_state.model = backend.load_qa(st.session_state.database)

    if "chat_history" not in st.session_state:
        st.session_state.chat_history = []

    st.text_input(
        "Input",
        label_visibility="collapsed",
        placeholder="Ask a Question...",
        on_change=on_query,
        key="InputText"
    )

    print(constants.question)

    if st.session_state.model is not None and constants.question is not None:
        start = time.time()

        with st.spinner("Processing..."):
            handle_userinput(constants.question)

        end = time.time()
        print(f"\n> Answer (took {round(end - start, 2)} s.):")


if __name__ == '__main__':
    login_page()