Spaces:
Running
Running
from .base_handler import ModelHandler | |
from .nlp_models.sequence_classification_handler import SequenceClassificationHandler | |
from .nlp_models.question_answering_handler import QuestionAnsweringHandler | |
from .nlp_models.token_classification_handler import TokenClassificationHandler | |
from .nlp_models.causal_lm_handler import CausalLMHandler | |
from .nlp_models.embedding_model_handler import EmbeddingModelHandler | |
from .audio_models.whisper_handler import WhisperHandler | |
from .nlp_models.masked_lm_handler import MaskedLMHandler | |
from .nlp_models.seq2seq_lm_handler import Seq2SeqLMHandler | |
from .nlp_models.multiple_choice_handler import MultipleChoiceHandler | |
from .img_models.image_classification_handler import ImageClassificationHandler | |
from transformers import ( | |
AutoModel, | |
AutoModelForTokenClassification, | |
AutoModelForSequenceClassification, | |
AutoModelForQuestionAnswering, | |
AutoModelForCausalLM, | |
AutoModelForMaskedLM, | |
AutoModelForSeq2SeqLM, | |
AutoModelForMultipleChoice, | |
) | |
TASK_CONFIGS = { | |
"embedding": { | |
"model_class": AutoModel, | |
"handler_class": EmbeddingModelHandler, | |
"example_text": "Hey, I am feeling way to good to be true.", | |
}, | |
"ner": { | |
"model_class": AutoModelForTokenClassification, | |
"handler_class": TokenClassificationHandler, | |
"example_text": "John works at Google in New York as a software engineer.", | |
}, | |
"text_classification": { | |
"model_class": AutoModelForSequenceClassification, | |
"handler_class": SequenceClassificationHandler, | |
"example_text": "This movie was great and I loved it.", | |
}, | |
"question_answering": { | |
"model_class": AutoModelForQuestionAnswering, | |
"handler_class": QuestionAnsweringHandler, | |
"example_text": "The pyramids were built in ancient Egypt. QUES: Where were the pyramids built?", | |
}, | |
"causal_lm": { | |
"model_class": AutoModelForCausalLM, | |
"handler_class": CausalLMHandler, | |
"example_text": "Once upon a time, there was ", | |
}, | |
"mask_lm": { | |
"model_class": AutoModelForMaskedLM, | |
"handler_class": MaskedLMHandler, | |
"example_text": "The quick brown [MASK] jumps over the lazy dog.", | |
}, | |
"seq2seq_lm": { | |
"model_class": AutoModelForSeq2SeqLM, | |
"handler_class": Seq2SeqLMHandler, | |
"example_text": "Translate English to French: The house is wonderful.", | |
}, | |
"multiple_choice": { | |
"model_class": AutoModelForMultipleChoice, | |
"handler_class": MultipleChoiceHandler, | |
"example_text": "What is the capital of France? (A) Paris (B) London (C) Berlin (D) Rome", | |
}, | |
"whisper_finetuning": { | |
"model_class": None, # Not implemented | |
"handler_class": WhisperHandler, | |
"example_text": "!!!!!NOT IMPLEMENTED!!!!!", | |
}, | |
"image_classification": { | |
"model_class": None, # Not implemented | |
"handler_class": ImageClassificationHandler, | |
"example_text": "!!!!!NOT IMPLEMENTED!!!!!", | |
}, | |
} | |
def get_model_handler(task: str, model_name: str, quantization_type: str, test_text: str): | |
task_config = TASK_CONFIGS.get(task) | |
if not task_config: | |
raise ValueError(f"No configuration found for task: {task}") | |
handler_class = task_config["handler_class"] | |
model_class = task_config["model_class"] | |
return handler_class(model_name, model_class, quantization_type, test_text) |