Spaces:
Runtime error
Runtime error
import streamlit as st | |
from streamlit_chat import message | |
from utils import PAGE, read_pdf | |
from prompt_generation import OpenAILLM | |
from dotenv import load_dotenv | |
load_dotenv() | |
def init(): | |
if 'current_page' not in st.session_state: | |
st.session_state.current_page = PAGE.MAIN | |
st.session_state.mcq_question_number = 10 | |
st.session_state.mcq_false_answer_number = 3 | |
st.session_state.llm = OpenAILLM(mcq_question_number=st.session_state.mcq_question_number, | |
mcq_false_answer_number=st.session_state.mcq_false_answer_number) | |
st.session_state.chat_start = False | |
st.session_state.chat_messages = [] | |
# Setting page title and header | |
st.set_page_config(page_title="AILearningBuddy", page_icon=":book:") | |
st.markdown("<h1 style='text-align: center;'>AI Learning Buddy</h1>", unsafe_allow_html=True) | |
def main_page(): | |
# Header | |
st.header("Main Page") | |
# Upload docs | |
file = st.file_uploader("Upload documents", type=["pdf", "txt"]) | |
MAX_FILE_SIZE = 2 * 1024 * 1024 # 2 MB | |
if file is not None: | |
# Check the file size | |
file_size = file.size | |
if file_size > MAX_FILE_SIZE: | |
st.error(f"File size should not exceed {MAX_FILE_SIZE / (1024 * 1024)} MB. Please upload a smaller file.") | |
else: | |
st.success("File uploaded successfully!") | |
# Read file based on its type | |
if file.type == "application/pdf": | |
text = read_pdf(file) | |
st.session_state.llm.upload_text(text) | |
elif file.type == "text/plain": | |
text = file.read().decode("utf-8") | |
st.session_state.llm.upload_text(text) | |
else: | |
st.error("Unsupported file type.") | |
# Display buttons if file is uploaded | |
if st.session_state.llm.is_text_uploaded(): | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
st.markdown("<h4>LEARN</h4>", unsafe_allow_html=True) | |
if st.button("Create summary", key="summary_button"): | |
st.session_state.current_page = PAGE.SUMMARY | |
st.rerun() | |
if st.button("Chat about the file", key="chat_button"): | |
st.session_state.current_page = PAGE.CHAT | |
st.session_state.chat_start = True | |
st.rerun() | |
with col2: | |
st.markdown("<h4>TEST</h4>", unsafe_allow_html=True) | |
if st.button("Create quiz", key="mcq_button"): | |
st.session_state.current_page = PAGE.MCQ | |
st.session_state.current_question = 0 | |
st.rerun() | |
def summary_page(): | |
# Header | |
if st.button(":back: Main Page"): | |
st.session_state.current_page = PAGE.MAIN | |
st.session_state.llm.empty_text() | |
st.rerun() | |
st.header("Summary") | |
# Get the summary | |
summary = st.session_state.llm.get_text_summary() | |
# Write summary | |
st.write(summary) | |
def chat_page(): | |
# Header | |
if st.button(":back: Main Page"): | |
st.session_state.current_page = PAGE.MAIN | |
st.session_state.chat_start = False | |
st.session_state.chat_messages = [] | |
st.session_state.llm.empty_text() | |
st.rerun() | |
st.header("Chat About the Document") | |
# Response and user container | |
response_container = st.container() | |
user_container = st.container() | |
with user_container: | |
with st.form(key='my_form', clear_on_submit=True): | |
user_input = st.text_area("Type here:", key='input', height=100) | |
send_button = st.form_submit_button(label='Send') | |
if send_button or st.session_state.chat_start: | |
# Get the model response, and save it | |
if st.session_state.chat_start: | |
user_input, model_response = st.session_state.llm.start_chat() | |
st.session_state.chat_start = False | |
else: | |
model_response = st.session_state.llm.get_chat_response(user_input) | |
st.session_state.chat_messages += [user_input, model_response] | |
# Display chat messages | |
with response_container: | |
for i in range(1, len(st.session_state.chat_messages)): | |
if i % 2: | |
message(st.session_state.chat_messages[i], key=f'{str(i)}_AI', avatar_style="pixel-art") | |
else: | |
message(st.session_state.chat_messages[i], is_user=True, key=f'{str(i)}_user', | |
avatar_style="adventurer-neutral") | |
def mcq_page(): | |
# Header | |
if st.button(":back: Main Page"): | |
st.session_state.current_page = PAGE.MAIN | |
st.session_state.current_question = 0 | |
st.session_state.llm.empty_text() | |
st.rerun() | |
# Setup MCQ | |
if st.session_state.current_question == 0: | |
# Start MCQ and get the first question and answer | |
st.session_state.llm.start_mcq() | |
st.session_state.question, st.session_state.answers = st.session_state.llm.get_mcq_question() | |
st.session_state.current_question += 1 | |
# Handler when pressing next | |
def increase_current_question(): | |
st.session_state.current_question = st.session_state.current_question + 1 | |
st.session_state.llm.mcq_record_answer(st.session_state.selected_answer) | |
st.session_state.question, st.session_state.answers = st.session_state.llm.get_mcq_question() | |
# For every MCQ question | |
if st.session_state.current_question <= st.session_state.mcq_question_number: | |
# QA header | |
st.header(f"Question {st.session_state.current_question} / {st.session_state.mcq_question_number}") | |
# QA form | |
with st.form(key='my_form', clear_on_submit=True): | |
st.session_state.selected_answer = st.radio(f"{st.session_state.question}:", st.session_state.answers) | |
st.form_submit_button(label="Next", on_click=increase_current_question) | |
else: | |
# Results header | |
st.header("Results") | |
# For the last QA, show score | |
# st.session_state.current_question += 1 | |
score, score_perc = st.session_state.llm.get_mcq_score() | |
st.markdown("<h4>" + f"Score: {score} / {st.session_state.mcq_question_number} ({score_perc} %)" + "</h4>", unsafe_allow_html=True) | |
# List your answers and the correct ones | |
for i, qa in enumerate(st.session_state.llm.mcq_answer_sheet[:-1]): | |
question, answer, user_answer = qa['question'], qa['answer'], qa['user_answer'] | |
st.write("---") | |
st.write(f"**Question {i+1}/{st.session_state.mcq_question_number}:** {question}") | |
st.write(f"**Correct answer:** {answer}") | |
st.write(f"**User answer:** {user_answer}") | |
# Main structure | |
init() | |
# Page selector | |
match st.session_state.current_page: | |
case PAGE.MAIN: | |
main_page() | |
case PAGE.SUMMARY: | |
summary_page() | |
case PAGE.CHAT: | |
chat_page() | |
case PAGE.MCQ: | |
mcq_page() | |