Spaces:
Sleeping
Sleeping
File size: 2,486 Bytes
72a1159 ee83d59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
import transformers
from utils import load_model, static_init
from global_config import GlobalConfig
@static_init
class ModelFactory:
models_names = {}
models = {}
tokenizers = {}
run_model = None
dtype = torch.bfloat16
load_device = torch.device("cpu")
run_device = torch.device("cpu")
@classmethod
def __static_init__(cls):
names_sec = GlobalConfig.get_section("models.names")
if names_sec is not None:
for name in names_sec:
cls.models_names[name] = GlobalConfig.get("models.names", name)
if GlobalConfig.get_section("models.params") is not None:
dtype = GlobalConfig.get("models.params", "dtype")
if dtype == "bfloat16":
cls.dtype = torch.bfloat16
elif dtype == "float16":
cls.dtype = torch.float16
elif dtype == "float32":
cls.dtype = torch.float32
load_device = GlobalConfig.get("models.params", "load_device")
run_device = GlobalConfig.get("models.params", "run_device")
if load_device is not None:
cls.load_device = torch.device(str(load_device))
if run_device is not None:
cls.run_device = torch.device(str(run_device))
@classmethod
def __load_model(cls, name):
if name not in cls.models_names:
print(f"{name} is not a valid model name")
return None
if name not in cls.models:
model, tokenizer = load_model(
cls.models_names[name], cls.load_device
)
cls.models[name] = model
cls.tokenizers[name] = tokenizer
else:
model, tokenizer = cls.models[name], cls.tokenizers[name]
return model, tokenizer
@classmethod
def load_model(cls, name):
if name not in cls.models:
cls.__load_model(name)
if name != cls.run_model and cls.run_model is not None:
cls.models[cls.run_model].to(cls.load_device)
cls.models[name].to(cls.run_device)
cls.run_model = name
return cls.models[name], cls.tokenizers[name]
@classmethod
def get_models_names(cls):
return list(cls.models_names.keys())
@classmethod
def get_model_max_length(cls, name: str):
if name in cls.tokenizers:
return cls.tokenizers[name].model_max_length
else:
return 0
|