Parth211 commited on
Commit
dfd232b
·
verified ·
1 Parent(s): d83cae6

update accuracy code

Browse files
Files changed (1) hide show
  1. app.py +181 -15
app.py CHANGED
@@ -22,6 +22,19 @@ import tqdm
22
  import accelerate
23
  import re
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  api_key = os.getenv('API_KEY')
26
 
27
 
@@ -249,30 +262,136 @@ def format_chat_history(message, chat_history):
249
  return formatted_chat_history
250
 
251
 
252
- def conversation(qa_chain, message, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  formatted_chat_history = format_chat_history(message, history)
254
- #print("formatted_chat_history",formatted_chat_history)
255
-
256
- # Generate response using QA chain
257
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
258
  response_answer = response["answer"]
 
259
  if response_answer.find("Helpful Answer:") != -1:
260
  response_answer = response_answer.split("Helpful Answer:")[-1]
261
  response_sources = response["source_documents"]
 
 
262
  response_source1 = response_sources[0].page_content.strip()
263
  response_source2 = response_sources[1].page_content.strip()
264
  response_source3 = response_sources[2].page_content.strip()
265
- # Langchain sources are zero-based
266
  response_source1_page = response_sources[0].metadata["page"] + 1
267
  response_source2_page = response_sources[1].metadata["page"] + 1
268
  response_source3_page = response_sources[2].metadata["page"] + 1
269
- # print ('chat response: ', response_answer)
270
- # print('DB source', response_sources)
271
-
272
- # Append user message and response to chat history
273
  new_history = history + [(message, response_answer)]
274
- # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
275
- return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
 
 
 
 
 
 
276
 
277
 
278
  def upload_file(file_obj):
@@ -285,6 +404,30 @@ def upload_file(file_obj):
285
  return list_file_path
286
 
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  def demo():
289
  with gr.Blocks(theme="base") as demo:
290
  vector_db = gr.State()
@@ -352,6 +495,21 @@ def demo():
352
  with gr.Row():
353
  submit_btn = gr.Button("Submit message")
354
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
  # Preprocessing events
357
  #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
@@ -366,10 +524,13 @@ def demo():
366
  queue=False)
367
 
368
  # Chatbot events
369
- msg.submit(conversation, \
370
- inputs=[qa_chain, msg, chatbot], \
371
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
372
- queue=False)
 
 
 
373
  submit_btn.click(conversation, \
374
  inputs=[qa_chain, msg, chatbot], \
375
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
@@ -378,6 +539,11 @@ def demo():
378
  inputs=None, \
379
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
380
  queue=False)
 
 
 
 
 
381
  demo.queue().launch(debug=True)
382
 
383
 
 
22
  import accelerate
23
  import re
24
 
25
+
26
+ import torch
27
+ from sacrebleu import corpus_bleu
28
+ from rouge_score import rouge_scorer
29
+ from bert_score import score
30
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline
31
+ import nltk
32
+ from nltk.util import ngrams
33
+
34
+
35
+
36
+
37
+
38
  api_key = os.getenv('API_KEY')
39
 
40
 
 
262
  return formatted_chat_history
263
 
264
 
