dh-mc commited on
Commit
01f4bd7
1 Parent(s): 32a6937

supported flag APPLY_CHAT_TEMPLATE_FOR_RAG

Browse files
.env.example CHANGED
@@ -25,7 +25,11 @@ OPENAI_MODEL_NAME=
25
  OLLAMA_MODEL_NAME=llama3:8b
26
 
27
  OLLAMA_RP=1.15
 
28
 
 
 
 
29
 
30
  # cpu, mps or cuda:0 - if unset, use whatever detected
31
  HF_EMBEDDINGS_DEVICE_TYPE=
 
25
  OLLAMA_MODEL_NAME=llama3:8b
26
 
27
  OLLAMA_RP=1.15
28
+ HF_RP=1.15
29
 
30
+ LANGCHAIN_DEBUG=false
31
+ BATCH_SIZE=1
32
+ APPLY_CHAT_TEMPLATE_FOR_RAG=true
33
 
34
  # cpu, mps or cuda:0 - if unset, use whatever detected
35
  HF_EMBEDDINGS_DEVICE_TYPE=
app_modules/init.py CHANGED
@@ -10,7 +10,7 @@ from langchain.vectorstores.chroma import Chroma
10
  from langchain.vectorstores.faiss import FAISS
11
 
12
  from app_modules.llm_loader import LLMLoader
13
- from app_modules.utils import get_device_types, init_settings, load_spacy_model
14
 
15
  found_dotenv = find_dotenv(".env")
16
 
@@ -53,21 +53,13 @@ def app_init():
53
  using_faiss = os.environ.get("FAISS_INDEX_PATH") is not None
54
  llm_model_type = os.environ.get("LLM_MODEL_TYPE")
55
 
56
- debug_metrics = os.getenv("DEBUG_METRICS", "false").lower() == "true"
57
-
58
- if debug_metrics:
59
- start = timer()
60
- load_spacy_model()
61
- end = timer()
62
- print(f"Completed in {end - start:.3f}s")
63
-
64
  qa_with_rag = os.getenv("QA_WITH_RAG", "true").lower() == "true"
65
  print(f"qa_with_rag: {qa_with_rag}")
66
 
67
  retrieve_from_questions_file = os.getenv("RETRIEVER_TYPE") == "questions_file"
68
  print(f"retrieve_from_questions_file: {retrieve_from_questions_file}", flush=True)
69
 
70
- if qa_with_rag and not retrieve_from_questions_file or debug_metrics:
71
  print(f"hf_embeddings_model_name: {hf_embeddings_model_name}")
72
  start = timer()
