chatbot_full / qa_model.py
letrunglinh's picture
Update qa_model.py
62ab090
raw
history blame
No virus
2.51 kB
from pathlib import Path
from transformers import AutoTokenizer, pipeline
import numpy as np
import torch
import torch.nn as nn
from text_utils import post_process_answer
from graph_utils import find_best_cluster
from optimum.intel import OVModelForQuestionAnswering
import os
import json
from text_utils import *
# os.environ['HTTP_PROXY'] = 'http://proxy.hcm.fpt.vn:80'
class QAEnsembleModel_modify(nn.Module):
# def __init__(self, model_name, model_checkpoints, entity_dict,
# thr=0.1, device="cuda:0"):
def __init__(self, model_name, entity_dict,
thr=0.1, device="cpu"):
super(QAEnsembleModel_modify, self).__init__()
self.nlps = []
# model_checkpoint = "./data/qa_model_robust.bin"
AUTH_TOKEN = "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"
# model_checkpoint = "letrunglinh/qa_pnc"
model_convert = OVModelForQuestionAnswering.from_pretrained(model_name, export= True, use_auth_token= AUTH_TOKEN)
# model_convert.half()
# model_convert.compile()
nlp = pipeline('question-answering', model=model_convert,
tokenizer=model_name)
self.nlps.append(nlp)
self.entity_dict = entity_dict
self.thr = thr
def forward(self, question, texts, ranking_scores=None):
if ranking_scores is None:
ranking_scores = np.ones((len(texts),))
curr_answers = []
curr_scores = []
best_score = 0
for i, nlp in enumerate(self.nlps):
for text, score in zip(texts, ranking_scores):
QA_input = {
'question': question,
'context': text
}
res = nlp(QA_input)
print(res)
if res["score"] > self.thr:
curr_answers.append(res["answer"])
curr_scores.append(res["score"])
res["score"] = res["score"] * score
if i == 0:
if res["score"] > best_score:
answer = res["answer"]
best_score = res["score"]
if len(curr_answers) == 0:
return None
curr_answers = [post_process_answer(x, self.entity_dict) for x in curr_answers]
answer = post_process_answer(answer, self.entity_dict)
new_best_answer = post_process_answer(find_best_cluster(curr_answers, answer), self.entity_dict)
return new_best_answer