system-with-gen-pipeline / QBGenModelWrapper.py
nes470's picture
Update QBGenModelWrapper.py
9c7092d verified
from typing import List
from transformers import PreTrainedModel
from transformers import PretrainedConfig
from .QBModelConfig import QBModelConfig
from .qbmodel import QuizBowlModel
class QBGenModelWrapper(PreTrainedModel):
config_class= QBModelConfig
def __init__(self, config):
super().__init__(config)
self.model = QuizBowlModel(use_hf_pkl=True)
self.tfmodel = self.model.guesser
def forward(self, question):
output = self.model.guess_and_buzz([question])
return output[0]