73
  embeddings = HuggingFaceInstructEmbeddings(
 
10
  from langchain.vectorstores.faiss import FAISS
11
 
12
  from app_modules.llm_loader import LLMLoader
13
+ from app_modules.utils import get_device_types, init_settings
14
 
15
  found_dotenv = find_dotenv(".env")
16
 
 
53
  using_faiss = os.environ.get("FAISS_INDEX_PATH") is not None
54
  llm_model_type = os.environ.get("LLM_MODEL_TYPE")
55
 
 
 
 
 
 
 
 
 
56
  qa_with_rag = os.getenv("QA_WITH_RAG", "true").lower() == "true"
57
  print(f"qa_with_rag: {qa_with_rag}")
58
 
59
  retrieve_from_questions_file = os.getenv("RETRIEVER_TYPE") == "questions_file"
60
  print(f"retrieve_from_questions_file: {retrieve_from_questions_file}", flush=True)
61
 
62
+ if qa_with_rag and not retrieve_from_questions_file:
63
  print(f"hf_embeddings_model_name: {hf_embeddings_model_name}")
64
  start = timer()
65
  embeddings = HuggingFaceInstructEmbeddings(
app_modules/llm_chat_chain.py CHANGED
@@ -6,7 +6,7 @@ from langchain.chains import ConversationChain, LLMChain
6
  from langchain.prompts import PromptTemplate
7
  from langchain.chains.base import Chain
8
 
9
- from app_modules.llm_inference import LLMInference
10
  from app_modules.utils import CustomizedConversationSummaryBufferMemory
11
  from langchain.chains import LLMChain
12
  from langchain.globals import get_debug
@@ -15,23 +15,6 @@ chat_history_enabled = os.getenv("CHAT_HISTORY_ENABLED", "false").lower() == "tr
15
  B_INST, E_INST = "[INST]", "[/INST]"
16
 
17
 
18
- def get_system_prompt_and_user_message(orca=False):
19
- # system_prompt = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
20
- system_prompt = (
21
- "You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."
22
- if orca
23
- else "You are a chatbot having a conversation with a human."
24
- )
25
-
26
- user_message = "{input}"
27
-
28
- if chat_history_enabled:
29
- user_message = "Chat History:\n\n{history} \n\n" + user_message
30
- system_prompt += " Read the chat history to get context."
31
-
32
- return system_prompt, user_message
33
-
34
-
35
  def create_llama_2_prompt_template():
36
  B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
37
 
@@ -141,19 +124,7 @@ class ChatChain(LLMInference):
141
  if not isinstance(inputs, list):
142
  inputs = {"input": inputs["question"]}
143
  elif self.llm_loader.llm_model_type == "huggingface":
144
- inputs = [
145
- [
146
- {
147
- "role": "system",
148
- "content": self.get_system_message(i),
149
- },
150
- {
151
- "role": "user",
152
- "content": self.get_user_message(i),
153
- },
154
- ]
155
- for i in inputs
156
- ]
157
  else:
158
  inputs = [{"input": i["question"]} for i in inputs]
159
 
@@ -161,9 +132,3 @@ class ChatChain(LLMInference):
161
  print("_process_inputs:", json.dumps(inputs, indent=4))
162
 
163
  return inputs
164
-
165
- def get_system_message(self, input) -> Chain:
166
- return get_system_prompt_and_user_message()[0]
167
-
168
- def get_user_message(self, input) -> Chain:
169
- return input["question"]
 
6
  from langchain.prompts import PromptTemplate
7
  from langchain.chains.base import Chain
8
 
9
+ from app_modules.llm_inference import LLMInference, get_system_prompt_and_user_message
10
  from app_modules.utils import CustomizedConversationSummaryBufferMemory
11
  from langchain.chains import LLMChain
12
  from langchain.globals import get_debug
 
15
  B_INST, E_INST = "[INST]", "[/INST]"
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def create_llama_2_prompt_template():
19
  B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
20
 
 
124
  if not isinstance(inputs, list):
125
  inputs = {"input": inputs["question"]}
126
  elif self.llm_loader.llm_model_type == "huggingface":
127
+ inputs = [self.apply_chat_template(input["question"]) for input in inputs]
 
 
 
 
 
 
 
 
 
 
 
 
128
  else:
129
  inputs = [{"input": i["question"]} for i in inputs]
130
 
 
132
  print("_process_inputs:", json.dumps(inputs, indent=4))
133
 
134
  return inputs
 
 
 
 
 
 
app_modules/llm_inference.py CHANGED
@@ -14,6 +14,25 @@ from langchain.chains.base import Chain
14
  from app_modules.llm_loader import LLMLoader, TextIteratorStreamer
15
  from app_modules.utils import remove_extra_spaces
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  class LLMInference(metaclass=abc.ABCMeta):
19
  def __init__(self, llm_loader):
@@ -143,3 +162,22 @@ class LLMInference(metaclass=abc.ABCMeta):
143
 
144
  t.join()
145
  return que.get()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from app_modules.llm_loader import LLMLoader, TextIteratorStreamer
15
  from app_modules.utils import remove_extra_spaces
16
 
17
+ chat_history_enabled = os.getenv("CHAT_HISTORY_ENABLED", "false").lower() == "true"
18
+
19
+
20
+ def get_system_prompt_and_user_message(orca=False):
21
+ # system_prompt = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
22
+ system_prompt = (
23
+ "You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."
24
+ if orca
25
+ else "You are a chatbot having a conversation with a human."
26
+ )
27
+
28
+ user_message = "{input}"
29
+
30
+ if chat_history_enabled:
31
+ user_message = "Chat History:\n\n{history} \n\n" + user_message
32
+ system_prompt += " Read the chat history to get context."
33
+
34
+ return system_prompt, user_message
35
+
36
 
37
  class LLMInference(metaclass=abc.ABCMeta):
38
  def __init__(self, llm_loader):
 
162
 
163
  t.join()
164
  return que.get()
165
+
166
+ def apply_chat_template(self, user_message):
167
+ result = (
168
+ []
169
+ if self.llm_loader.model_name.lower().startswith("gemma")
170
+ else [
171
+ {
172
+ "role": "system",
173
+ "content": get_system_prompt_and_user_message()[0],
174
+ }
175
+ ]
176
+ )
177
+ result.append(
178
+ {
179
+ "role": "user",
180
+ "content": user_message,
181
+ }
182
+ )
183
+ return result
app_modules/llm_qa_chain.py CHANGED
@@ -6,12 +6,17 @@ from langchain.chains import ConversationalRetrievalChain
6
  from langchain.chains.base import Chain
7
  from app_modules.llm_inference import LLMInference
8
  from app_modules.utils import CustomizedConversationSummaryBufferMemory
 
9
  from langchain_core.retrievers import BaseRetriever
10
  from langchain_core.documents import Document
11
  from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
12
  from langchain.globals import get_debug
13
 
14
  retrieve_from_questions_file = os.getenv("RETRIEVER_TYPE") == "questions_file"
 
 
 
 
15
 
16
  if retrieve_from_questions_file:
17
  questions_file_path = os.getenv("QUESTIONS_FILE_PATH")
@@ -108,8 +113,11 @@ class QAChain(LLMInference):
108
  # find the query in the df
109
  filtered = df[df["question"].str.lower() == query.lower()]
110
 
111
- context = filtered.iloc[0]["context"]
112
 
113
- return (
114
- f"{qa_system_prompt}\n\n{context}\n\nQuestion: {query}\n\nHelpful Answer:"
115
- )
 
 
 
 
6
  from langchain.chains.base import Chain
7
  from app_modules.llm_inference import LLMInference
8
  from app_modules.utils import CustomizedConversationSummaryBufferMemory
9
+
10
  from langchain_core.retrievers import BaseRetriever
11
  from langchain_core.documents import Document
12
  from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
13
  from langchain.globals import get_debug
14
 
15
  retrieve_from_questions_file = os.getenv("RETRIEVER_TYPE") == "questions_file"
16
+ apply_chat_template_for_rag = os.getenv("APPLY_CHAT_TEMPLATE_FOR_RAG") == "true"
17
+
18
+ print(f"retrieve_from_questions_file: {retrieve_from_questions_file}", flush=True)
19
+ print(f"apply_chat_template_for_rag: {apply_chat_template_for_rag}", flush=True)
20
 
21
  if retrieve_from_questions_file:
22
  questions_file_path = os.getenv("QUESTIONS_FILE_PATH")
 
113
  # find the query in the df
114
  filtered = df[df["question"].str.lower() == query.lower()]
115
 
116
+ context = filtered.iloc[0]["context"] if len(filtered) > 0 else ""
117
 
118
+ if apply_chat_template_for_rag:
119
+ return self.apply_chat_template(
120
+ f"{qa_system_prompt}\n\n{context}\n\nQuestion: {query}"
121
+ )
122
+ else:
123
+ return f"{qa_system_prompt}\n\n{context}\n\nQuestion: {query}\n\nHelpful Answer:"
app_modules/utils.py CHANGED
@@ -7,7 +7,8 @@ import os
7
  import platform
8
  import re
9
  from pathlib import Path
10
-
 
11
  import requests
12
  import torch
13
  from tqdm import tqdm
@@ -186,234 +187,79 @@ class CustomizedConversationSummaryBufferMemory(ConversationSummaryBufferMemory)
186
  )
187
 
188
 
189
- def CalculateDistance(entry1, entry2, distance_calculator):
190
- if entry1 == entry2:
191
- return 0
192
- distance = distance_calculator.evaluate_string_pairs(
193
- prediction=entry1, prediction_b=entry2
194
- )
195
- # print(f"entry1: {entry1}, entry2: {entry2}, distance: {distance['score']}")
196
- return distance["score"]
197
-
198
-
199
- def FindInList(entry, elist, distance_calculator=None, debug=False):
200
- for item in elist:
201
- if distance_calculator is not None:
202
- distance = CalculateDistance(entry, item, distance_calculator)
203
- if distance < distance_threshold:
204
- if debug:
205
- print(
206
- f"FindInList - matched by distance {distance:.3f}: {entry} - {item}"
207
- )
208
- return True
209
- if entry == item:
210
- return True
211
- return False
212
-
213
-
214
- def CalculatePRF1F2(
215
- goldAnswerList, predAnswerList, distance_calculator=None, debug=False
216
- ):
217
- if len(goldAnswerList) == 0:
218
- if len(predAnswerList) == 0:
219
- return [
220
- 1.0,
221
- 1.0,
222
- 1.0,
223
- 1.0,
224
- ] # consider it 'correct' when there is no labeled answer, and also no predicted answer
225
- else:
226
- return [
227
- 0.0,
228
- 1.0,
229
- 0.0,
230
- 0.0,
231
- ] # precision=0 and recall=1 when there is no labeled answer, but has some predicted answer(s)
232
- elif len(predAnswerList) == 0:
233
- return [
234
- 1.0,
235
- 0.0,
236
- 0.0,
237
- 0.0,
238
- ] # precision=1 and recall=0 when there is labeled answer(s), but no predicted answer
239
- else:
240
- glist = goldAnswerList
241
- plist = predAnswerList
242
-
243
- tp = 1e-40 # numerical trick
244
- fp = 0.0
245
- fn = 0.0
246
-
247
- for gentry in glist:
248
- if FindInList(
249
- gentry, plist, distance_calculator=distance_calculator, debug=True
250
- ):
251
- tp += 1
252
- else:
253
- fn += 1
254
- for pentry in plist:
255
- if not FindInList(pentry, glist, distance_calculator=distance_calculator):
256
- fp += 1
257
-
258
- precision = tp / (tp + fp)
259
- recall = tp / (tp + fn)
260
-
261
- f1 = (2 * precision * recall) / (precision + recall)
262
- f2 = (5 * precision * recall) / (4 * precision + recall)
263
- return [precision, recall, f1, f2]
264
-
265
-
266
- nlp = None
267
- distance_threshold = 0.05
268
-
269
-
270
- def load_spacy_model():
271
- import spacy
272
-
273
- global nlp
274
- if nlp is not None:
275
- return nlp
276
-
277
- global distance_threshold
278
- distance_threshold = float(os.getenv("DISTANCE_THRESHOLD", "0.05"))
279
 
280
- spacy_model_name = os.getenv("SPACY_MODEL_NAME", "en_core_web_trf")
281
 
282
- while True:
283
- try:
284
- print(f"loading spacy model from {spacy_model_name}")
285
- nlp = spacy.load(spacy_model_name)
286
- print(f"loaded spacy model from {spacy_model_name}")
287
- return nlp
288
- except OSError:
289
- print(f"downloading spacy model {spacy_model_name}")
290
- spacy.cli.download(spacy_model_name)
291
- print(f"downloaded spacy model {spacy_model_name}")
292
-
293
-
294
- def clean_text(text):
295
- text = text.lower()
296
- text = text.replace('"', "")
297
- text = text.replace(".", "")
298
- # text = text.replace("ō", "o")
299
- return text
300
 
 
 
 
 
 
301
 
302
- def get_entities_in_text(text, debug=False):
303
- nlp = load_spacy_model()
304
- doc = nlp(text)
305
- entities_in_text = []
306
- for word in doc.ents:
307
- if debug:
308
- print(word.text, word.label_)
309
- entity = clean_text(word.text)
310
- if entity not in entities_in_text:
311
- entities_in_text.append(entity)
312
 
313
- entities_in_text.sort()
314
- return entities_in_text
 
315
 
316
 
317
- def calculate_metrics(question, answer, distance_calculator=None, debug=False):
318
- ground_truth = question["answers"]
319
- ground_truth.sort()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
  if debug:
322
- print(f"question: {question}")
323
- print(f"answer: {answer}")
324
 
325
- print("entities_in_question ---------------")
326
- entities_in_question = get_entities_in_text(question["question"], debug)
327
 
328
- print("entities_in_answer -----------------")
329
- entities_in_answer = get_entities_in_text(answer, debug)
330
 
331
- print("done with NER with spaCy -----------")
332
 
333
- entities_in_answer.sort()
 
334
 
335
- predAnswerList = [
336
- pentry
337
- for pentry in entities_in_answer
338
- if not FindInList(pentry, entities_in_question)
339
- ]
340
 
341
- print(f"entities_in_question: {entities_in_question}")
342
- print(f"entities_in_answer: {entities_in_answer}")
343
- print(f"ground_truth: {ground_truth}")
344
- print(f"pred_answers: {predAnswerList}")
345
-
346
- precision, recall, f1, f2 = CalculatePRF1F2(
347
- ground_truth,
348
- predAnswerList,
349
- debug=debug,
350
- distance_calculator=distance_calculator,
351
- )
352
- print(f"precision: {precision}, recall: {recall}, f1: {f1}, f2: {f2}")
353
- else:
354
- precision = 0.0
355
- recall = 0.0
356
- f1 = 0.0
357
- f2 = 0.0
358
- entities_in_answer = []
359
- entities_in_question = []
360
-
361
- return (
362
- precision,
363
- recall,
364
- f1,
365
- f2,
366
- entities_in_answer,
367
- ground_truth,
368
- entities_in_question,
369
  )
370
-
371
-
372
- def calculate_metrics_gemini(question, answer, debug=False):
373
- precision = 0.0
374
- recall = 0.0
375
- f1 = 0.0
376
-
377
- return (precision, recall, f1)
378
-
379
-
380
- if __name__ == "__main__":
381
- from langchain_community.embeddings import HuggingFaceInstructEmbeddings
382
- from langchain.evaluation import load_evaluator
383
-
384
- hf_embeddings_device_type, hf_pipeline_device_type = get_device_types()
385
- print(f"hf_embeddings_device_type: {hf_embeddings_device_type}")
386
- print(f"hf_pipeline_device_type: {hf_pipeline_device_type}")
387
-
388
- hf_embeddings_model_name = "hkunlp/instructor-large"
389
- print(f"hf_embeddings_model_name: {hf_embeddings_model_name}")
390
- embeddings = HuggingFaceInstructEmbeddings(
391
- model_name=hf_embeddings_model_name,
392
- model_kwargs={"device": hf_embeddings_device_type},
393
- )
394
-
395
- hf_evaluator = load_evaluator("pairwise_embedding_distance", embeddings=embeddings)
396
-
397
- question = {
398
- "question": "what does jamaican people speak",
399
- "entities_in_question": ["jamaican"],
400
- "answers": ["jamaican english", "jamaican creole english language"],
401
- }
402
- answer = "Jamaican people primarily speak Jamaican Patois, which is an English-based creole language with significant West African influences. It is spoken as a native language by the majority of Jamaicans and also exists in various forms among Jamaican expatriates and non-Jamaicans in different parts of the world. The phonology of Jamaican Patois includes around 21 consonants (with some dialectal variation regarding the status of /h/ as a phoneme) and between nine and sixteen vowels, some of which are capable of nasalization or lengthening. There are also instances of palatalization in Jamaican Patois, where certain consonants appear to be phonemic in some dialects but may be considered phonetic in others. For example, the palatal stops [c], [ɟ], and [ɲ] may be analyzed as phonemes or as instances of phonetic palatalization depending on the account."
403
- calculate_metrics(question, answer, distance_calculator=hf_evaluator, debug=True)
404
-
405
- question = {
406
- "question": "who is governor of ohio 2011",
407
- "entities_in_question": ["2011"],
408
- "answers": ["john kasich", "return j. meigs, jr.", "ted strickland"],
409
- }
410
- answer = "The lieutenant governor of Ohio in 2011 was Mary Taylor, who served alongside Governor John Kasich. She assumed office on January 10, 2011, after being elected as the lieutenant governor in the 2010 election. During her tenure, she faced criticism for using the state airplane for personal errands and reportedly had high turnover among her staff."
411
- calculate_metrics(question, answer, distance_calculator=hf_evaluator, debug=True)
412
-
413
- question = {
414
- "question": "where is the fukushima daiichi nuclear power station",
415
- "entities_in_question": ["the fukushima daiichi nuclear power station"],
416
- "answers": ["japan", "okuma"],
417
- }
418
- answer = "The Fukushima Daiichi Nuclear Power Station is located in the towns of Ōkuma and Futaba in Fukushima Prefecture, Japan."
419
- calculate_metrics(question, answer, distance_calculator=hf_evaluator, debug=True)
 
7
  import platform
8
  import re
9
  from pathlib import Path
10
+ import evaluate
11
+ import pandas as pd
12
  import requests
13
  import torch
14
  from tqdm import tqdm
 
187
  )
