from hybrid_model import HybridQAModel from transformers import QuestionAnsweringPipeline, PretrainedConfig class HybridQAPipeline(QuestionAnsweringPipeline): def __init__(self, model=None, tokenizer=None, **kwargs): extractive_id = "datarpit/distilbert-base-uncased-finetuned-natural-questions" generative_id = "MaRiOrOsSi/t5-base-finetuned-question-answering" self.config = HybridQAConfig(extractive_id, generative_id) super().__init__(model=model, tokenizer=tokenizer, **kwargs) self.model = HybridQAModel(self.config) def __call__(self, question, context): return self.model.predict(question, context) class HybridQAConfig(PretrainedConfig): def __init__( self, extractive_id=None, generative_id = None, **kwargs ): self.extractive_id = extractive_id self.generative_id = generative_id super().__init__(**kwargs)