File size: 466 Bytes
0b88116 |
1 2 3 4 5 6 7 8 9 10 11 |
from hybrid_model import HybridQAModel
from transformers import QuestionAnsweringPipeline
class HybridQAPipeline(QuestionAnsweringPipeline):
def __init__(self, model=None, tokenizer=None, **kwargs):
self.config = kwargs['custom']
super().__init__(model=model, tokenizer=tokenizer, **kwargs)
self.model = HybridQAModel(self.config)
def __call__(self, question, context):
return self.model.predict(question, context) |