|
from typing import List |
|
from transformers import PreTrainedModel |
|
from transformers import PretrainedConfig |
|
from huggingface_hub import hf_hub_download |
|
from .QBModelConfig import QBModelConfig |
|
from .qbmodel import QuizBowlModel |
|
|
|
REPO_ID = "Backedman/TriviaAnsweringMachineREAL" |
|
FILENAME = "models/Mythology_tfidf.pkl" |
|
|
|
class QBModelWrapper(PreTrainedModel): |
|
config_class= QBModelConfig |
|
config = QBModelConfig |
|
|
|
print("jdkalf;jdskl") |
|
|
|
hf_hub_download(repo_id=REPO_ID, filename='tfidf_model.py', local_dir='.') |
|
hf_hub_download(repo_id=REPO_ID, filename='question_categorizer.py', local_dir='.') |
|
hf_hub_download(repo_id=REPO_ID, filename='models/categorizer', local_dir='.') |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.model = QuizBowlModel() |
|
self.predict = self.model.predict |
|
|
|
|
|
def forward(self, question): |
|
output = self.model.predict(question) |
|
return output[0] |