from hybrid_model import HybridQAModel | |
from transformers import QuestionAnsweringPipeline | |
class HybridQAPipeline(QuestionAnsweringPipeline): | |
def __init__(self, model=None, tokenizer=None, **kwargs): | |
self.config = kwargs['custom'] | |
super().__init__(model=model, tokenizer=tokenizer, **kwargs) | |
self.model = HybridQAModel(self.config) | |
def __call__(self, question, context): | |
return self.model.predict(question, context) |