DeepakJaiz commited on
Commit
d24816c
1 Parent(s): 832aaa2

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +491 -0
  2. requirements.txt +0 -0
  3. text_utils.py +109 -0
app.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ from typing import List
5
+ import faiss
6
+ import pypdf
7
+ import random
8
+ import itertools
9
+ import text_utils
10
+ import pandas as pd
11
+ import altair as alt
12
+ import streamlit as st
13
+ from io import StringIO
14
+ from llama_index import Document
15
+ from langchain.llms import Anthropic
16
+ from langchain import HuggingFaceHub
17
+ from langchain.chains import RetrievalQA
18
+ from langchain.vectorstores import FAISS
19
+ from llama_index import LangchainEmbedding
20
+ from langchain.chat_models import ChatOpenAI
21
+ from langchain.retrievers import SVMRetriever
22
+ from langchain.chains import QAGenerationChain
23
+ from langchain.retrievers import TFIDFRetriever
24
+ from langchain.evaluation.qa import QAEvalChain
25
+ from langchain.embeddings import HuggingFaceEmbeddings
26
+ from langchain.embeddings.openai import OpenAIEmbeddings
27
+ from gpt_index import LLMPredictor, ServiceContext, GPTFaissIndex
28
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
29
+ from text_utils import GRADE_DOCS_PROMPT, GRADE_ANSWER_PROMPT, GRADE_DOCS_PROMPT_FAST, GRADE_ANSWER_PROMPT_FAST, GRADE_ANSWER_PROMPT_BIAS_CHECK, GRADE_ANSWER_PROMPT_OPENAI
30
+
31
+ # Keep dataframe in memory to accumulate experimental results
32
+ if "existing_df" not in st.session_state:
33
+ summary = pd.DataFrame(columns=['chunk_chars',
34
+ 'overlap',
35
+ 'split',
36
+ 'model',
37
+ 'retriever',
38
+ 'embedding',
39
+ 'num_neighbors',
40
+ 'Latency',
41
+ 'Retrieval score',
42
+ 'Answer score'])
43
+ st.session_state.existing_df = summary
44
+ else:
45
+ summary = st.session_state.existing_df
46
+
47
+
48
+ @st.cache_data
49
+ def load_docs(files: List) -> str:
50
+ """
51
+ Load docs from files
52
+ @param files: list of files to load
53
+ @return: string of all docs concatenated
54
+ """
55
+
56
+ st.info("`Reading doc ...`")
57
+ all_text = ""
58
+ for file_path in files:
59
+ file_extension = os.path.splitext(file_path.name)[1]
60
+ if file_extension == ".pdf":
61
+ pdf_reader = pypdf.PdfReader(file_path)
62
+ file_content = ""
63
+ for page in pdf_reader.pages:
64
+ file_content += page.extract_text()
65
+ file_content = text_utils.clean_pdf_text(file_content)
66
+ all_text += file_content
67
+ elif file_extension == ".txt":
68
+ stringio = StringIO(file_path.getvalue().decode("utf-8"))
69
+ file_content = stringio.read()
70
+ all_text += file_content
71
+ else:
72
+ st.warning('Please provide txt or pdf.', icon="⚠️")
73
+ return all_text
74
+
75
+
76
+ @st.cache_data
77
+ def generate_eval(text: str, num_questions: int, chunk: int):
78
+ """
79
+ Generate eval set
80
+ @param text: text to generate eval set from
81
+ @param num_questions: number of questions to generate
82
+ @param chunk: chunk size to draw question from in the doc
83
+ @return: eval set as JSON list
84
+ """
85
+ st.info("`Generating eval set ...`")
86
+ n = len(text)
87
+ starting_indices = [random.randint(0, n - chunk) for _ in range(num_questions)]
88
+ sub_sequences = [text[i:i + chunk] for i in starting_indices]
89
+ chain = QAGenerationChain.from_llm(ChatOpenAI(temperature=0))
90
+ eval_set = []
91
+ for i, b in enumerate(sub_sequences):
92
+ try:
93
+ qa = chain.run(b)
94
+ eval_set.append(qa)
95
+ except:
96
+ st.warning('Error generating question %s.' % str(i + 1), icon="⚠️")
97
+ eval_set_full = list(itertools.chain.from_iterable(eval_set))
98
+ return eval_set_full
99
+
100
+
101
+ @st.cache_resource
102
+ def split_texts(text, chunk_size: int, overlap, split_method: str):
103
+ """
104
+ Split text into chunks
105
+ @param text: text to split
106
+ @param chunk_size:
107
+ @param overlap:
108
+ @param split_method:
109
+ @return: list of str splits
110
+ """
111
+ st.info("`Splitting doc ...`")
112
+ if split_method == "RecursiveTextSplitter":
113
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
114
+ chunk_overlap=overlap)
115
+ elif split_method == "CharacterTextSplitter":
116
+ text_splitter = CharacterTextSplitter(separator=" ",
117
+ chunk_size=chunk_size,
118
+ chunk_overlap=overlap)
119
+ else:
120
+ st.warning("`Split method not recognized. Using RecursiveCharacterTextSplitter`", icon="⚠️")
121
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
122
+ chunk_overlap=overlap)
123
+
124
+ split_text = text_splitter.split_text(text)
125
+ return split_text
126
+
127
+
128
+ @st.cache_resource
129
+ def make_llm(model_version: str):
130
+ """
131
+ Make LLM from model version
132
+ @param model_version: model_version
133
+ @return: LLN
134
+ """
135
+ if (model_version == "gpt-3.5-turbo") or (model_version == "gpt-4"):
136
+ chosen_model = ChatOpenAI(model_name=model_version, temperature=0)
137
+ elif model_version == "anthropic":
138
+ chosen_model = Anthropic(temperature=0)
139
+ elif model_version == "flan-t5-xl":
140
+ chosen_model = HuggingFaceHub(repo_id="google/flan-t5-xl",model_kwargs={"temperature":0,"max_length":64})
141
+ else:
142
+ st.warning("`Model version not recognized. Using gpt-3.5-turbo`", icon="⚠️")
143
+ chosen_model = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
144
+ return chosen_model
145
+
146
+ @st.cache_resource
147
+ def make_retriever(splits, retriever_type, embedding_type, num_neighbors, _llm):
148
+ """
149
+ Make document retriever
150
+ @param splits: list of str splits
151
+ @param retriever_type: retriever type
152
+ @param embedding_type: embedding type
153
+ @param num_neighbors: number of neighbors for retrieval
154
+ @param _llm: model
155
+ @return: retriever
156
+ """
157
+ st.info("`Making retriever ...`")
158
+ # Set embeddings
159
+ if embedding_type == "OpenAI":
160
+ embedding = OpenAIEmbeddings()
161
+ elif embedding_type == "HuggingFace":
162
+ embedding = HuggingFaceEmbeddings()
163
+ else:
164
+ st.warning("`Embedding type not recognized. Using OpenAI`", icon="⚠️")
165
+ embedding = OpenAIEmbeddings()
166
+
167
+ # Select retriever
168
+ if retriever_type == "similarity-search":
169
+ try:
170
+ vector_store = FAISS.from_texts(splits, embedding)
171
+ except ValueError:
172
+ st.warning("`Error using OpenAI embeddings (disallowed TikToken token in the text). Using HuggingFace.`",
173
+ icon="⚠️")
174
+ vector_store = FAISS.from_texts(splits, HuggingFaceEmbeddings())
175
+ retriever_obj = vector_store.as_retriever(k=num_neighbors)
176
+ elif retriever_type == "SVM":
177
+ retriever_obj = SVMRetriever.from_texts(splits, embedding)
178
+ elif retriever_type == "TF-IDF":
179
+ retriever_obj = TFIDFRetriever.from_texts(splits)
180
+ elif retriever_type == "Llama-Index":
181
+ documents = [Document(t, LangchainEmbedding(embedding)) for t in splits]
182
+ llm_predictor = LLMPredictor(llm)
183
+ context = ServiceContext.from_defaults(chunk_size_limit=512, llm_predictor=llm_predictor)
184
+ d = 1536
185
+ faiss_index = faiss.IndexFlatL2(d)
186
+ retriever_obj = GPTFaissIndex.from_documents(documents, faiss_index=faiss_index, service_context=context)
187
+ else:
188
+ st.warning("`Retriever type not recognized. Using SVM`", icon="⚠️")
189
+ retriever_obj = SVMRetriever.from_texts(splits, embedding)
190
+ return retriever_obj
191
+
192
+
193
+ def make_chain(llm, retriever, retriever_type: str) -> RetrievalQA:
194
+ """
195
+ Make chain
196
+ @param llm: model
197
+ @param retriever: retriever
198
+ @param retriever_type: retriever type
199
+ @return: chain (or return retriever for Llama-Index)
200
+ """
201
+ st.info("`Making chain ...`")
202
+ if retriever_type == "Llama-Index":
203
+ qa = retriever
204
+ else:
205
+ qa = RetrievalQA.from_chain_type(llm,
206
+ chain_type="stuff",
207
+ retriever=retriever,
208
+ input_key="question")
209
+ return qa
210
+
211
+
212
+ def grade_model_answer(predicted_dataset: List, predictions: List, grade_answer_prompt: str) -> List:
213
+ """
214
+ Grades the distilled answer based on ground truth and model predictions.
215
+ @param predicted_dataset: A list of dictionaries containing ground truth questions and answers.
216
+ @param predictions: A list of dictionaries containing model predictions for the questions.
217
+ @param grade_answer_prompt: The prompt level for the grading. Either "Fast" or "Full".
218
+ @return: A list of scores for the distilled answers.
219
+ """
220
+ # Grade the distilled answer
221
+ st.info("`Grading model answer ...`")
222
+ # Set the grading prompt based on the grade_answer_prompt parameter
223
+ if grade_answer_prompt == "Fast":
224
+ prompt = GRADE_ANSWER_PROMPT_FAST
225
+ elif grade_answer_prompt == "Descriptive w/ bias check":
226
+ prompt = GRADE_ANSWER_PROMPT_BIAS_CHECK
227
+ elif grade_answer_prompt == "OpenAI grading prompt":
228
+ prompt = GRADE_ANSWER_PROMPT_OPENAI
229
+ else:
230
+ prompt = GRADE_ANSWER_PROMPT
231
+
232
+ # Create an evaluation chain
233
+ eval_chain = QAEvalChain.from_llm(
234
+ llm=ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0),
235
+ prompt=prompt
236
+ )
237
+
238
+ # Evaluate the predictions and ground truth using the evaluation chain
239
+ graded_outputs = eval_chain.evaluate(
240
+ predicted_dataset,
241
+ predictions,
242
+ question_key="question",
243
+ prediction_key="result"
244
+ )
245
+
246
+ return graded_outputs
247
+
248
+
249
+ def grade_model_retrieval(gt_dataset: List, predictions: List, grade_docs_prompt: str):
250
+ """
251
+ Grades the relevance of retrieved documents based on ground truth and model predictions.
252
+ @param gt_dataset: list of dictionaries containing ground truth questions and answers.
253
+ @param predictions: list of dictionaries containing model predictions for the questions
254
+ @param grade_docs_prompt: prompt level for the grading. Either "Fast" or "Full"
255
+ @return: list of scores for the retrieved documents.
256
+ """
257
+ # Grade the docs retrieval
258
+ st.info("`Grading relevance of retrieved docs ...`")
259
+
260
+ # Set the grading prompt based on the grade_docs_prompt parameter
261
+ prompt = GRADE_DOCS_PROMPT_FAST if grade_docs_prompt == "Fast" else GRADE_DOCS_PROMPT
262
+
263
+ # Create an evaluation chain
264
+ eval_chain = QAEvalChain.from_llm(
265
+ llm=ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0),
266
+ prompt=prompt
267
+ )
268
+
269
+ # Evaluate the predictions and ground truth using the evaluation chain
270
+ graded_outputs = eval_chain.evaluate(
271
+ gt_dataset,
272
+ predictions,
273
+ question_key="question",
274
+ prediction_key="result"
275
+ )
276
+ return graded_outputs
277
+
278
+
279
+ def run_evaluation(chain, retriever, eval_set, grade_prompt, retriever_type, num_neighbors):
280
+ """
281
+ Runs evaluation on a model's performance on a given evaluation dataset.
282
+ @param chain: Model chain used for answering questions
283
+ @param retriever: Document retriever used for retrieving relevant documents
284
+ @param eval_set: List of dictionaries containing questions and corresponding ground truth answers
285
+ @param grade_prompt: String prompt used for grading model's performance
286
+ @param retriever_type: String specifying the type of retriever used
287
+ @param num_neighbors: Number of neighbors to retrieve using the retriever
288
+ @return: A tuple of four items:
289
+ - answers_grade: A dictionary containing scores for the model's answers.
290
+ - retrieval_grade: A dictionary containing scores for the model's document retrieval.
291
+ - latencies_list: A list of latencies in seconds for each question answered.
292
+ - predictions_list: A list of dictionaries containing the model's predicted answers and relevant documents for each question.
293
+ """
294
+ st.info("`Running evaluation ...`")
295
+ predictions_list = []
296
+ retrieved_docs = []
297
+ gt_dataset = []
298
+ latencies_list = []
299
+
300
+ for data in eval_set:
301
+
302
+ # Get answer and log latency
303
+ start_time = time.time()
304
+ if retriever_type != "Llama-Index":
305
+ predictions_list.append(chain(data))
306
+ elif retriever_type == "Llama-Index":
307
+ answer = chain.query(data["question"], similarity_top_k=num_neighbors, response_mode="tree_summarize",
308
+ use_async=True)
309
+ predictions_list.append({"question": data["question"], "answer": data["answer"], "result": answer.response})
310
+ gt_dataset.append(data)
311
+ end_time = time.time()
312
+ elapsed_time = end_time - start_time
313
+ latencies_list.append(elapsed_time)
314
+
315
+ # Retrieve docs
316
+ retrieved_doc_text = ""
317
+ if retriever_type == "Llama-Index":
318
+ for i, doc in enumerate(answer.source_nodes):
319
+ retrieved_doc_text += "Doc %s: " % str(i + 1) + doc.node.text + " "
320
+
321
+ else:
322
+ docs = retriever.get_relevant_documents(data["question"])
323
+ for i, doc in enumerate(docs):
324
+ retrieved_doc_text += "Doc %s: " % str(i + 1) + doc.page_content + " "
325
+
326
+ retrieved = {"question": data["question"], "answer": data["answer"], "result": retrieved_doc_text}
327
+ retrieved_docs.append(retrieved)
328
+
329
+ # Grade
330
+ answers_grade = grade_model_answer(gt_dataset, predictions_list, grade_prompt)
331
+ retrieval_grade = grade_model_retrieval(gt_dataset, retrieved_docs, grade_prompt)
332
+ return answers_grade, retrieval_grade, latencies_list, predictions_list
333
+
334
+
335
+ # Auth
336
+ st.sidebar.image("img/diagnostic.jpg")
337
+
338
+ oai_api_key = st.sidebar.text_input("`OpenAI API Key:`", type="password")
339
+ ant_api_key = st.sidebar.text_input("`(Optional) Anthropic API Key:`", type="password")
340
+ hf_api_key = st.sidebar.text_input("`(Optional) HuggingFace API Token:`", type="password")
341
+
342
+ with st.sidebar.form("user_input"):
343
+
344
+ num_eval_questions = st.select_slider("`Number of eval questions`",
345
+ options=[1, 5, 10, 15, 20], value=5)
346
+
347
+ chunk_chars = st.select_slider("`Choose chunk size for splitting`",
348
+ options=[500, 750, 1000, 1500, 2000], value=1000)
349
+
350
+ overlap = st.select_slider("`Choose overlap for splitting`",
351
+ options=[0, 50, 100, 150, 200], value=100)
352
+
353
+ split_method = st.radio("`Split method`",
354
+ ("RecursiveTextSplitter",
355
+ "CharacterTextSplitter"),
356
+ index=0)
357
+
358
+ model = st.radio("`Choose model`",
359
+ ("gpt-3.5-turbo",
360
+ "gpt-4",
361
+ "anthropic"),
362
+ # Error raised by inference API: Model google/flan-t5-xl time out
363
+ #"flan-t5-xl"),
364
+ index=0)
365
+
366
+ retriever_type = st.radio("`Choose retriever`",
367
+ ("TF-IDF",
368
+ "SVM",
369
+ "Llama-Index",
370
+ "similarity-search"),
371
+ index=3)
372
+
373
+ num_neighbors = st.select_slider("`Choose # chunks to retrieve`",
374
+ options=[3, 4, 5, 6, 7, 8])
375
+
376
+ embeddings = st.radio("`Choose embeddings`",
377
+ ("HuggingFace",
378
+ "OpenAI"),
379
+ index=1)
380
+
381
+ grade_prompt = st.radio("`Grading style prompt`",
382
+ ("Fast",
383
+ "Descriptive",
384
+ "Descriptive w/ bias check",
385
+ "OpenAI grading prompt"),
386
+ index=0)
387
+
388
+ submitted = st.form_submit_button("Submit evaluation")
389
+
390
+ #st.sidebar.write("`By:` [@RLanceMartin](https://twitter.com/RLanceMartin)")
391
+
392
+ # App
393
+ st.header("`Auto-evaluator`")
394
+ st.info(
395
+ "`I am an evaluation tool for question-answering built on LangChain. Given documents, I will auto-generate a question-answer eval "
396
+ "set and evaluate using the selected chain settings. Experiments with different configurations are logged. "
397
+ "Optionally, provide your own eval set (as a JSON, see docs/karpathy-pod-eval.json for an example). If you don't have acess to GPT-4 or Anthropic, you can use our free hosted app here: https://autoevaluator.langchain.com/`")
398
+
399
+ with st.form(key='file_inputs'):
400
+ uploaded_file = st.file_uploader("`Please upload a file to evaluate (.txt or .pdf):` ",
401
+ type=['pdf', 'txt'],
402
+ accept_multiple_files=True)
403
+
404
+ uploaded_eval_set = st.file_uploader("`[Optional] Please upload eval set (.json):` ",
405
+ type=['json'],
406
+ accept_multiple_files=False)
407
+
408
+ submitted = st.form_submit_button("Submit files")
409
+
410
+ if uploaded_file and oai_api_key:
411
+
412
+ os.environ["OPENAI_API_KEY"] = oai_api_key
413
+ os.environ["ANTHROPIC_API_KEY"] = ant_api_key
414
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_api_key
415
+
416
+ # Load docs
417
+ text = load_docs(uploaded_file)
418
+ # Generate num_eval_questions questions, each from context of 3k chars randomly selected
419
+ if not uploaded_eval_set:
420
+ eval_set = generate_eval(text, num_eval_questions, 3000)
421
+ else:
422
+ eval_set = json.loads(uploaded_eval_set.read())
423
+ # Split text
424
+ splits = split_texts(text, chunk_chars, overlap, split_method)
425
+ # Make LLM
426
+ llm = make_llm(model)
427
+ # Make vector DB
428
+ retriever = make_retriever(splits, retriever_type, embeddings, num_neighbors, llm)
429
+ # Make chain
430
+ qa_chain = make_chain(llm, retriever, retriever_type)
431
+ # Grade model
432
+ graded_answers, graded_retrieval, latency, predictions = run_evaluation(qa_chain, retriever, eval_set, grade_prompt,
433
+ retriever_type, num_neighbors)
434
+
435
+ # Assemble outputs
436
+ d = pd.DataFrame(predictions)
437
+ d['answer score'] = [g['text'] for g in graded_answers]
438
+ d['docs score'] = [g['text'] for g in graded_retrieval]
439
+ d['latency'] = latency
440
+
441
+ # Summary statistics
442
+ mean_latency = d['latency'].mean()
443
+ correct_answer_count = len([text for text in d['answer score'] if "INCORRECT" not in text])
444
+ correct_docs_count = len([text for text in d['docs score'] if "Context is relevant: True" in text])
445
+ percentage_answer = (correct_answer_count / len(graded_answers)) * 100
446
+ percentage_docs = (correct_docs_count / len(graded_retrieval)) * 100
447
+
448
+ st.subheader("`Run Results`")
449
+ st.info(
450
+ "`I will grade the chain based on: 1/ the relevance of the retrived documents relative to the question and 2/ "
451
+ "the summarized answer relative to the ground truth answer. You can see (and change) to prompts used for "
452
+ "grading in text_utils`")
453
+ st.dataframe(data=d, use_container_width=True)
454
+
455
+ # Accumulate results
456
+ st.subheader("`Aggregate Results`")
457
+ st.info(
458
+ "`Retrieval and answer scores are percentage of retrived documents deemed relevant by the LLM grader ("
459
+ "relative to the question) and percentage of summarized answers deemed relevant (relative to ground truth "
460
+ "answer), respectively. The size of point correponds to the latency (in seconds) of retrieval + answer "
461
+ "summarization (larger circle = slower).`")
462
+ new_row = pd.DataFrame({'chunk_chars': [chunk_chars],
463
+ 'overlap': [overlap],
464
+ 'split': [split_method],
465
+ 'model': [model],
466
+ 'retriever': [retriever_type],
467
+ 'embedding': [embeddings],
468
+ 'num_neighbors': [num_neighbors],
469
+ 'Latency': [mean_latency],
470
+ 'Retrieval score': [percentage_docs],
471
+ 'Answer score': [percentage_answer]})
472
+ summary = pd.concat([summary, new_row], ignore_index=True)
473
+ st.dataframe(data=summary, use_container_width=True)
474
+ st.session_state.existing_df = summary
475
+
476
+ # Dataframe for visualization
477
+ show = summary.reset_index().copy()
478
+ show.columns = ['expt number', 'chunk_chars', 'overlap',
479
+ 'split', 'model', 'retriever', 'embedding', 'num_neighbors', 'Latency', 'Retrieval score',
480
+ 'Answer score']
481
+ show['expt number'] = show['expt number'].apply(lambda x: "Expt #: " + str(x + 1))
482
+ c = alt.Chart(show).mark_circle().encode(x='Retrieval score',
483
+ y='Answer score',
484
+ size=alt.Size('Latency'),
485
+ color='expt number',
486
+ tooltip=['expt number', 'Retrieval score', 'Latency', 'Answer score'])
487
+ st.altair_chart(c, use_container_width=True, theme="streamlit")
488
+
489
+ else:
490
+
491
+ st.warning("Please input file and API key(s)!")
requirements.txt ADDED
File without changes
text_utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from langchain.prompts import PromptTemplate
3
+
4
+
5
+ def clean_pdf_text(text: str) -> str:
6
+ """Cleans text extracted from a PDF file."""
7
+ # TODO: Remove References/Bibliography section.
8
+ return remove_citations(text)
9
+
10
+
11
+ def remove_citations(text: str) -> str:
12
+ """Removes in-text citations from a string."""
13
+ # (Author, Year)
14
+ text = re.sub(r'\([A-Za-z0-9,.\s]+\s\d{4}\)', '', text)
15
+ # [1], [2], [3-5], [3, 33, 49, 51]
16
+ text = re.sub(r'\[[0-9,-]+(,\s[0-9,-]+)*\]', '', text)
17
+ return text
18
+
19
+
20
+ template = """You are a teacher grading a quiz.
21
+ You are given a question, the student's answer, and the true answer, and are asked to score the student answer as either CORRECT or INCORRECT.
22
+ Example Format:
23
+ QUESTION: question here
24
+ STUDENT ANSWER: student's answer here
25
+ TRUE ANSWER: true answer here
26
+ GRADE: CORRECT or INCORRECT here
27
+ Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. Begin!
28
+ QUESTION: {query}
29
+ STUDENT ANSWER: {result}
30
+ TRUE ANSWER: {answer}
31
+ GRADE:
32
+ And explain why the STUDENT ANSWER is correct or incorrect.
33
+ """
34
+
35
+ GRADE_ANSWER_PROMPT = PromptTemplate(input_variables=["query", "result", "answer"], template=template)
36
+
37
+ template = """You are a teacher grading a quiz.
38
+ You are given a question, the student's answer, and the true answer, and are asked to score the student answer as either CORRECT or INCORRECT.
39
+ You are also asked to identify potential sources of bias in the question and in the true answer.
40
+ Example Format:
41
+ QUESTION: question here
42
+ STUDENT ANSWER: student's answer here
43
+ TRUE ANSWER: true answer here
44
+ GRADE: CORRECT or INCORRECT here
45
+ Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. Begin!
46
+ QUESTION: {query}
47
+ STUDENT ANSWER: {result}
48
+ TRUE ANSWER: {answer}
49
+ GRADE:
50
+ And explain why the STUDENT ANSWER is correct or incorrect, identify potential sources of bias in the QUESTION, and identify potential sources of bias in the TRUE ANSWER.
51
+ """
52
+
53
+ GRADE_ANSWER_PROMPT_BIAS_CHECK = PromptTemplate(input_variables=["query", "result", "answer"], template=template)
54
+
55
+ template = """You are assessing a submitted student answer to a question relative to the true answer based on the provided criteria:
56
+
57
+ ***
58
+ QUESTION: {query}
59
+ ***
60
+ STUDENT ANSWER: {result}
61
+ ***
62
+ TRUE ANSWER: {answer}
63
+ ***
64
+ Criteria:
65
+ relevance: Is the submission referring to a real quote from the text?"
66
+ conciseness: Is the answer concise and to the point?"
67
+ correct: Is the answer correct?"
68
+ ***
69
+ Does the submission meet the criterion? First, write out in a step by step manner your reasoning about the criterion to be sure that your conclusion is correct. Avoid simply stating the correct answers at the outset. Then print the "CORRECT" or "INCORRECT" (without quotes or punctuation) on its own line corresponding to the correct answer.
70
+ Reasoning:
71
+ """
72
+
73
+ GRADE_ANSWER_PROMPT_OPENAI = PromptTemplate(input_variables=["query", "result", "answer"], template=template)
74
+
75
+ template = """You are a teacher grading a quiz.
76
+ You are given a question, the student's answer, and the true answer, and are asked to score the student answer as either CORRECT or INCORRECT.
77
+ Example Format:
78
+ QUESTION: question here
79
+ STUDENT ANSWER: student's answer here
80
+ TRUE ANSWER: true answer here
81
+ GRADE: CORRECT or INCORRECT here
82
+ Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. Begin!
83
+ QUESTION: {query}
84
+ STUDENT ANSWER: {result}
85
+ TRUE ANSWER: {answer}
86
+ GRADE:"""
87
+
88
+ GRADE_ANSWER_PROMPT_FAST = PromptTemplate(input_variables=["query", "result", "answer"], template=template)
89
+
90
+ template = """
91
+ Given the question: \n
92
+ {query}
93
+ Decide if the following retrieved context is relevant: \n
94
+ {result}
95
+ Answer in the following format: \n
96
+ "Context is relevant: True or False." \n
97
+ And explain why it supports or does not support the correct answer: {answer}"""
98
+
99
+ GRADE_DOCS_PROMPT = PromptTemplate(input_variables=["query", "result", "answer"], template=template)
100
+
101
+ template = """
102
+ Given the question: \n
103
+ {query}
104
+ Decide if the following retrieved context is relevant to the {answer}: \n
105
+ {result}
106
+ Answer in the following format: \n
107
+ "Context is relevant: True or False." \n """
108
+
109
+ GRADE_DOCS_PROMPT_FAST = PromptTemplate(input_variables=["query", "result", "answer"], template=template)