hybrid-qa / hybrid_pipe.py
justinhl's picture
Upload HybridQAPipeline
0b88116 verified
raw
history blame
466 Bytes
from hybrid_model import HybridQAModel
from transformers import QuestionAnsweringPipeline
class HybridQAPipeline(QuestionAnsweringPipeline):
def __init__(self, model=None, tokenizer=None, **kwargs):
self.config = kwargs['custom']
super().__init__(model=model, tokenizer=tokenizer, **kwargs)
self.model = HybridQAModel(self.config)
def __call__(self, question, context):
return self.model.predict(question, context)