Spaces:
Runtime error
Runtime error
Commit
•
465eb75
1
Parent(s):
a5162e0
Upload 3 files
Browse files- app.py +19 -16
- prompt_generation.py +9 -8
app.py
CHANGED
@@ -3,15 +3,16 @@ from streamlit_chat import message
|
|
3 |
from utils import PAGE, read_pdf
|
4 |
from prompt_generation import OpenAILLM
|
5 |
from dotenv import load_dotenv
|
6 |
-
|
7 |
load_dotenv()
|
8 |
|
9 |
|
10 |
def init():
|
11 |
if 'current_page' not in st.session_state:
|
12 |
st.session_state.current_page = PAGE.MAIN
|
13 |
-
st.session_state.mcq_question_number =
|
14 |
-
st.session_state.
|
|
|
|
|
15 |
st.session_state.chat_start = False
|
16 |
st.session_state.chat_messages = []
|
17 |
|
@@ -131,35 +132,37 @@ def mcq_page():
|
|
131 |
|
132 |
# Setup MCQ
|
133 |
if st.session_state.current_question == 0:
|
|
|
134 |
st.session_state.llm.start_mcq()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
# For every MCQ question
|
137 |
-
if st.session_state.current_question
|
138 |
# QA header
|
139 |
-
st.header(f"Question {st.session_state.current_question
|
140 |
-
|
141 |
-
# Generate the QA text
|
142 |
-
question, answers = st.session_state.llm.get_mcq_question()
|
143 |
|
144 |
# QA form
|
145 |
with st.form(key='my_form', clear_on_submit=True):
|
146 |
-
selected_answer = st.radio(f"{question}:", answers)
|
147 |
-
|
148 |
-
if send_button:
|
149 |
-
print("SELECTED ANSWER: ", selected_answer)
|
150 |
-
st.session_state.llm.mcq_record_answer(selected_answer)
|
151 |
-
st.session_state.current_question += 1
|
152 |
else:
|
153 |
# Results header
|
154 |
st.header("Results")
|
155 |
|
156 |
# For the last QA, show score
|
157 |
-
st.session_state.current_question += 1
|
158 |
score, score_perc = st.session_state.llm.get_mcq_score()
|
159 |
st.markdown("<h4>" + f"Score: {score} / {st.session_state.mcq_question_number} ({score_perc} %)" + "</h4>", unsafe_allow_html=True)
|
160 |
|
161 |
# List your answers and the correct ones
|
162 |
-
for i, qa in enumerate(st.session_state.llm.mcq_answer_sheet):
|
163 |
question, answer, user_answer = qa['question'], qa['answer'], qa['user_answer']
|
164 |
st.write("---")
|
165 |
st.write(f"**Question {i+1}/{st.session_state.mcq_question_number}:** {question}")
|
|
|
3 |
from utils import PAGE, read_pdf
|
4 |
from prompt_generation import OpenAILLM
|
5 |
from dotenv import load_dotenv
|
|
|
6 |
load_dotenv()
|
7 |
|
8 |
|
9 |
def init():
|
10 |
if 'current_page' not in st.session_state:
|
11 |
st.session_state.current_page = PAGE.MAIN
|
12 |
+
st.session_state.mcq_question_number = 10
|
13 |
+
st.session_state.mcq_false_answer_number = 3
|
14 |
+
st.session_state.llm = OpenAILLM(mcq_question_number=st.session_state.mcq_question_number,
|
15 |
+
mcq_false_answer_number=st.session_state.mcq_false_answer_number)
|
16 |
st.session_state.chat_start = False
|
17 |
st.session_state.chat_messages = []
|
18 |
|
|
|
132 |
|
133 |
# Setup MCQ
|
134 |
if st.session_state.current_question == 0:
|
135 |
+
# Start MCQ and get the first question and answer
|
136 |
st.session_state.llm.start_mcq()
|
137 |
+
st.session_state.question, st.session_state.answers = st.session_state.llm.get_mcq_question()
|
138 |
+
st.session_state.current_question += 1
|
139 |
+
|
140 |
+
# Handler when pressing next
|
141 |
+
def increase_current_question():
|
142 |
+
st.session_state.current_question = st.session_state.current_question + 1
|
143 |
+
st.session_state.llm.mcq_record_answer(st.session_state.selected_answer)
|
144 |
+
st.session_state.question, st.session_state.answers = st.session_state.llm.get_mcq_question()
|
145 |
|
146 |
# For every MCQ question
|
147 |
+
if st.session_state.current_question <= st.session_state.mcq_question_number:
|
148 |
# QA header
|
149 |
+
st.header(f"Question {st.session_state.current_question} / {st.session_state.mcq_question_number}")
|
|
|
|
|
|
|
150 |
|
151 |
# QA form
|
152 |
with st.form(key='my_form', clear_on_submit=True):
|
153 |
+
st.session_state.selected_answer = st.radio(f"{st.session_state.question}:", st.session_state.answers)
|
154 |
+
st.form_submit_button(label="Next", on_click=increase_current_question)
|
|
|
|
|
|
|
|
|
155 |
else:
|
156 |
# Results header
|
157 |
st.header("Results")
|
158 |
|
159 |
# For the last QA, show score
|
160 |
+
# st.session_state.current_question += 1
|
161 |
score, score_perc = st.session_state.llm.get_mcq_score()
|
162 |
st.markdown("<h4>" + f"Score: {score} / {st.session_state.mcq_question_number} ({score_perc} %)" + "</h4>", unsafe_allow_html=True)
|
163 |
|
164 |
# List your answers and the correct ones
|
165 |
+
for i, qa in enumerate(st.session_state.llm.mcq_answer_sheet[:-1]):
|
166 |
question, answer, user_answer = qa['question'], qa['answer'], qa['user_answer']
|
167 |
st.write("---")
|
168 |
st.write(f"**Question {i+1}/{st.session_state.mcq_question_number}:** {question}")
|
prompt_generation.py
CHANGED
@@ -12,8 +12,9 @@ import random
|
|
12 |
|
13 |
class OpenAILLM:
|
14 |
def __init__(self, temperature: float = 1.,
|
15 |
-
model_name: str = 'gpt-
|
16 |
-
mcq_question_number: int = 10
|
|
|
17 |
# Model-related instantiations
|
18 |
self.llm = ChatOpenAI(temperature=temperature, model_name=model_name)
|
19 |
self.Memory = ConversationBufferMemory
|
@@ -26,8 +27,9 @@ class OpenAILLM:
|
|
26 |
self.chat_document_intro = "Read the following document: "
|
27 |
self.chat_message_begin = "What would you like to know about the uploaded document?"
|
28 |
self.mcq_question_number = mcq_question_number
|
29 |
-
self.
|
30 |
-
|
|
|
31 |
Make sure that it is unique from the ones you have generated before!
|
32 |
Only create 3 possible false answers and a correct answers!
|
33 |
"""
|
@@ -43,6 +45,7 @@ class OpenAILLM:
|
|
43 |
def empty_text(self):
|
44 |
self.docs = []
|
45 |
self.chain_chat.memory = self.Memory()
|
|
|
46 |
|
47 |
def get_text_summary(self):
|
48 |
summary = self.chain_summary.run(self.docs)
|
@@ -66,7 +69,7 @@ class OpenAILLM:
|
|
66 |
ResponseSchema(name="question", description="Question generated from provided document."),
|
67 |
ResponseSchema(name="answer", description="One correct answer for the asked question."),
|
68 |
ResponseSchema(name="choices",
|
69 |
-
description="
|
70 |
]
|
71 |
output_format_instructions = StructuredOutputParser.from_response_schemas(
|
72 |
response_schemas).get_format_instructions()
|
@@ -85,16 +88,14 @@ class OpenAILLM:
|
|
85 |
self.start_chat()
|
86 |
|
87 |
def get_mcq_question(self):
|
88 |
-
print("HERE")
|
89 |
while True:
|
90 |
try:
|
91 |
response = self.chain_chat.predict(input=self.mcq_query)
|
92 |
-
print(response)
|
93 |
response_parsed = json.loads(response[len(r"```json"):-len(r"```")])
|
94 |
|
95 |
question = response_parsed["question"]
|
96 |
answers = [response_parsed["answer"]] + [false_answer.strip() for false_answer in
|
97 |
-
response_parsed["choices"].split(',')]
|
98 |
break
|
99 |
except Exception as e:
|
100 |
print(e)
|
|
|
12 |
|
13 |
class OpenAILLM:
|
14 |
def __init__(self, temperature: float = 1.,
|
15 |
+
model_name: str = 'gpt-4',
|
16 |
+
mcq_question_number: int = 10,
|
17 |
+
mcq_false_answer_number: int = 3):
|
18 |
# Model-related instantiations
|
19 |
self.llm = ChatOpenAI(temperature=temperature, model_name=model_name)
|
20 |
self.Memory = ConversationBufferMemory
|
|
|
27 |
self.chat_document_intro = "Read the following document: "
|
28 |
self.chat_message_begin = "What would you like to know about the uploaded document?"
|
29 |
self.mcq_question_number = mcq_question_number
|
30 |
+
self.mcq_false_answer_number = mcq_false_answer_number
|
31 |
+
self.mcq_intro = f"""
|
32 |
+
Generate a question, correct answer and {self.mcq_false_answer_number} possible false answers from the inputted document.
|
33 |
Make sure that it is unique from the ones you have generated before!
|
34 |
Only create 3 possible false answers and a correct answers!
|
35 |
"""
|
|
|
45 |
def empty_text(self):
|
46 |
self.docs = []
|
47 |
self.chain_chat.memory = self.Memory()
|
48 |
+
self.mcq_answer_sheet = []
|
49 |
|
50 |
def get_text_summary(self):
|
51 |
summary = self.chain_summary.run(self.docs)
|
|
|
69 |
ResponseSchema(name="question", description="Question generated from provided document."),
|
70 |
ResponseSchema(name="answer", description="One correct answer for the asked question."),
|
71 |
ResponseSchema(name="choices",
|
72 |
+
description=f"{self.mcq_false_answer_number} available false options for a multiple-choice question in comma separated."),
|
73 |
]
|
74 |
output_format_instructions = StructuredOutputParser.from_response_schemas(
|
75 |
response_schemas).get_format_instructions()
|
|
|
88 |
self.start_chat()
|
89 |
|
90 |
def get_mcq_question(self):
|
|
|
91 |
while True:
|
92 |
try:
|
93 |
response = self.chain_chat.predict(input=self.mcq_query)
|
|
|
94 |
response_parsed = json.loads(response[len(r"```json"):-len(r"```")])
|
95 |
|
96 |
question = response_parsed["question"]
|
97 |
answers = [response_parsed["answer"]] + [false_answer.strip() for false_answer in
|
98 |
+
response_parsed["choices"].split(',')][:self.mcq_false_answer_number]
|
99 |
break
|
100 |
except Exception as e:
|
101 |
print(e)
|