|
from minicheck_web.minicheck import MiniCheck |
|
from web_retrieval import * |
|
|
|
|
|
def sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk): |
|
''' |
|
Sort the chunks in a single document based on the probability of "supported" in descending order. |
|
This function is used when a user document is provided. |
|
''' |
|
|
|
flattened_docs = [doc for chunk in used_chunk for doc in chunk] |
|
flattened_scores = [score for chunk in support_prob_per_chunk for score in chunk] |
|
|
|
doc_score = list(zip(flattened_docs, flattened_scores)) |
|
ranked_doc_score = sorted(doc_score, key=lambda x: x[1], reverse=True) |
|
|
|
ranked_docs, scores = zip(*ranked_doc_score) |
|
|
|
return ranked_docs, scores |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path="./"): |
|
self.scorer = MiniCheck(path=path) |
|
|
|
def __call__(self, data): |
|
|
|
|
|
if len(data['inputs']['docs']) == 1 and data['inputs']['docs'][0] != '': |
|
_, _, used_chunk, support_prob_per_chunk = self.scorer.score(data=data) |
|
ranked_docs, scores = sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk) |
|
|
|
outputs = { |
|
'ranked_docs': ranked_docs, |
|
'scores': scores |
|
} |
|
|
|
else: |
|
assert len(data['inputs']['claims']) == 1, "Only one claim is allowed for web retrieval for the current version." |
|
|
|
claim = data['inputs']['claims'][0] |
|
ranked_docs, scores, ranked_urls = self.search_relevant_docs(claim) |
|
|
|
outputs = { |
|
'ranked_docs': ranked_docs, |
|
'scores': scores, |
|
'ranked_urls': ranked_urls |
|
} |
|
|
|
return outputs |
|
|
|
|
|
def search_relevant_docs(self, claim, timeout=10, max_search_results_per_query=5, allow_duplicated_urls=False): |
|
|
|
search_results = search_google(claim, timeout=timeout) |
|
|
|
print('Searching webpages...') |
|
start = time() |
|
with concurrent.futures.ThreadPoolExecutor() as e: |
|
scraped_results = e.map(scrape_url, search_results, itertools.repeat(timeout)) |
|
end = time() |
|
print(f"Finished searching in {round((end - start), 1)} seconds.\n") |
|
scraped_results = [(r[0][:50000], r[1]) for r in scraped_results if r[0] and '��' not in r[0] and ".pdf" not in r[1]] |
|
|
|
retrieved_docs, urls = zip(*scraped_results[:max_search_results_per_query]) |
|
|
|
print('Scoring webpages...') |
|
start = time() |
|
retrieved_data = { |
|
'inputs': { |
|
'docs': list(retrieved_docs), |
|
'claims': [claim]*len(retrieved_docs) |
|
} |
|
} |
|
_, _, used_chunk, support_prob_per_chunk = self.scorer.score(data=retrieved_data) |
|
end = time() |
|
num_chunks = len([item for items in used_chunk for item in items]) |
|
print(f'Finished {num_chunks} entailment checks in {round((end - start), 1)} seconds ({round(num_chunks / (end - start) * 60)} Doc./min).') |
|
|
|
ranked_docs, scores, ranked_urls = order_doc_score_url(used_chunk, support_prob_per_chunk, urls, allow_duplicated_urls=allow_duplicated_urls) |
|
|
|
return ranked_docs, scores, ranked_urls |