File size: 545 Bytes
e6a7c75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
from typing import List
from transformers import PreTrainedModel
from transformers import PretrainedConfig
from QBModelConfig import QBModelConfig
from qbmodel import QuizBowlModel
class QBModelWrapper(PreTrainedModel):
config_class= QBModelConfig
config = QBModelConfig
def __init__(self, config):
super().__init__(config)
self.model = QuizBowlModel(clear=True)
self.tfmodel = self.model.predict
def forward(self, question):
output = self.model.predict(question)
return output[0] |