brunogreen25 commited on
Commit
465eb75
1 Parent(s): a5162e0

Upload 3 files

Browse files
Files changed (2) hide show
  1. app.py +19 -16
  2. prompt_generation.py +9 -8
app.py CHANGED
@@ -3,15 +3,16 @@ from streamlit_chat import message
3
  from utils import PAGE, read_pdf
4
  from prompt_generation import OpenAILLM
5
  from dotenv import load_dotenv
6
-
7
  load_dotenv()
8
 
9
 
10
  def init():
11
  if 'current_page' not in st.session_state:
12
  st.session_state.current_page = PAGE.MAIN
13
- st.session_state.mcq_question_number = 3
14
- st.session_state.llm = OpenAILLM(mcq_question_number=st.session_state.mcq_question_number)
 
 
15
  st.session_state.chat_start = False
16
  st.session_state.chat_messages = []
17
 
@@ -131,35 +132,37 @@ def mcq_page():
131
 
132
  # Setup MCQ
133
  if st.session_state.current_question == 0:
 
134
  st.session_state.llm.start_mcq()
 
 
 
 
 
 
 
 
135
 
136
  # For every MCQ question
137
- if st.session_state.current_question < st.session_state.mcq_question_number:
138
  # QA header
139
- st.header(f"Question {st.session_state.current_question + 1} / {st.session_state.mcq_question_number}")
140
-
141
- # Generate the QA text
142
- question, answers = st.session_state.llm.get_mcq_question()
143
 
144
  # QA form
145
  with st.form(key='my_form', clear_on_submit=True):
146
- selected_answer = st.radio(f"{question}:", answers)
147
- send_button = st.form_submit_button(label="Next")
148
- if send_button:
149
- print("SELECTED ANSWER: ", selected_answer)
150
- st.session_state.llm.mcq_record_answer(selected_answer)
151
- st.session_state.current_question += 1
152
  else:
153
  # Results header
154
  st.header("Results")
155
 
156
  # For the last QA, show score
157
- st.session_state.current_question += 1
158
  score, score_perc = st.session_state.llm.get_mcq_score()
159
  st.markdown("<h4>" + f"Score: {score} / {st.session_state.mcq_question_number} ({score_perc} %)" + "</h4>", unsafe_allow_html=True)
160
 
161
  # List your answers and the correct ones
162
- for i, qa in enumerate(st.session_state.llm.mcq_answer_sheet):
163
  question, answer, user_answer = qa['question'], qa['answer'], qa['user_answer']
164
  st.write("---")
165
  st.write(f"**Question {i+1}/{st.session_state.mcq_question_number}:** {question}")
 
3
  from utils import PAGE, read_pdf
4
  from prompt_generation import OpenAILLM
5
  from dotenv import load_dotenv
 
6
  load_dotenv()
7
 
8
 
9
  def init():
10
  if 'current_page' not in st.session_state:
11
  st.session_state.current_page = PAGE.MAIN
12
+ st.session_state.mcq_question_number = 10
13
+ st.session_state.mcq_false_answer_number = 3
14
+ st.session_state.llm = OpenAILLM(mcq_question_number=st.session_state.mcq_question_number,
15
+ mcq_false_answer_number=st.session_state.mcq_false_answer_number)
16
  st.session_state.chat_start = False
17
  st.session_state.chat_messages = []
18
 
 
132
 
133
  # Setup MCQ
134
  if st.session_state.current_question == 0:
135
+ # Start MCQ and get the first question and answer
136
  st.session_state.llm.start_mcq()
137
+ st.session_state.question, st.session_state.answers = st.session_state.llm.get_mcq_question()
138
+ st.session_state.current_question += 1
139
+
140
+ # Handler when pressing next
141
+ def increase_current_question():
142
+ st.session_state.current_question = st.session_state.current_question + 1
143
+ st.session_state.llm.mcq_record_answer(st.session_state.selected_answer)
144
+ st.session_state.question, st.session_state.answers = st.session_state.llm.get_mcq_question()
145
 
146
  # For every MCQ question
147
+ if st.session_state.current_question <= st.session_state.mcq_question_number:
148
  # QA header
149
+ st.header(f"Question {st.session_state.current_question} / {st.session_state.mcq_question_number}")
 
 
 
150
 
151
  # QA form
152
  with st.form(key='my_form', clear_on_submit=True):
153
+ st.session_state.selected_answer = st.radio(f"{st.session_state.question}:", st.session_state.answers)
154
+ st.form_submit_button(label="Next", on_click=increase_current_question)
 
 
 
 
155
  else:
156
  # Results header
157
  st.header("Results")
158
 
159
  # For the last QA, show score
160
+ # st.session_state.current_question += 1
161
  score, score_perc = st.session_state.llm.get_mcq_score()
162
  st.markdown("<h4>" + f"Score: {score} / {st.session_state.mcq_question_number} ({score_perc} %)" + "</h4>", unsafe_allow_html=True)