188
 
189
 
190
+ bleu = evaluate.load("bleu")
191
+ rouge = evaluate.load("rouge")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
 
193
 
194
+ def calc_metrics(df):
195
+ predictions = [df["answer"][i] for i in range(len(df))]
196
+ references = [df["ground_truth"][i] for i in range(len(df))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
+ bleu_scores = bleu.compute(
199
+ predictions=predictions, references=references, max_order=1
200
+ )
201
+ rouge_scores = rouge.compute(predictions=predictions, references=references)
202
+ return {"bleu_scores": bleu_scores, "rouge_scores": rouge_scores}
203
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ pattern_abnormal_newlines = re.compile(r"\n{5,}")
206
+ pattern_text_repetitions = re.compile(r"\b(\w.+?)\b(\1+)", re.M | re.DOTALL)
207
+ exception_pattern = re.compile(r"(\w+\.)\1")
208
 
209
 
210
+ # final version for repetition detection
211
+ def detect_repetitions(
212
+ text, debug=False, pattern_text_repetitions=pattern_text_repetitions
213
+ ):
214
+ subtotals = [0, 0]
215
+
216
+ if isinstance(text, str):
217
+ patterns = [pattern_abnormal_newlines, pattern_text_repetitions]
218
+ for i, pattern in enumerate(patterns):
219
+ if debug:
220
+ print(
221
+ f"----detect {'abnormal newlines' if i == 0 else 'text repetitions'}----"
222
+ )
223
+ matches = pattern.finditer(text)
224
+ for match in matches:
225
+ if debug:
226
+ print(match)
227
+ for groupNum in range(0, len(match.groups())):
228
+ groupNum = groupNum + 1
229
+ print(
230
+ "Group {groupNum} found at {start}-{end}: `{group}`".format(
231
+ groupNum=groupNum,
232
+ start=match.start(groupNum),
233
+ end=match.end(groupNum),
234
+ group=match.group(groupNum),
235
+ )
236
+ )
237
+
238
+ if exception_pattern.match(match[0]):
239
+ if debug:
240
+ print("ignored: ", match[0])
241
+ continue
242
+
243
+ start, end = match.span()
244
+ subtotals[i] += end - start
245
+
246
+ result = (subtotals[0], subtotals[1], subtotals[0] + subtotals[1])
247
 
248
  if debug:
249
+ print(result)
250
+ return result
251
 
 
 
252
 
253
+ def detect_abnormal_newlines(text, debug=False):
254
+ return detect_repetitions(text, debug=debug)[0]
255
 
 
256
 
257
+ def detect_text_repetitions(text, debug=False):
258
+ return detect_repetitions(text, debug=debug)[1]
259
 
 
 
 
 
 
260
 
261
+ def detect_repetition_scores(text, debug=False):
262
+ newline_score, repetition_score, total_repetitions = detect_repetitions(
263
+ text, debug=debug
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  )
265
+ return pd.Series([newline_score, repetition_score, total_repetitions])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
qa_chain_test.py CHANGED
@@ -12,7 +12,7 @@ if chatting:
12
 
13
  from app_modules.init import app_init
14
  from app_modules.llm_qa_chain import QAChain
15
- from app_modules.utils import print_llm_response
16
 
17
  llm_loader, qa_chain = app_init()
18
 
@@ -116,7 +116,9 @@ if __name__ == "__main__":
116
  query = df["question"][i]
117
  id = df["id"][i]
118
 
119
- ground_truth = question["answers"]
 
 
120
 
121
  word_count = len(nltk.word_tokenize(answer))
122
 
@@ -128,6 +130,10 @@ if __name__ == "__main__":
128
  "ground_truth": ground_truth,
129
  }
