TriviaAnsweringMachineREAL / QBModelWrapper.py
Backedman's picture
Upload QApipeline
186d760 verified
raw
history blame
No virus
923 Bytes
from typing import List
from transformers import PreTrainedModel
from transformers import PretrainedConfig
from QBModelConfig import QBModelConfig
from qbmodel import QuizBowlModel
from huggingface_hub import hf_hub_download
REPO_ID = "Backedman/TriviaAnsweringMachineREAL"
FILENAME = "models/Mythology_tfidf.pkl"
class QBModelWrapper(PreTrainedModel):
config_class= QBModelConfig
config = QBModelConfig
hf_hub_download(repo_id=REPO_ID, filename='tfidf_model.py', local_dir='')
hf_hub_download(repo_id=REPO_ID, filename='question_categorizer.py', local_dir='')
hf_hub_download(repo_id=REPO_ID, filename='models/categorizer', local_dir='')
def __init__(self, config):
super().__init__(config)
self.model = QuizBowlModel()
self.predict = self.model.predict
def forward(self, question):
output = self.model.predict(question)
return output[0]