|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering
|
|
from transformers import pipeline
|
|
|
|
from huggingface_hub import PyTorchModelHubMixin
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from hybrid_config import HybridQAConfig
|
|
|
|
class HybridQAModel(nn.Module, PyTorchModelHubMixin):
|
|
config_class = HybridQAConfig
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.load_models(config.extractive_id, config.generative_id)
|
|
|
|
def load_models(self, extractive_id, generative_id):
|
|
self.tokenizer_extractive = AutoTokenizer.from_pretrained(extractive_id)
|
|
self.tokenizer_generative = AutoTokenizer.from_pretrained(generative_id)
|
|
|
|
self.model_extractive = AutoModelForQuestionAnswering.from_pretrained(extractive_id)
|
|
self.model_generative = AutoModelForSeq2SeqLM.from_pretrained(generative_id)
|
|
|
|
def predict(self, question, context):
|
|
result_gen, conf_gen = self.infer_generative(self.model_generative, self.tokenizer_generative, question)
|
|
result_ext, conf_ext = self.infer_extractive(self.model_extractive, self.tokenizer_extractive, question, context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if conf_gen > conf_ext:
|
|
return {'guess':result_gen, 'confidence':conf_gen}
|
|
else:
|
|
return {'guess':result_ext, 'confidence':conf_ext}
|
|
|
|
def infer_generative(self, model, tokenizer, input_text, **generate_kwargs):
|
|
max_input_length = min(tokenizer.model_max_length, model.config.max_length)
|
|
input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=max_input_length)
|
|
|
|
with torch.no_grad():
|
|
output_ids = model.generate(input_ids, **generate_kwargs)
|
|
decoded_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
|
|
output_probs = F.softmax(output_ids.float(), dim=-1).squeeze(0)
|
|
entropy = -(output_probs * torch.log(output_probs)).sum(dim=-1)
|
|
confidence_score = 1 - entropy.item()
|
|
|
|
model.save_pretrained("./base_models")
|
|
return decoded_output, confidence_score
|
|
|
|
def infer_extractive(self, model, tokenizer, question, context):
|
|
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
|
|
result = qa_pipeline(question=question, context=context)
|
|
confidence_score = result['score']
|
|
|
|
model.save_pretrained("./base_models")
|
|
return result['answer'], confidence_score |