File size: 539 Bytes
ec31fb7
 
 
6f557e1
 
ec31fb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from typing import List
from transformers import PreTrainedModel
from transformers import PretrainedConfig
from .QBModelConfig import QBModelConfig
from .qbmodel import QuizBowlModel

class QBGenModelWrapper(PreTrainedModel):
    config_class= QBModelConfig


    def __init__(self, config):
        super().__init__(config)

        self.model = QuizBowlModel(use_hf_pkl=True)
        self.tfmodel = self.model.guesser

    
    def forward(self, question):
        output = self.model.guess_and_buzz([question])
        return output[0]