Spaces:
Build error
Build error
File size: 2,090 Bytes
7f7285f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
# -*- coding: utf-8 -*-
'''
@Author : Jiangjie Chen
@Time : 2020/9/21 16:13
@Contact : jjchen19@fudan.edu.cn
@Description:
'''
import cjjpy as cjj
import os
# from .document_retrieval import DocRetrieval
from .doc_retrieval_by_api import DocRetrieval
from .sentence_selection import SentSelector
arg_values = {
'batch_size': 32,
'dropout': 0.6,
'use_cuda': True,
'bert_hidden_dim': 768,
'layer': 1,
'num_labels': 3,
'evi_num': 5,
'threshold': 0.0,
'max_len': 120,
}
args = cjj.AttrDict(arg_values)
class EvidenceRetrieval:
def __init__(self, er_model_dir=cjj.AbsParentDir(__file__, '...', 'models/evidence_retrieval/')):
# self.doc_retriever = DocRetrieval(cjj.AbsParentDir(__file__, '...', 'data/fever.db'),
# add_claim=True, k_wiki_results=7)
self.doc_retrieval = DocRetrieval(link_type='tagme')
self.sent_selector = SentSelector(os.path.join(er_model_dir, 'bert_base/'),
os.path.join(er_model_dir, 'retrieval_model/model.best.pt'),
args)
def retrieve(self, claim):
# noun_phrases, wiki_results, predicted_pages = self.doc_retriever.exact_match(claim)
# evidence = []
# for page in predicted_pages:
# evidence.extend(self.doc_retriever.db.get_doc_lines(page))
evidence = self.doc_retrieval.retrieve_docs(claim)
evidence = self.rank_sentences(claim, evidence)
return evidence
def rank_sentences(self, claim, sentences, id=None):
'''
:param claim: str
:param sentences: [(ent, num, sent) * N]
:param id:
:return: [(ent, num, sent) * k]
'''
if id is None:
id = len(claim)
result = self.sent_selector.rank_sentences([{'claim': claim,
'evidence': sentences,
'id': id}])
evidence = result.get(id, [])
return evidence |