Spaces:
Sleeping
Sleeping
Add global config and model factory
Browse files- config.ini +34 -0
- global_config.py +36 -0
- model_factory.py +72 -0
config.ini
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[server]
|
2 |
+
port = int:42069
|
3 |
+
|
4 |
+
[models.names]
|
5 |
+
gpt2 = str:openai-community/gpt2
|
6 |
+
gpt2_medium = str:openai-community/gpt2-medium
|
7 |
+
gpt2_large = str:openai-community/gpt2-large
|
8 |
+
gpt2_xl = str:openai-community/gpt2-xl
|
9 |
+
|
10 |
+
llama_3_8b_intruct = str:meta-llama/Meta-Llama-3-8B-Instruct
|
11 |
+
llama_3_70b_instruct = str:meta-llama/Meta-Llama-3-70B-Instruct
|
12 |
+
|
13 |
+
[models.params]
|
14 |
+
dtype = str:bfloat16
|
15 |
+
run_device = str:cpu
|
16 |
+
load_device = str:cuda
|
17 |
+
|
18 |
+
[encrypt.default]
|
19 |
+
gen_model = str:gpt2
|
20 |
+
start_pos = int:0
|
21 |
+
gamma = float:10.0
|
22 |
+
msg_base = int:2
|
23 |
+
seed_scheme = str:sha_left_hash
|
24 |
+
window_length = int:1
|
25 |
+
private_key = int:0
|
26 |
+
max_new_tokens_ratio = float:2.0
|
27 |
+
num_beams = int:4
|
28 |
+
|
29 |
+
[decrypt.default]
|
30 |
+
gen_model = str:gpt2
|
31 |
+
msg_base = int:2
|
32 |
+
seed_scheme = str:sha_left_hash
|
33 |
+
window_length = int:1
|
34 |
+
private_key = int:0
|
global_config.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import configparser
|
2 |
+
from utils import static_init
|
3 |
+
|
4 |
+
|
5 |
+
@static_init
|
6 |
+
class GlobalConfig:
|
7 |
+
default_file_name = "config.ini"
|
8 |
+
config = configparser.ConfigParser()
|
9 |
+
|
10 |
+
@classmethod
|
11 |
+
def get_section(cls, section_name):
|
12 |
+
if section_name in cls.config :
|
13 |
+
return cls.config[section_name].keys()
|
14 |
+
else:
|
15 |
+
return None
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def get(cls, section_name, attr_name):
|
19 |
+
if section_name in cls.config and attr_name in cls.config[section_name]:
|
20 |
+
value = cls.config.get(section_name, attr_name)
|
21 |
+
value = value.split(":")
|
22 |
+
type_name = value[0]
|
23 |
+
value = ":".join(value[1:])
|
24 |
+
if type_name == "str":
|
25 |
+
value = str(value)
|
26 |
+
elif type_name == "float":
|
27 |
+
value = float(value)
|
28 |
+
elif type_name == "int":
|
29 |
+
value = int(value)
|
30 |
+
return value
|
31 |
+
else:
|
32 |
+
return None
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def __static_init__(cls):
|
36 |
+
cls.config.read(cls.default_file_name)
|
model_factory.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import transformers
|
3 |
+
from utils import load_model, static_init
|
4 |
+
from global_config import GlobalConfig
|
5 |
+
|
6 |
+
|
7 |
+
@static_init
|
8 |
+
class ModelFactory:
|
9 |
+
models_names = {}
|
10 |
+
models = {}
|
11 |
+
tokenizers = {}
|
12 |
+
run_model = None
|
13 |
+
dtype = torch.bfloat16
|
14 |
+
load_device = torch.device("cpu")
|
15 |
+
run_device = torch.device("cpu")
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def __static_init__(cls):
|
19 |
+
names_sec = GlobalConfig.get_section("models.names")
|
20 |
+
if names_sec is not None:
|
21 |
+
for name in names_sec:
|
22 |
+
cls.models_names[name] = GlobalConfig.get("models.names", name)
|
23 |
+
|
24 |
+
if GlobalConfig.get_section("models.params") is not None:
|
25 |
+
dtype = GlobalConfig.get("models.params", "dtype")
|
26 |
+
if dtype == "bfloat16":
|
27 |
+
cls.dtype = torch.bfloat16
|
28 |
+
elif dtype == "float16":
|
29 |
+
cls.dtype = torch.float16
|
30 |
+
elif dtype == "float32":
|
31 |
+
cls.dtype = torch.float32
|
32 |
+
|
33 |
+
load_device = GlobalConfig.get("models.params", "load_device")
|
34 |
+
run_device = GlobalConfig.get("models.params", "run_device")
|
35 |
+
if load_device is not None:
|
36 |
+
cls.load_device = torch.device(str(load_device))
|
37 |
+
if run_device is not None:
|
38 |
+
cls.run_device = torch.device(str(run_device))
|
39 |
+
|
40 |
+
@classmethod
|
41 |
+
def __load_model(cls, name):
|
42 |
+
if name not in cls.models_names:
|
43 |
+
print(f"{name} is not a valid model name")
|
44 |
+
return None
|
45 |
+
|
46 |
+
if name not in cls.models:
|
47 |
+
model, tokenizer = load_model(
|
48 |
+
cls.models_names[name], cls.load_device
|
49 |
+
)
|
50 |
+
cls.models[name] = model
|
51 |
+
cls.tokenizers[name] = tokenizer
|
52 |
+
else:
|
53 |
+
model, tokenizer = cls.models[name], cls.tokenizers[name]
|
54 |
+
|
55 |
+
return model, tokenizer
|
56 |
+
|
57 |
+
@classmethod
|
58 |
+
def load_model(cls, name):
|
59 |
+
if name not in cls.models:
|
60 |
+
cls.__load_model(name)
|
61 |
+
|
62 |
+
if name != cls.run_model and cls.run_model is not None:
|
63 |
+
cls.models[cls.run_model].to(cls.load_device)
|
64 |
+
|
65 |
+
cls.models[name].to(cls.run_device)
|
66 |
+
cls.run_model = name
|
67 |
+
|
68 |
+
return cls.models[name], cls.tokenizers[name]
|
69 |
+
|
70 |
+
@classmethod
|
71 |
+
def get_models_names(cls):
|
72 |
+
return list(cls.models_names.keys())
|