TriviaAnsweringMachine / QBModelWrapper.py
Backedman's picture
Update QBModelWrapper.py
318e687 verified
from typing import List
from transformers import PreTrainedModel
from transformers import PretrainedConfig
from .QBModelConfig import QBModelConfig
from .qbmodel import QuizBowlModel
class QBModelWrapper(PreTrainedModel):
config_class= QBModelConfig
def __init__(self, config):
super().__init__(config)
self.model = QuizBowlModel()
self.tfmodel = self.model.predict
def forward(self, question, context):
output = self.model.predict(question)
return output[0]