from typing import List from transformers import PreTrainedModel from transformers import PretrainedConfig from .QBModelConfig import QBModelConfig from .qbmodel import QuizBowlModel class QBModelWrapper(PreTrainedModel): config_class= QBModelConfig def __init__(self, config): super().__init__(config) self.model = QuizBowlModel() self.tfmodel = self.model.predict def forward(self, question, context): output = self.model.predict(question) return output[0]