265
+
266
+
267
+
268
+ ###############################################
269
+ class RAGEvaluator:
270
+ def __init__(self):
271
+ self.gpt2_model, self.gpt2_tokenizer = self.load_gpt2_model()
272
+ self.bias_pipeline = pipeline("zero-shot-classification", model="Hate-speech-CNERG/dehatebert-mono-english")
273
+
274
+ def load_gpt2_model(self):
275
+ model = GPT2LMHeadModel.from_pretrained('gpt2')
276
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
277
+ return model, tokenizer
278
+
279
+ def evaluate_bleu_rouge(self, candidates, references):
280
+ bleu_score = corpus_bleu(candidates, [references]).score
281
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
282
+ rouge_scores = [scorer.score(ref, cand) for ref, cand in zip(references, candidates)]
283
+ rouge1 = sum([score['rouge1'].fmeasure for score in rouge_scores]) / len(rouge_scores)
284
+ return bleu_score, rouge1
285
+
286
+ def evaluate_bert_score(self, candidates, references):
287
+ P, R, F1 = score(candidates, references, lang="en", model_type='bert-base-multilingual-cased')
288
+ return P.mean().item(), R.mean().item(), F1.mean().item()
289
+
290
+ def evaluate_perplexity(self, text):
291
+ encodings = self.gpt2_tokenizer(text, return_tensors='pt')
292
+ max_length = self.gpt2_model.config.n_positions
293
+ stride = 512
294
+ lls = []
295
+ for i in range(0, encodings.input_ids.size(1), stride):
296
+ begin_loc = max(i + stride - max_length, 0)
297
+ end_loc = min(i + stride, encodings.input_ids.size(1))
298
+ trg_len = end_loc - i
299
+ input_ids = encodings.input_ids[:, begin_loc:end_loc]
300
+ target_ids = input_ids.clone()
301
+ target_ids[:, :-trg_len] = -100
302
+ with torch.no_grad():
303
+ outputs = self.gpt2_model(input_ids, labels=target_ids)
304
+ log_likelihood = outputs[0] * trg_len
305
+ lls.append(log_likelihood)
306
+ ppl = torch.exp(torch.stack(lls).sum() / end_loc)
307
+ return ppl.item()
308
+
309
+ def evaluate_diversity(self, texts):
310
+ all_tokens = [tok for text in texts for tok in text.split()]
311
+ unique_bigrams = set(ngrams(all_tokens, 2))
312
+ diversity_score = len(unique_bigrams) / len(all_tokens) if all_tokens else 0
313
+ return diversity_score
314
+
315
+ def evaluate_racial_bias(self, text):
316
+ results = self.bias_pipeline([text], candidate_labels=["hate speech", "not hate speech"])
317
+ bias_score = results[0]['scores'][results[0]['labels'].index('hate speech')]
318
+ return bias_score
319
+
320
+ def evaluate_all(self, question, response, reference):
321
+ candidates = [response]
322
+ references = [reference]
323
+ bleu, rouge1 = self.evaluate_bleu_rouge(candidates, references)
324
+ bert_p, bert_r, bert_f1 = self.evaluate_bert_score(candidates, references)
325
+ perplexity = self.evaluate_perplexity(response)
326
+ diversity = self.evaluate_diversity(candidates)
327
+ racial_bias = self.evaluate_racial_bias(response)
328
+ return {
329
+ "BLEU": bleu,
330
+ "ROUGE-1": rouge1,
331
+ "BERT P": bert_p,
332
+ "BERT R": bert_r,
333
+ "BERT F1": bert_f1,
334
+ "Perplexity": perplexity,
335
+ "Diversity": diversity,
336
+ "Racial Bias": racial_bias
337
+ }
338
+
339
+ ###################################
340
+
341
+ evaluator = RAGEvaluator()
342
+
343
+
344
+ #################################
345
+
346
+ def display_metrics(metrics):
347
+ result = ""
348
+ for k, v in metrics.items():
349
+ if k == 'BLEU':
350
+ result += f"BLEU measures the overlap between the generated output and reference text based on n-grams. Higher scores indicate better match. Score obtained: {v}\n\n"
351
+ elif k == "ROUGE-1":
352
+ result += f"ROUGE-1 measures the overlap of unigrams between the generated output and reference text. Higher scores indicate better match. Score obtained: {v}\n\n"
353
+ elif k == 'BERT P':
354
+ result += "BERTScore evaluates the semantic similarity between the generated output and reference text using BERT embeddings.\n\n"
355
+ result += f"**BERT Precision**: {metrics['BERT P']}\n"
356
+ result += f"**BERT Recall**: {metrics['BERT R']}\n"
357
+ result += f"**BERT F1 Score**: {metrics['BERT F1']}\n\n"
358
+ elif k == 'Perplexity':
359
+ result += f"Perplexity measures how well a language model predicts the text. Lower values indicate better fluency and coherence. Score obtained: {v}\n\n"
360
+ elif k == 'Diversity':
361
+ result += f"Diversity measures the uniqueness of bigrams in the generated output. Higher values indicate more diverse and varied output. Score obtained: {v}\n\n"
362
+ elif k == 'Racial Bias':
363
+ result += f"Racial Bias score indicates the presence of biased language in the generated output. Higher scores indicate more bias. Score obtained: {v}\n\n"
364
+ return result
365
+
366
+ def conversation(qa_chain, message, history, evaluator):
367
  formatted_chat_history = format_chat_history(message, history)
