|
from hybrid_model import HybridQAModel |
|
from transformers import QuestionAnsweringPipeline, PretrainedConfig |
|
|
|
class HybridQAPipeline(QuestionAnsweringPipeline): |
|
def __init__(self, model=None, tokenizer=None, **kwargs): |
|
extractive_id = "datarpit/distilbert-base-uncased-finetuned-natural-questions" |
|
generative_id = "MaRiOrOsSi/t5-base-finetuned-question-answering" |
|
self.config = HybridQAConfig(extractive_id, generative_id) |
|
super().__init__(model=model, tokenizer=tokenizer, **kwargs) |
|
self.model = HybridQAModel(self.config) |
|
|
|
def __call__(self, question, context): |
|
return self.model.predict(question, context) |
|
|
|
class HybridQAConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
extractive_id=None, |
|
generative_id = None, |
|
**kwargs |
|
): |
|
self.extractive_id = extractive_id |
|
self.generative_id = generative_id |
|
super().__init__(**kwargs) |