File size: 940 Bytes
43dadf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from hybrid_model import HybridQAModel
from transformers import QuestionAnsweringPipeline, PretrainedConfig

class HybridQAPipeline(QuestionAnsweringPipeline):
    def __init__(self, model=None, tokenizer=None, **kwargs):
        extractive_id = "datarpit/distilbert-base-uncased-finetuned-natural-questions"
        generative_id = "MaRiOrOsSi/t5-base-finetuned-question-answering"
        self.config = HybridQAConfig(extractive_id, generative_id)
        super().__init__(model=model, tokenizer=tokenizer, **kwargs)
        self.model = HybridQAModel(self.config)

    def __call__(self, question, context):
        return self.model.predict(question, context)

class HybridQAConfig(PretrainedConfig):
    def __init__(
        self, 
        extractive_id=None,
        generative_id = None,
        **kwargs
    ):
        self.extractive_id = extractive_id
        self.generative_id = generative_id
        super().__init__(**kwargs)