TriviaAnsweringMachine2 / QBModelWrapper.py
Backedman's picture
Upload model
c75c67f verified
raw
history blame contribute delete
No virus
545 Bytes
from typing import List
from transformers import PreTrainedModel
from transformers import PretrainedConfig
from QBModelConfig import QBModelConfig
from qbmodel import QuizBowlModel
class QBModelWrapper(PreTrainedModel):
config_class= QBModelConfig
config = QBModelConfig
def __init__(self, config):
super().__init__(config)
self.model = QuizBowlModel(clear=True)
self.tfmodel = self.model.predict
def forward(self, question):
output = self.model.predict(question)
return output[0]