Spaces:
Running
Running
File size: 3,504 Bytes
9bf1d31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from .base_handler import ModelHandler
from .nlp_models.sequence_classification_handler import SequenceClassificationHandler
from .nlp_models.question_answering_handler import QuestionAnsweringHandler
from .nlp_models.token_classification_handler import TokenClassificationHandler
from .nlp_models.causal_lm_handler import CausalLMHandler
from .nlp_models.embedding_model_handler import EmbeddingModelHandler
from .audio_models.whisper_handler import WhisperHandler
from .nlp_models.masked_lm_handler import MaskedLMHandler
from .nlp_models.seq2seq_lm_handler import Seq2SeqLMHandler
from .nlp_models.multiple_choice_handler import MultipleChoiceHandler
from .img_models.image_classification_handler import ImageClassificationHandler
from transformers import (
AutoModel,
AutoModelForTokenClassification,
AutoModelForSequenceClassification,
AutoModelForQuestionAnswering,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForSeq2SeqLM,
AutoModelForMultipleChoice,
)
TASK_CONFIGS = {
"embedding": {
"model_class": AutoModel,
"handler_class": EmbeddingModelHandler,
"example_text": "Hey, I am feeling way to good to be true.",
},
"ner": {
"model_class": AutoModelForTokenClassification,
"handler_class": TokenClassificationHandler,
"example_text": "John works at Google in New York as a software engineer.",
},
"text_classification": {
"model_class": AutoModelForSequenceClassification,
"handler_class": SequenceClassificationHandler,
"example_text": "This movie was great and I loved it.",
},
"question_answering": {
"model_class": AutoModelForQuestionAnswering,
"handler_class": QuestionAnsweringHandler,
"example_text": "The pyramids were built in ancient Egypt. QUES: Where were the pyramids built?",
},
"causal_lm": {
"model_class": AutoModelForCausalLM,
"handler_class": CausalLMHandler,
"example_text": "Once upon a time, there was ",
},
"mask_lm": {
"model_class": AutoModelForMaskedLM,
"handler_class": MaskedLMHandler,
"example_text": "The quick brown [MASK] jumps over the lazy dog.",
},
"seq2seq_lm": {
"model_class": AutoModelForSeq2SeqLM,
"handler_class": Seq2SeqLMHandler,
"example_text": "Translate English to French: The house is wonderful.",
},
"multiple_choice": {
"model_class": AutoModelForMultipleChoice,
"handler_class": MultipleChoiceHandler,
"example_text": "What is the capital of France? (A) Paris (B) London (C) Berlin (D) Rome",
},
"whisper_finetuning": {
"model_class": None, # Not implemented
"handler_class": WhisperHandler,
"example_text": "!!!!!NOT IMPLEMENTED!!!!!",
},
"image_classification": {
"model_class": None, # Not implemented
"handler_class": ImageClassificationHandler,
"example_text": "!!!!!NOT IMPLEMENTED!!!!!",
},
}
def get_model_handler(task: str, model_name: str, quantization_type: str, test_text: str):
task_config = TASK_CONFIGS.get(task)
if not task_config:
raise ValueError(f"No configuration found for task: {task}")
handler_class = task_config["handler_class"]
model_class = task_config["model_class"]
return handler_class(model_name, model_class, quantization_type, test_text) |