163
 
164
  # List your answers and the correct ones
165
+ for i, qa in enumerate(st.session_state.llm.mcq_answer_sheet[:-1]):
166
  question, answer, user_answer = qa['question'], qa['answer'], qa['user_answer']
167
  st.write("---")
168
  st.write(f"**Question {i+1}/{st.session_state.mcq_question_number}:** {question}")
prompt_generation.py CHANGED
@@ -12,8 +12,9 @@ import random
12
 
13
  class OpenAILLM:
14
  def __init__(self, temperature: float = 1.,
15
- model_name: str = 'gpt-4o',
16
- mcq_question_number: int = 10):
 
17
  # Model-related instantiations
18
  self.llm = ChatOpenAI(temperature=temperature, model_name=model_name)
19
  self.Memory = ConversationBufferMemory
@@ -26,8 +27,9 @@ class OpenAILLM:
26
  self.chat_document_intro = "Read the following document: "
27
  self.chat_message_begin = "What would you like to know about the uploaded document?"
28
  self.mcq_question_number = mcq_question_number
29
- self.mcq_intro = """
30
- Generate a question, correct answer and 3 possible false answers from the inputted document.
 
31
  Make sure that it is unique from the ones you have generated before!
32
  Only create 3 possible false answers and a correct answers!
33
  """
@@ -43,6 +45,7 @@ class OpenAILLM:
43
  def empty_text(self):
44
  self.docs = []
45
  self.chain_chat.memory = self.Memory()
 
46
 
47
  def get_text_summary(self):
48
  summary = self.chain_summary.run(self.docs)
@@ -66,7 +69,7 @@ class OpenAILLM:
66
  ResponseSchema(name="question", description="Question generated from provided document."),
67
  ResponseSchema(name="answer", description="One correct answer for the asked question."),
68
  ResponseSchema(name="choices",
69
- description="3 available false options for a multiple-choice question in comma separated."),
70
  ]
71
  output_format_instructions = StructuredOutputParser.from_response_schemas(
72
  response_schemas).get_format_instructions()
@@ -85,16 +88,14 @@ class OpenAILLM:
85
  self.start_chat()
86
 
87
  def get_mcq_question(self):
88
- print("HERE")
89
  while True:
90
  try:
91
  response = self.chain_chat.predict(input=self.mcq_query)
92
- print(response)
93
  response_parsed = json.loads(response[len(r"```json"):-len(r"```")])
94
 
95
  question = response_parsed["question"]
96
  answers = [response_parsed["answer"]] + [false_answer.strip() for false_answer in
97
- response_parsed["choices"].split(',')]
98
  break
99
  except Exception as e:
100
  print(e)
 
12
 
13
  class OpenAILLM:
14
  def __init__(self, temperature: float = 1.,
15
+ model_name: str = 'gpt-4',
16
+ mcq_question_number: int = 10,
17
+ mcq_false_answer_number: int = 3):
18
  # Model-related instantiations
19
  self.llm = ChatOpenAI(temperature=temperature, model_name=model_name)
20
  self.Memory = ConversationBufferMemory
 
27
  self.chat_document_intro = "Read the following document: "
28
  self.chat_message_begin = "What would you like to know about the uploaded document?"
29
  self.mcq_question_number = mcq_question_number
30
+ self.mcq_false_answer_number = mcq_false_answer_number
31
+ self.mcq_intro = f"""
32
+ Generate a question, correct answer and {self.mcq_false_answer_number} possible false answers from the inputted document.
33
  Make sure that it is unique from the ones you have generated before!
34
  Only create 3 possible false answers and a correct answers!
35
  """
 
45
  def empty_text(self):
46
  self.docs = []
47
  self.chain_chat.memory = self.Memory()
48
+ self.mcq_answer_sheet = []
49
 
50
  def get_text_summary(self):
51
  summary = self.chain_summary.run(self.docs)
 
69
  ResponseSchema(name="question", description="Question generated from provided document."),
70
  ResponseSchema(name="answer", description="One correct answer for the asked question."),
71
  ResponseSchema(name="choices",
72
+ description=f"{self.mcq_false_answer_number} available false options for a multiple-choice question in comma separated."),
73
  ]
74
  output_format_instructions = StructuredOutputParser.from_response_schemas(
75
  response_schemas).get_format_instructions()
 
88
  self.start_chat()
89
 
90
  def get_mcq_question(self):
 
91
  while True:
92
  try:
93
  response = self.chain_chat.predict(input=self.mcq_query)
 
94
  response_parsed = json.loads(response[len(r"```json"):-len(r"```")])
95
 
96
  question = response_parsed["question"]
97
  answers = [response_parsed["answer"]] + [false_answer.strip() for false_answer in
98
+ response_parsed["choices"].split(',')][:self.mcq_false_answer_number]
99
  break
100
  except Exception as e:
101
  print(e)