File size: 2,486 Bytes
72a1159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee83d59
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import transformers
from utils import load_model, static_init
from global_config import GlobalConfig


@static_init
class ModelFactory:
    models_names = {}
    models = {}
    tokenizers = {}
    run_model = None
    dtype = torch.bfloat16
    load_device = torch.device("cpu")
    run_device = torch.device("cpu")

    @classmethod
    def __static_init__(cls):
        names_sec = GlobalConfig.get_section("models.names")
        if names_sec is not None:
            for name in names_sec:
                cls.models_names[name] = GlobalConfig.get("models.names", name)

        if GlobalConfig.get_section("models.params") is not None:
            dtype = GlobalConfig.get("models.params", "dtype")
            if dtype == "bfloat16":
                cls.dtype = torch.bfloat16
            elif dtype == "float16":
                cls.dtype = torch.float16
            elif dtype == "float32":
                cls.dtype = torch.float32

            load_device = GlobalConfig.get("models.params", "load_device")
            run_device = GlobalConfig.get("models.params", "run_device")
            if load_device is not None:
                cls.load_device = torch.device(str(load_device))
            if run_device is not None:
                cls.run_device = torch.device(str(run_device))

    @classmethod
    def __load_model(cls, name):
        if name not in cls.models_names:
            print(f"{name} is not a valid model name")
            return None

        if name not in cls.models:
            model, tokenizer = load_model(
                cls.models_names[name], cls.load_device
            )
            cls.models[name] = model
            cls.tokenizers[name] = tokenizer
        else:
            model, tokenizer = cls.models[name], cls.tokenizers[name]

        return model, tokenizer

    @classmethod
    def load_model(cls, name):
        if name not in cls.models:
            cls.__load_model(name)

        if name != cls.run_model and cls.run_model is not None:
            cls.models[cls.run_model].to(cls.load_device)

        cls.models[name].to(cls.run_device)
        cls.run_model = name

        return cls.models[name], cls.tokenizers[name]

    @classmethod
    def get_models_names(cls):
        return list(cls.models_names.keys())

    @classmethod
    def get_model_max_length(cls, name: str):
        if name in cls.tokenizers:
            return cls.tokenizers[name].model_max_length
        else:
            return 0