loren-fact-checking / src /er_client /sentence_selection.py
jiangjiechen's picture
fix device issue: cpu
817b9ee
raw
history blame
2.13 kB
# -*- coding: utf-8 -*-
'''
@Author : Jiangjie Chen
@Time : 2020/9/20 11:42
@Contact : jjchen19@fudan.edu.cn
@Description:
'''
import torch
from transformers import BertTokenizer
from .retrieval_model.bert_model import BertForSequenceEncoder
from .retrieval_model.models import inference_model
from .retrieval_model.data_loader import DataLoaderTest
class SentSelector:
def __init__(self, pretrained_bert_path, select_model_path, args):
self.args = args
self.use_cuda = self.args.use_cuda and torch.cuda.is_available()
self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
self.bert_model = BertForSequenceEncoder.from_pretrained(pretrained_bert_path)
self.rank_model = inference_model(self.bert_model, self.args)
self.rank_model.load_state_dict(torch.load(select_model_path,
map_location=None if self.use_cuda else torch.device('cpu'))['model'])
if self.use_cuda:
self.bert_model = self.bert_model.cuda()
self.rank_model.cuda()
def rank_sentences(self, js: list):
'''
:param js: [{'claim': xxx, 'id': xx, 'evidence': xxx}]
:return: [(ent, num, sent, prob), (ent, num, sent, prob)]
'''
data_reader = DataLoaderTest(js, self.tokenizer, self.args, self.use_cuda)
self.rank_model.eval()
all_predict = dict()
for inp_tensor, msk_tensor, seg_tensor, ids, evi_list in data_reader:
probs = self.rank_model(inp_tensor, msk_tensor, seg_tensor)
probs = probs.tolist()
assert len(probs) == len(evi_list)
for i in range(len(probs)):
if ids[i] not in all_predict:
all_predict[ids[i]] = []
# if probs[i][1] >= probs[i][0]:
all_predict[ids[i]].append(tuple(evi_list[i]) + (probs[i],))
results = {}
for k, v in all_predict.items():
sorted_v = sorted(v, key=lambda x: x[-1], reverse=True)
results[k] = sorted_v[:self.args.evi_num]
return results