system-with-gen-pipeline / QBGenModelWrapper.py
nes470's picture
Update QBGenModelWrapper.py
9c7092d verified
raw
history blame
539 Bytes
from typing import List
from transformers import PreTrainedModel
from transformers import PretrainedConfig
from .QBModelConfig import QBModelConfig
from .qbmodel import QuizBowlModel
class QBGenModelWrapper(PreTrainedModel):
config_class= QBModelConfig
def __init__(self, config):
super().__init__(config)
self.model = QuizBowlModel(use_hf_pkl=True)
self.tfmodel = self.model.guesser
def forward(self, question):
output = self.model.guess_and_buzz([question])
return output[0]