File size: 940 Bytes
43dadf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
from hybrid_model import HybridQAModel
from transformers import QuestionAnsweringPipeline, PretrainedConfig
class HybridQAPipeline(QuestionAnsweringPipeline):
def __init__(self, model=None, tokenizer=None, **kwargs):
extractive_id = "datarpit/distilbert-base-uncased-finetuned-natural-questions"
generative_id = "MaRiOrOsSi/t5-base-finetuned-question-answering"
self.config = HybridQAConfig(extractive_id, generative_id)
super().__init__(model=model, tokenizer=tokenizer, **kwargs)
self.model = HybridQAModel(self.config)
def __call__(self, question, context):
return self.model.predict(question, context)
class HybridQAConfig(PretrainedConfig):
def __init__(
self,
extractive_id=None,
generative_id = None,
**kwargs
):
self.extractive_id = extractive_id
self.generative_id = generative_id
super().__init__(**kwargs) |