|
from hybrid_model import HybridQAModel
|
|
from transformers import QuestionAnsweringPipeline, PretrainedConfig
|
|
|
|
class HybridQAPipeline(QuestionAnsweringPipeline):
|
|
def __init__(self, model=None, tokenizer=None, **kwargs):
|
|
extractive_id = "datarpit/distilbert-base-uncased-finetuned-natural-questions"
|
|
generative_id = "MaRiOrOsSi/t5-base-finetuned-question-answering"
|
|
self.config = HybridQAConfig(extractive_id, generative_id)
|
|
super().__init__(model=model, tokenizer=tokenizer, **kwargs)
|
|
self.model = HybridQAModel(self.config)
|
|
|
|
def __call__(self, question, context):
|
|
return self.model.predict(question, context)
|
|
|
|
class HybridQAConfig(PretrainedConfig):
|
|
def __init__(
|
|
self,
|
|
extractive_id=None,
|
|
generative_id = None,
|
|
**kwargs
|
|
):
|
|
self.extractive_id = extractive_id
|
|
self.generative_id = generative_id
|
|
super().__init__(**kwargs) |