hybrid-qa1 / hybrid_pipe.py
justinhl's picture
Update hybrid_pipe.py
ceb5437 verified
raw
history blame contribute delete
No virus
5.37 kB
from transformers import QuestionAnsweringPipeline, PretrainedConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering, AutoModelForCausalLM
from transformers import pipeline, PretrainedConfig
from huggingface_hub import PyTorchModelHubMixin
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import json
class HybridQAPipeline(QuestionAnsweringPipeline):
def __init__(self, model=None, tokenizer=None, **kwargs):
extractive_id = "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad"
generative_id = "microsoft/Phi-3-mini-4k-instruct"
self.config = HybridQAConfig(extractive_id, generative_id)
self.model = HybridQAModel(self.config)
super().__init__(model=self.model, tokenizer=tokenizer, **kwargs)
self.model = HybridQAModel(self.config)
def __call__(self, question, context):
return self.model.predict(question, context)
class HybridQAConfig(PretrainedConfig):
def __init__(
self,
extractive_id=None,
generative_id = None,
**kwargs
):
self.extractive_id = extractive_id
self.generative_id = generative_id
super().__init__(**kwargs)
class HybridQAModel(nn.Module, PyTorchModelHubMixin):
#config_class = HybridQAConfig
def __init__(self, config):
super().__init__()
self.config = config
self.load_models(config.extractive_id, config.generative_id)
def can_generate(self):
return False
def load_models(self, extractive_id, generative_id):
self.tokenizer_extractive = AutoTokenizer.from_pretrained(extractive_id, trust_remote_code=True)
self.tokenizer_generative = AutoTokenizer.from_pretrained(generative_id, trust_remote_code=True)
self.model_extractive = AutoModelForQuestionAnswering.from_pretrained(extractive_id, trust_remote_code=True)
self.model_generative = AutoModelForCausalLM.from_pretrained(generative_id, trust_remote_code=True)
def predict(self, question, context):
result_gen, conf_gen = self.infer_generative(self.model_generative, self.tokenizer_generative, question)
result_ext, conf_ext = self.infer_extractive(self.model_extractive, self.tokenizer_extractive, question, context)
if len(result_gen) < 30 and conf_gen > conf_ext:
return {'guess':result_gen, 'confidence':conf_gen}
else:
return {'guess':result_ext, 'confidence':conf_ext}
def infer_generative(self, model, tokenizer, input_text, **generate_kwargs):
max_input_length = min(tokenizer.model_max_length, model.config.max_length)
input_text += " Do not output anything but the question's answer."
messages = [
{"role": "user", "content": input_text}
]
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
generated_outputs = model.generate(input_ids, max_new_tokens=256, temperature=0.5, output_scores=True, return_dict_in_generate=True)
# Process the outputs to calculate normalized confidence
logits = generated_outputs.scores # List of tensors, one for each generated token
softmax_scores = [torch.softmax(logit, dim=-1) for logit in logits]
max_confidence_scores = [score.max().item() for score in softmax_scores] # Maximum probability as confidence
average_confidence = sum(max_confidence_scores) / len(max_confidence_scores) # Calculate average confidence
decoded_output = tokenizer.decode(generated_outputs.sequences[0], skip_special_tokens=True)
final_output = decoded_output[len(input_text):].split("\n")[-1]
average_confidence, final_output
return final_output, average_confidence
def infer_extractive(self, model, tokenizer, question, context):
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
result = qa_pipeline(question=question, context=context)
confidence_score = result['score']
return result['answer'], confidence_score
def save_pretrained(self, save_directory, **kwargs):
if not os.path.exists(save_directory):
os.makedirs(save_directory, exist_ok=True)
self.config.save_pretrained(save_directory, **kwargs)
self.model_extractive.save_pretrained(save_directory, **kwargs)
self.tokenizer_extractive.save_pretrained(save_directory, **kwargs)
self.model_generative.save_pretrained(save_directory, **kwargs)
self.tokenizer_generative.save_pretrained(save_directory, **kwargs)
def from_pretrained(cls, save_directory, *model_args, **model_kwargs):
config = PretrainedConfig.from_pretrained(save_directory, trust_remote_code=True)
model = HybridQAModel(config)
model.model_extractive = AutoModelForQuestionAnswering.from_pretrained(save_directory, trust_remote_code=True)
model.tokenizer_extractive = AutoTokenizer.from_pretrained(save_directory, trust_remote_code=True)
model.model_generative = AutoModelForCausalLM.from_pretrained(save_directory, trust_remote_code=True)
model.tokenizer_generative = AutoTokenizer.from_pretrained(save_directory, trust_remote_code=True)
return model