368
+ question_by_user = message
369
+
 
370
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
371
  response_answer = response["answer"]
372
+ answer_of_question = response["answer"]
373
  if response_answer.find("Helpful Answer:") != -1:
374
  response_answer = response_answer.split("Helpful Answer:")[-1]
375
  response_sources = response["source_documents"]
376
+ context = " ".join([d.page_content for d in response_sources])
377
+
378
  response_source1 = response_sources[0].page_content.strip()
379
  response_source2 = response_sources[1].page_content.strip()
380
  response_source3 = response_sources[2].page_content.strip()
381
+
382
  response_source1_page = response_sources[0].metadata["page"] + 1
383
  response_source2_page = response_sources[1].metadata["page"] + 1
384
  response_source3_page = response_sources[2].metadata["page"] + 1
385
+
 
 
 
386
  new_history = history + [(message, response_answer)]
387
+
388
+ # Evaluate the metrics
389
+ metrics = evaluator.evaluate_all(question_by_user, answer_of_question, context)
390
+ evaluation_metrics = display_metrics(metrics)
391
+
392
+ return (qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page,
393
+ response_source2, response_source2_page, response_source3, response_source3_page,
394
+ question_by_user, context, answer_of_question, evaluation_metrics)
395
 
396
 
397
  def upload_file(file_obj):
 
404
  return list_file_path
405
 
406
 
407
+ # Function to display metrics
408
+ def display_metrics(metrics):
409
+ result = ""
410
+ for k, v in metrics.items():
411
+ if k == 'BLEU':
412
+ result += f"BLEU measures the overlap between the generated output and reference text based on n-grams. Higher scores indicate better match. Score obtained: {v}\n\n"
413
+ elif k == "ROUGE-1":
414
+ result += f"ROUGE-1 measures the overlap of unigrams between the generated output and reference text. Higher scores indicate better match. Score obtained: {v}\n\n"
415
+ elif k == 'BERT P':
416
+ result += "BERTScore evaluates the semantic similarity between the generated output and reference text using BERT embeddings.\n\n"
417
+ result += f"**BERT Precision**: {metrics['BERT P']}\n"
418
+ result += f"**BERT Recall**: {metrics['BERT R']}\n"
419
+ result += f"**BERT F1 Score**: {metrics['BERT F1']}\n\n"
420
+ elif k == 'Perplexity':
421
+ result += f"Perplexity measures how well a language model predicts the text. Lower values indicate better fluency and coherence. Score obtained: {v}\n\n"
422
+ elif k == 'Diversity':
423
+ result += f"Diversity measures the uniqueness of bigrams in the generated output. Higher values indicate more diverse and varied output. Score obtained: {v}\n\n"
424
+ elif k == 'Racial Bias':
425
+ result += f"Racial Bias score indicates the presence of biased language in the generated output. Higher scores indicate more bias. Score obtained: {v}\n\n"
426
+ return result
427
+
428
+
429
+
430
+ ###################################
431
  def demo():
432
  with gr.Blocks(theme="base") as demo:
433
  vector_db = gr.State()
 
495
  with gr.Row():
496
  submit_btn = gr.Button("Submit message")
497
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
498
+ with gr.Tab("Metrics"):
499
+ metrics_output = gr.Textbox(lines=10, label="Evaluation Metrics")
500
+
501
+
502
+
503
+ with gr.Tab("Metrics"):
504
+ metrics_output = gr.Textbox(lines=10, label="Evaluation Metrics")
505
+
506
+
507
+
508
+
509
+
510
+
511
+
512
+
513
 
514
  # Preprocessing events
515
  #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
 
524
  queue=False)
525
 
526
  # Chatbot events
527
+ msg.submit(interact, inputs=[gr.State(), msg, history], outputs=[
528
+ gr.State(), chatbot, history, response_source1, response_source1_page,
529
+ response_source2, response_source2_page, response_source3, response_source3_page,
530
+ None, None, None, metrics_output
531
+ ])
532
+
533
+
534
  submit_btn.click(conversation, \
535
  inputs=[qa_chain, msg, chatbot], \
536
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
 
539
  inputs=None, \
540
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
541
  queue=False)
542
+
543
+
544
+
545
+
546
+
547
  demo.queue().launch(debug=True)
548
 
549