File size: 3,504 Bytes
9bf1d31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from .base_handler import ModelHandler
from .nlp_models.sequence_classification_handler import SequenceClassificationHandler
from .nlp_models.question_answering_handler import QuestionAnsweringHandler
from .nlp_models.token_classification_handler import TokenClassificationHandler
from .nlp_models.causal_lm_handler import CausalLMHandler
from .nlp_models.embedding_model_handler import EmbeddingModelHandler
from .audio_models.whisper_handler import WhisperHandler
from .nlp_models.masked_lm_handler import MaskedLMHandler
from .nlp_models.seq2seq_lm_handler import Seq2SeqLMHandler
from .nlp_models.multiple_choice_handler import MultipleChoiceHandler
from .img_models.image_classification_handler import ImageClassificationHandler

from transformers import (
    AutoModel,
    AutoModelForTokenClassification,
    AutoModelForSequenceClassification,
    AutoModelForQuestionAnswering,
    AutoModelForCausalLM,
    AutoModelForMaskedLM,
    AutoModelForSeq2SeqLM,
    AutoModelForMultipleChoice,
)

TASK_CONFIGS = {
    "embedding": {
        "model_class": AutoModel,
        "handler_class": EmbeddingModelHandler,
        "example_text": "Hey, I am feeling way to good to be true.",
    },
    "ner": {
        "model_class": AutoModelForTokenClassification,
        "handler_class": TokenClassificationHandler,
        "example_text": "John works at Google in New York as a software engineer.",
    },
    "text_classification": {
        "model_class": AutoModelForSequenceClassification,
        "handler_class": SequenceClassificationHandler,
        "example_text": "This movie was great and I loved it.",
    },
    "question_answering": {
        "model_class": AutoModelForQuestionAnswering,
        "handler_class": QuestionAnsweringHandler,
        "example_text": "The pyramids were built in ancient Egypt. QUES: Where were the pyramids built?",
    },
    "causal_lm": {
        "model_class": AutoModelForCausalLM,
        "handler_class": CausalLMHandler,
        "example_text": "Once upon a time, there was ",
    },
    "mask_lm": {
        "model_class": AutoModelForMaskedLM,
        "handler_class": MaskedLMHandler,
        "example_text": "The quick brown [MASK] jumps over the lazy dog.",
    },
    "seq2seq_lm": {
        "model_class": AutoModelForSeq2SeqLM,
        "handler_class": Seq2SeqLMHandler,
        "example_text": "Translate English to French: The house is wonderful.",
    },
    "multiple_choice": {
        "model_class": AutoModelForMultipleChoice,
        "handler_class": MultipleChoiceHandler,
        "example_text": "What is the capital of France? (A) Paris (B) London (C) Berlin (D) Rome",
    },
    "whisper_finetuning": {
        "model_class": None, # Not implemented
        "handler_class": WhisperHandler,
        "example_text": "!!!!!NOT IMPLEMENTED!!!!!",
    },
    "image_classification": {
        "model_class": None,  # Not implemented
        "handler_class": ImageClassificationHandler,
        "example_text": "!!!!!NOT IMPLEMENTED!!!!!",
    },
}

def get_model_handler(task: str, model_name: str, quantization_type: str, test_text: str):
    task_config = TASK_CONFIGS.get(task)
    if not task_config:
        raise ValueError(f"No configuration found for task: {task}")

    handler_class = task_config["handler_class"]
    model_class = task_config["model_class"]
    return handler_class(model_name, model_class, quantization_type, test_text)