|
from QBModelWrapperCopy import QBModelWrapper |
|
from QBModelConfig import QBModelConfig |
|
from QBpipeline import QApipeline |
|
from transformers.pipelines import PIPELINE_REGISTRY |
|
from transformers import AutoModelForQuestionAnswering, TFAutoModelForQuestionAnswering |
|
from transformers import pipeline |
|
from transformers import AutoConfig, AutoModel, AutoModelForQuestionAnswering, TFAutoModel |
|
|
|
config = QBModelConfig() |
|
qb_model = QBModelWrapper(config) |
|
|
|
|
|
|
|
|
|
AutoConfig.register("QA-umd-quizbowl", QBModelConfig) |
|
AutoModel.register(QBModelConfig, QBModelWrapper) |
|
AutoModelForQuestionAnswering.register(QBModelConfig, QBModelWrapper) |
|
|
|
|
|
|
|
QBModelConfig.register_for_auto_class() |
|
QBModelWrapper.register_for_auto_class("AutoModel") |
|
QBModelWrapper.register_for_auto_class("AutoModelForQuestionAnswering") |
|
|
|
|
|
|
|
|
|
|
|
PIPELINE_REGISTRY.register_pipeline( |
|
"qa-pipeline-qb", |
|
pipeline_class=QApipeline, |
|
pt_model=AutoModelForQuestionAnswering, |
|
tf_model=TFAutoModelForQuestionAnswering, |
|
# pt_model=AutoModel, |
|
# tf_model=TFAutoModel |
|
) |
|
|
|
qa_pipe = pipeline("qa-pipeline-qb", model=qb_model) |
|
|
|
qa_pipe.save_pretrained("main", safe_serialization=False) |
|
|
|
result = qa_pipe(question="This star in the solar system has 8 planets", context="Context for the question") |
|
print(result["answer"]) |
|
|
|
|
|
|