chatbot_full / qa_model.py
letrunglinh's picture
Update qa_model.py
31f75a0
raw
history blame contribute delete
No virus
2.51 kB
from pathlib import Path
from transformers import AutoTokenizer, pipeline
import numpy as np
import torch
import torch.nn as nn
from text_utils import post_process_answer
from graph_utils import find_best_cluster
from optimum.intel import OVModelForQuestionAnswering
import os
import json
from text_utils import *
# os.environ['HTTP_PROXY'] = 'http://proxy.hcm.fpt.vn:80'
class QAEnsembleModel_modify(nn.Module):
# def __init__(self, model_name, model_checkpoints, entity_dict,
# thr=0.1, device="cuda:0"):
def __init__(self, model_name, entity_dict,
thr=0.1, device="cpu"):
super(QAEnsembleModel_modify, self).__init__()
self.nlps = []
# model_checkpoint = "./data/qa_model_robust.bin"
AUTH_TOKEN = "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"
# model_checkpoint = "letrunglinh/qa_pnc"
model_convert = OVModelForQuestionAnswering.from_pretrained(model_name, export= True, use_auth_token= AUTH_TOKEN)
model_convert.half()
model_convert.compile()
nlp = pipeline('question-answering', model=model_convert,
tokenizer=model_name)
self.nlps.append(nlp)
self.entity_dict = entity_dict
self.thr = thr
def forward(self, question, texts, ranking_scores=None):
if ranking_scores is None:
ranking_scores = np.ones((len(texts),))
curr_answers = []
curr_scores = []
best_score = 0
for i, nlp in enumerate(self.nlps):
for text, score in zip(texts, ranking_scores):
QA_input = {
'question': question,
'context': text
}
res = nlp(QA_input)
print(res)
if res["score"] > self.thr:
curr_answers.append(res["answer"])
curr_scores.append(res["score"])
res["score"] = res["score"] * score
if i == 0:
if res["score"] > best_score:
answer = res["answer"]
best_score = res["score"]
if len(curr_answers) == 0:
return None
curr_answers = [post_process_answer(x, self.entity_dict) for x in curr_answers]
answer = post_process_answer(answer, self.entity_dict)
new_best_answer = post_process_answer(find_best_cluster(curr_answers, answer), self.entity_dict)
return new_best_answer