File size: 466 Bytes
0b88116
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
from hybrid_model import HybridQAModel
from transformers import QuestionAnsweringPipeline

class HybridQAPipeline(QuestionAnsweringPipeline):
    def __init__(self, model=None, tokenizer=None, **kwargs):
        self.config = kwargs['custom']
        super().__init__(model=model, tokenizer=tokenizer, **kwargs)
        self.model = HybridQAModel(self.config)

    def __call__(self, question, context):
        return self.model.predict(question, context)