hybrid-qa / hybrid_pipe.py
justinhl's picture
Upload HybridQAPipeline
bdc6603 verified
raw
history blame
963 Bytes
from hybrid_model import HybridQAModel
from transformers import QuestionAnsweringPipeline, PretrainedConfig
class HybridQAPipeline(QuestionAnsweringPipeline):
def __init__(self, model=None, tokenizer=None, **kwargs):
extractive_id = "datarpit/distilbert-base-uncased-finetuned-natural-questions"
generative_id = "MaRiOrOsSi/t5-base-finetuned-question-answering"
self.config = HybridQAConfig(extractive_id, generative_id)
super().__init__(model=model, tokenizer=tokenizer, **kwargs)
self.model = HybridQAModel(self.config)
def __call__(self, question, context):
return self.model.predict(question, context)
class HybridQAConfig(PretrainedConfig):
def __init__(
self,
extractive_id=None,
generative_id = None,
**kwargs
):
self.extractive_id = extractive_id
self.generative_id = generative_id
super().__init__(**kwargs)