File size: 957 Bytes
04993ef 186d760 cbb881a 186d760 04993ef a69b0eb 186d760 7c35b20 186d760 cbb881a 373ed8b 186d760 04993ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from typing import List
from transformers import PreTrainedModel
from transformers import PretrainedConfig
from huggingface_hub import hf_hub_download
from .QBModelConfig import QBModelConfig
from .qbmodel import QuizBowlModel
REPO_ID = "Backedman/TriviaAnsweringMachineREAL"
FILENAME = "models/Mythology_tfidf.pkl"
class QBModelWrapper(PreTrainedModel):
config_class= QBModelConfig
config = QBModelConfig
print("jdkalf;jdskl")
hf_hub_download(repo_id=REPO_ID, filename='tfidf_model.py', local_dir='.')
hf_hub_download(repo_id=REPO_ID, filename='question_categorizer.py', local_dir='.')
hf_hub_download(repo_id=REPO_ID, filename='models/categorizer', local_dir='.')
def __init__(self, config):
super().__init__(config)
self.model = QuizBowlModel()
self.predict = self.model.predict
def forward(self, question):
output = self.model.predict(question)
return output[0] |