File size: 2,358 Bytes
9d5a733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import importlib

MODEL_REGISTRY = {}
TOKENIZER_REGISTRY = {}


def ModelSelect(model_name_or_path):
    model = None
    for name in MODEL_REGISTRY.keys():
        if name.lower() in model_name_or_path.lower():
            model = MODEL_REGISTRY[name]
    if model is None:
        model = MODEL_REGISTRY['llama']
    return model


def TokenizerSelect(model_name_or_path):
    tokenizer_init = None
    for name in TOKENIZER_REGISTRY.keys():
        if name.lower() in model_name_or_path.lower():
            tokenizer_init = TOKENIZER_REGISTRY[name]
    if tokenizer_init is None:
        tokenizer_init = TOKENIZER_REGISTRY['llama']
    return tokenizer_init


def register_model(name):
    def register_model_cls(cls):
        if name in MODEL_REGISTRY:
            return MODEL_REGISTRY[name]

        MODEL_REGISTRY[name] = cls
        # FIXME: Find a more elegant way to do this
        if name == 'phi':
            MODEL_REGISTRY['TinyLLaVA-3.1B'] = cls
        elif name == 'stablelm':
            MODEL_REGISTRY['TinyLLaVA-2.0B'] = cls
        elif name == 'llama':
            MODEL_REGISTRY['TinyLLaVA-1.5B'] = cls
        return cls

    return register_model_cls


def register_tokenizer(name):
    def register_tokenizer_cls(cls):
        if name in TOKENIZER_REGISTRY:
            return TOKENIZER_REGISTRY[name]

        TOKENIZER_REGISTRY[name] = cls
        # FIXME: Find a more elegant way to do this
        if name == 'phi':
            TOKENIZER_REGISTRY['TinyLLaVA-3.1B'] = cls
        elif name == 'stablelm':
            TOKENIZER_REGISTRY['TinyLLaVA-2.0B'] = cls
        elif name == 'llama':
            TOKENIZER_REGISTRY['TinyLLaVA-1.5B'] = cls
        return cls

    return register_tokenizer_cls


def import_models(models_dir, namespace):
    for file in os.listdir(models_dir):
        path = os.path.join(models_dir, file)
        if (
            not file.startswith("_")
            and not file.startswith(".")
            and file.endswith(".py")
        ):
            model_name = file[: file.find(".py")] if file.endswith(".py") else file
            importlib.import_module(namespace + "." + model_name)


# automatically import any Python files in the models/ directory
models_dir = os.path.join(os.path.dirname(__file__), 'language_model')
import_models(models_dir, "tinyllava.model.language_model")