justinhl commited on
Commit
bdc6603
1 Parent(s): 631c9cf

Upload HybridQAPipeline

Browse files
Files changed (1) hide show
  1. hybrid_pipe.py +16 -3
hybrid_pipe.py CHANGED
@@ -1,11 +1,24 @@
1
  from hybrid_model import HybridQAModel
2
- from transformers import QuestionAnsweringPipeline
3
 
4
  class HybridQAPipeline(QuestionAnsweringPipeline):
5
  def __init__(self, model=None, tokenizer=None, **kwargs):
6
- self.config = kwargs['custom']
 
 
7
  super().__init__(model=model, tokenizer=tokenizer, **kwargs)
8
  self.model = HybridQAModel(self.config)
9
 
10
  def __call__(self, question, context):
11
- return self.model.predict(question, context)
 
 
 
 
 
 
 
 
 
 
 
 
1
  from hybrid_model import HybridQAModel
2
+ from transformers import QuestionAnsweringPipeline, PretrainedConfig
3
 
4
  class HybridQAPipeline(QuestionAnsweringPipeline):
5
  def __init__(self, model=None, tokenizer=None, **kwargs):
6
+ extractive_id = "datarpit/distilbert-base-uncased-finetuned-natural-questions"
7
+ generative_id = "MaRiOrOsSi/t5-base-finetuned-question-answering"
8
+ self.config = HybridQAConfig(extractive_id, generative_id)
9
  super().__init__(model=model, tokenizer=tokenizer, **kwargs)
10
  self.model = HybridQAModel(self.config)
11
 
12
  def __call__(self, question, context):
13
+ return self.model.predict(question, context)
14
+
15
+ class HybridQAConfig(PretrainedConfig):
16
+ def __init__(
17
+ self,
18
+ extractive_id=None,
19
+ generative_id = None,
20
+ **kwargs
21
+ ):
22
+ self.extractive_id = extractive_id
23
+ self.generative_id = generative_id
24
+ super().__init__(**kwargs)