130
 
 
 
 
 
131
  pd.options.display.float_format = "{:.3f}".format
132
  print(df2.describe())
133
 
@@ -147,6 +153,8 @@ if __name__ == "__main__":
147
  df2.to_csv(csv_file, mode="a", index=False, header=True)
148
  print(f"test results saved to file: {csv_file}")
149
 
 
 
150
  df = pd.DataFrame(
151
  {
152
  "model": [llm_loader.model_name],
@@ -154,6 +162,8 @@ if __name__ == "__main__":
154
  "word_count": [word_count],
155
  "inference_time": [total_time],
156
  "inference_speed": [word_count / total_time],
 
 
157
  }
158
  )
159
 
 
12
 
13
  from app_modules.init import app_init
14
  from app_modules.llm_qa_chain import QAChain
15
+ from app_modules.utils import print_llm_response, calc_metrics, detect_repetition_scores
16
 
17
  llm_loader, qa_chain = app_init()
18
 
 
116
  query = df["question"][i]
117
  id = df["id"][i]
118
 
119
+ ground_truth = question[
120
+ "wellFormedAnswers" if "wellFormedAnswers" in question else "answers"
121
+ ]
122
 
123
  word_count = len(nltk.word_tokenize(answer))
124
 
 
130
  "ground_truth": ground_truth,
