Liyan06 commited on
Commit
0677600
1 Parent(s): 98d958b

add span highlight

Browse files
Files changed (2) hide show
  1. handler.py +45 -5
  2. requirements.txt +2 -1
handler.py CHANGED
@@ -1,5 +1,7 @@
1
  from minicheck_web.minicheck import MiniCheck
2
  from web_retrieval import *
 
 
3
 
4
 
5
  def sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk):
@@ -22,23 +24,37 @@ def sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk):
22
  class EndpointHandler():
23
  def __init__(self, path="./"):
24
  self.scorer = MiniCheck(path=path)
 
 
25
 
26
  def __call__(self, data):
27
 
 
 
28
  # Using user-provided document to do fact-checking
29
  if len(data['inputs']['docs']) == 1 and data['inputs']['docs'][0] != '':
30
  _, _, used_chunk, support_prob_per_chunk = self.scorer.score(data=data)
31
  ranked_docs, scores = sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk)
32
 
 
 
 
 
 
 
 
 
 
33
  outputs = {
34
  'ranked_docs': ranked_docs,
35
- 'scores': scores
36
- }
 
37
 
38
  else:
39
  assert len(data['inputs']['claims']) == 1, "Only one claim is allowed for web retrieval for the current version."
40
 
41
- claim = data['inputs']['claims'][0]
42
  ranked_docs, scores, ranked_urls = self.search_relevant_docs(claim)
43
 
44
  outputs = {
@@ -60,7 +76,7 @@ class EndpointHandler():
60
  scraped_results = e.map(scrape_url, search_results, itertools.repeat(timeout))
61
  end = time()
62
  print(f"Finished searching in {round((end - start), 1)} seconds.\n")
63
- scraped_results = [(r[0][:50000], r[1]) for r in scraped_results if r[0] and '��' not in r[0] and ".pdf" not in r[1]]
64
 
65
  retrieved_docs, urls = zip(*scraped_results[:max_search_results_per_query])
66
 
@@ -79,4 +95,28 @@ class EndpointHandler():
79
 
80
  ranked_docs, scores, ranked_urls = order_doc_score_url(used_chunk, support_prob_per_chunk, urls, allow_duplicated_urls=allow_duplicated_urls)
81
 
82
- return ranked_docs, scores, ranked_urls
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from minicheck_web.minicheck import MiniCheck
2
  from web_retrieval import *
3
+ from nltk.tokenize import sent_tokenize
4
+ import evaluate
5
 
6
 
7
  def sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk):
 
24
  class EndpointHandler():
25
  def __init__(self, path="./"):
26
  self.scorer = MiniCheck(path=path)
27
+ self.rouge = evaluate.load('rouge')
28
+
29
 
30
  def __call__(self, data):
31
 
32
+ claim = data['inputs']['claims'][0]
33
+
34
  # Using user-provided document to do fact-checking
35
  if len(data['inputs']['docs']) == 1 and data['inputs']['docs'][0] != '':
36
  _, _, used_chunk, support_prob_per_chunk = self.scorer.score(data=data)
37
  ranked_docs, scores = sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk)
38
 
39
+ span_to_highlight = []
40
+ for doc_chunk, score in zip(ranked_docs, scores):
41
+ # If the chunk can support the claim, find the sentence with the highest rouge score
42
+ if score > 0.5:
43
+ highest_score_sent, _ = self.chunk_and_highest_rouge_score(doc_chunk, claim)
44
+ span_to_highlight.append(highest_score_sent)
45
+ else:
46
+ span_to_highlight.append("")
47
+
48
  outputs = {
49
  'ranked_docs': ranked_docs,
50
+ 'scores': scores,
51
+ 'span_to_highlight': span_to_highlight
52
+ }
53
 
54
  else:
55
  assert len(data['inputs']['claims']) == 1, "Only one claim is allowed for web retrieval for the current version."
56
 
57
+
58
  ranked_docs, scores, ranked_urls = self.search_relevant_docs(claim)
59
 
60
  outputs = {
 
76
  scraped_results = e.map(scrape_url, search_results, itertools.repeat(timeout))
77
  end = time()
78
  print(f"Finished searching in {round((end - start), 1)} seconds.\n")
79
+ scraped_results = [(r[0][:20000], r[1]) for r in scraped_results if r[0] and '��' not in r[0] and ".pdf" not in r[1]]
80
 
81
  retrieved_docs, urls = zip(*scraped_results[:max_search_results_per_query])
82
 
 
95
 
96
  ranked_docs, scores, ranked_urls = order_doc_score_url(used_chunk, support_prob_per_chunk, urls, allow_duplicated_urls=allow_duplicated_urls)
97
 
98
+ return ranked_docs, scores, ranked_urls
99
+
100
+
101
+ def chunk_and_highest_rouge_score(self, doc, claim):
102
+
103
+ '''
104
+ Given a document and a claim, return the sentence with the highest rouge score and the score
105
+ '''
106
+
107
+ doc_sentences = sent_tokenize(doc)
108
+ claims = [claim] * len(doc_sentences)
109
+
110
+ results = self.rouge.compute(
111
+ predictions=doc_sentences,
112
+ references=claims,
113
+ use_aggregator=False)
114
+
115
+ highest_score = 0
116
+ highest_score_sent = ""
117
+ for i in range(len(doc_sentences)):
118
+ if results['rouge1'][i] > highest_score:
119
+ highest_score = results['rouge1'][i]
120
+ highest_score_sent = doc_sentences[i]
121
+
122
+ return highest_score_sent, highest_score
requirements.txt CHANGED
@@ -4,4 +4,5 @@ nltk==3.8.1
4
  pandas==2.2.1
5
  numpy==1.26.2
6
  tqdm
7
- bs4
 
 
4
  pandas==2.2.1
5
  numpy==1.26.2
6
  tqdm
7
+ bs4
8
+ rouge-score