pipeline-as-repo / QBModelWrapper.py
nes470's picture
Upload folder using huggingface_hub
c3a942f verified
raw
history blame
491 Bytes
from typing import List
from transformers import PreTrainedModel
from transformers import PretrainedConfig
from QBModelConfig import QBModelConfig
from qbmodel import QuizBowlModel
class QBModelWrapper(PreTrainedModel):
config_class= QBModelConfig
def __init__(self, config):
super().__init__(config)
self.model = QuizBowlModel()
self.tfmodel = self.model.guesser
def forward(self, question):
return self.model.guess_and_buzz(question)