Spaces:
Runtime error
Runtime error
File size: 7,230 Bytes
a5162e0 465eb75 a5162e0 465eb75 a5162e0 465eb75 a5162e0 465eb75 a5162e0 465eb75 a5162e0 465eb75 a5162e0 465eb75 a5162e0 465eb75 a5162e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import streamlit as st
from streamlit_chat import message
from utils import PAGE, read_pdf
from prompt_generation import OpenAILLM
from dotenv import load_dotenv
load_dotenv()
def init():
if 'current_page' not in st.session_state:
st.session_state.current_page = PAGE.MAIN
st.session_state.mcq_question_number = 10
st.session_state.mcq_false_answer_number = 3
st.session_state.llm = OpenAILLM(mcq_question_number=st.session_state.mcq_question_number,
mcq_false_answer_number=st.session_state.mcq_false_answer_number)
st.session_state.chat_start = False
st.session_state.chat_messages = []
# Setting page title and header
st.set_page_config(page_title="AILearningBuddy", page_icon=":book:")
st.markdown("<h1 style='text-align: center;'>AI Learning Buddy</h1>", unsafe_allow_html=True)
def main_page():
# Header
st.header("Main Page")
# Upload docs
file = st.file_uploader("Upload documents", type=["pdf", "txt"])
MAX_FILE_SIZE = 2 * 1024 * 1024 # 2 MB
if file is not None:
# Check the file size
file_size = file.size
if file_size > MAX_FILE_SIZE:
st.error(f"File size should not exceed {MAX_FILE_SIZE / (1024 * 1024)} MB. Please upload a smaller file.")
else:
st.success("File uploaded successfully!")
# Read file based on its type
if file.type == "application/pdf":
text = read_pdf(file)
st.session_state.llm.upload_text(text)
elif file.type == "text/plain":
text = file.read().decode("utf-8")
st.session_state.llm.upload_text(text)
else:
st.error("Unsupported file type.")
# Display buttons if file is uploaded
if st.session_state.llm.is_text_uploaded():
col1, col2 = st.columns([1, 1])
with col1:
st.markdown("<h4>LEARN</h4>", unsafe_allow_html=True)
if st.button("Create summary", key="summary_button"):
st.session_state.current_page = PAGE.SUMMARY
st.rerun()
if st.button("Chat about the file", key="chat_button"):
st.session_state.current_page = PAGE.CHAT
st.session_state.chat_start = True
st.rerun()
with col2:
st.markdown("<h4>TEST</h4>", unsafe_allow_html=True)
if st.button("Create quiz", key="mcq_button"):
st.session_state.current_page = PAGE.MCQ
st.session_state.current_question = 0
st.rerun()
def summary_page():
# Header
if st.button(":back: Main Page"):
st.session_state.current_page = PAGE.MAIN
st.session_state.llm.empty_text()
st.rerun()
st.header("Summary")
# Get the summary
summary = st.session_state.llm.get_text_summary()
# Write summary
st.write(summary)
def chat_page():
# Header
if st.button(":back: Main Page"):
st.session_state.current_page = PAGE.MAIN
st.session_state.chat_start = False
st.session_state.chat_messages = []
st.session_state.llm.empty_text()
st.rerun()
st.header("Chat About the Document")
# Response and user container
response_container = st.container()
user_container = st.container()
with user_container:
with st.form(key='my_form', clear_on_submit=True):
user_input = st.text_area("Type here:", key='input', height=100)
send_button = st.form_submit_button(label='Send')
if send_button or st.session_state.chat_start:
# Get the model response, and save it
if st.session_state.chat_start:
user_input, model_response = st.session_state.llm.start_chat()
st.session_state.chat_start = False
else:
model_response = st.session_state.llm.get_chat_response(user_input)
st.session_state.chat_messages += [user_input, model_response]
# Display chat messages
with response_container:
for i in range(1, len(st.session_state.chat_messages)):
if i % 2:
message(st.session_state.chat_messages[i], key=f'{str(i)}_AI', avatar_style="pixel-art")
else:
message(st.session_state.chat_messages[i], is_user=True, key=f'{str(i)}_user',
avatar_style="adventurer-neutral")
def mcq_page():
# Header
if st.button(":back: Main Page"):
st.session_state.current_page = PAGE.MAIN
st.session_state.current_question = 0
st.session_state.llm.empty_text()
st.rerun()
# Setup MCQ
if st.session_state.current_question == 0:
# Start MCQ and get the first question and answer
st.session_state.llm.start_mcq()
st.session_state.question, st.session_state.answers = st.session_state.llm.get_mcq_question()
st.session_state.current_question += 1
# Handler when pressing next
def increase_current_question():
st.session_state.current_question = st.session_state.current_question + 1
st.session_state.llm.mcq_record_answer(st.session_state.selected_answer)
st.session_state.question, st.session_state.answers = st.session_state.llm.get_mcq_question()
# For every MCQ question
if st.session_state.current_question <= st.session_state.mcq_question_number:
# QA header
st.header(f"Question {st.session_state.current_question} / {st.session_state.mcq_question_number}")
# QA form
with st.form(key='my_form', clear_on_submit=True):
st.session_state.selected_answer = st.radio(f"{st.session_state.question}:", st.session_state.answers)
st.form_submit_button(label="Next", on_click=increase_current_question)
else:
# Results header
st.header("Results")
# For the last QA, show score
# st.session_state.current_question += 1
score, score_perc = st.session_state.llm.get_mcq_score()
st.markdown("<h4>" + f"Score: {score} / {st.session_state.mcq_question_number} ({score_perc} %)" + "</h4>", unsafe_allow_html=True)
# List your answers and the correct ones
for i, qa in enumerate(st.session_state.llm.mcq_answer_sheet[:-1]):
question, answer, user_answer = qa['question'], qa['answer'], qa['user_answer']
st.write("---")
st.write(f"**Question {i+1}/{st.session_state.mcq_question_number}:** {question}")
st.write(f"**Correct answer:** {answer}")
st.write(f"**User answer:** {user_answer}")
# Main structure
init()
# Page selector
match st.session_state.current_page:
case PAGE.MAIN:
main_page()
case PAGE.SUMMARY:
summary_page()
case PAGE.CHAT:
chat_page()
case PAGE.MCQ:
mcq_page()
|