hybrid-qa / hybrid_model.py
justinhl's picture
Upload 2 files
dead439 verified
raw
history blame
2.77 kB
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering
from transformers import pipeline
from huggingface_hub import PyTorchModelHubMixin
import torch
import torch.nn as nn
import torch.nn.functional as F
from hybrid_config import HybridQAConfig
class HybridQAModel(nn.Module, PyTorchModelHubMixin):
config_class = HybridQAConfig
def __init__(self, config):
super().__init__()
self.config = config
self.load_models(config.extractive_id, config.generative_id)
def load_models(self, extractive_id, generative_id):
self.tokenizer_extractive = AutoTokenizer.from_pretrained(extractive_id)
self.tokenizer_generative = AutoTokenizer.from_pretrained(generative_id)
self.model_extractive = AutoModelForQuestionAnswering.from_pretrained(extractive_id)
self.model_generative = AutoModelForSeq2SeqLM.from_pretrained(generative_id)
def predict(self, question, context):
result_gen, conf_gen = self.infer_generative(self.model_generative, self.tokenizer_generative, question)
result_ext, conf_ext = self.infer_extractive(self.model_extractive, self.tokenizer_extractive, question, context)
# print("Generative result: ",result_gen)
# print("Confidence: ", conf_gen)
# print("Extractive result: ", result_ext)
# print("Confidence: ", conf_ext)
if conf_gen > conf_ext:
return {'guess':result_gen, 'confidence':conf_gen}
else:
return {'guess':result_ext, 'confidence':conf_ext}
def infer_generative(self, model, tokenizer, input_text, **generate_kwargs):
max_input_length = min(tokenizer.model_max_length, model.config.max_length)
input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=max_input_length)
with torch.no_grad():
output_ids = model.generate(input_ids, **generate_kwargs)
decoded_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_probs = F.softmax(output_ids.float(), dim=-1).squeeze(0)
entropy = -(output_probs * torch.log(output_probs)).sum(dim=-1)
confidence_score = 1 - entropy.item()
model.save_pretrained("./base_models")
return decoded_output, confidence_score
def infer_extractive(self, model, tokenizer, question, context):
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
result = qa_pipeline(question=question, context=context)
confidence_score = result['score']
model.save_pretrained("./base_models")
return result['answer'], confidence_score