131
  }
132
 
133
+ df2[["newline_score", "repetition_score", "total_repetitions"]] = df2[
134
+ "answer"
135
+ ].apply(detect_repetition_scores)
136
+
137
  pd.options.display.float_format = "{:.3f}".format
138
  print(df2.describe())
139
 
 
153
  df2.to_csv(csv_file, mode="a", index=False, header=True)
154
  print(f"test results saved to file: {csv_file}")
155
 
156
+ scores = calc_metrics(df2)
157
+
158
  df = pd.DataFrame(
159
  {
160
  "model": [llm_loader.model_name],
 
162
  "word_count": [word_count],
163
  "inference_time": [total_time],
164
  "inference_speed": [word_count / total_time],
165
+ "bleu1": [scores["bleu_scores"]["bleu"]],
166
+ "rougeL": [scores["rouge_scores"]["rougeL"]],
167
  }
168
  )
169
 
requirements.txt CHANGED
@@ -9,4 +9,6 @@ gradio==4.26.0
9
  spaces==0.27.1
10
  black==24.4.0
11
  chardet==5.2.0
12
- sentencepiece==0.2.0
 
 
 
9
  spaces==0.27.1
10
  black==24.4.0
11
  chardet==5.2.0
12
+ sentencepiece==0.2.0
13
+ evaluate==0.4.2
14
+ rouge_score==0.1.2