tnk2908 commited on
Commit
72a1159
1 Parent(s): e350046

Add global config and model factory

Browse files
Files changed (3) hide show
  1. config.ini +34 -0
  2. global_config.py +36 -0
  3. model_factory.py +72 -0
config.ini ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [server]
2
+ port = int:42069
3
+
4
+ [models.names]
5
+ gpt2 = str:openai-community/gpt2
6
+ gpt2_medium = str:openai-community/gpt2-medium
7
+ gpt2_large = str:openai-community/gpt2-large
8
+ gpt2_xl = str:openai-community/gpt2-xl
9
+
10
+ llama_3_8b_intruct = str:meta-llama/Meta-Llama-3-8B-Instruct
11
+ llama_3_70b_instruct = str:meta-llama/Meta-Llama-3-70B-Instruct
12
+
13
+ [models.params]
14
+ dtype = str:bfloat16
15
+ run_device = str:cpu
16
+ load_device = str:cuda
17
+
18
+ [encrypt.default]
19
+ gen_model = str:gpt2
20
+ start_pos = int:0
21
+ gamma = float:10.0
22
+ msg_base = int:2
23
+ seed_scheme = str:sha_left_hash
24
+ window_length = int:1
25
+ private_key = int:0
26
+ max_new_tokens_ratio = float:2.0
27
+ num_beams = int:4
28
+
29
+ [decrypt.default]
30
+ gen_model = str:gpt2
31
+ msg_base = int:2
32
+ seed_scheme = str:sha_left_hash
33
+ window_length = int:1
34
+ private_key = int:0
global_config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import configparser
2
+ from utils import static_init
3
+
4
+
5
+ @static_init
6
+ class GlobalConfig:
7
+ default_file_name = "config.ini"
8
+ config = configparser.ConfigParser()
9
+
10
+ @classmethod
11
+ def get_section(cls, section_name):
12
+ if section_name in cls.config :
13
+ return cls.config[section_name].keys()
14
+ else:
15
+ return None
16
+
17
+ @classmethod
18
+ def get(cls, section_name, attr_name):
19
+ if section_name in cls.config and attr_name in cls.config[section_name]:
20
+ value = cls.config.get(section_name, attr_name)
21
+ value = value.split(":")
22
+ type_name = value[0]
23
+ value = ":".join(value[1:])
24
+ if type_name == "str":
25
+ value = str(value)
26
+ elif type_name == "float":
27
+ value = float(value)
28
+ elif type_name == "int":
29
+ value = int(value)
30
+ return value
31
+ else:
32
+ return None
33
+
34
+ @classmethod
35
+ def __static_init__(cls):
36
+ cls.config.read(cls.default_file_name)
model_factory.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from utils import load_model, static_init
4
+ from global_config import GlobalConfig
5
+
6
+
7
+ @static_init
8
+ class ModelFactory:
9
+ models_names = {}
10
+ models = {}
11
+ tokenizers = {}
12
+ run_model = None
13
+ dtype = torch.bfloat16
14
+ load_device = torch.device("cpu")
15
+ run_device = torch.device("cpu")
16
+
17
+ @classmethod
18
+ def __static_init__(cls):
19
+ names_sec = GlobalConfig.get_section("models.names")
20
+ if names_sec is not None:
21
+ for name in names_sec:
22
+ cls.models_names[name] = GlobalConfig.get("models.names", name)
23
+
24
+ if GlobalConfig.get_section("models.params") is not None:
25
+ dtype = GlobalConfig.get("models.params", "dtype")
26
+ if dtype == "bfloat16":
27
+ cls.dtype = torch.bfloat16
28
+ elif dtype == "float16":
29
+ cls.dtype = torch.float16
30
+ elif dtype == "float32":
31
+ cls.dtype = torch.float32
32
+
33
+ load_device = GlobalConfig.get("models.params", "load_device")
34
+ run_device = GlobalConfig.get("models.params", "run_device")
35
+ if load_device is not None:
36
+ cls.load_device = torch.device(str(load_device))
37
+ if run_device is not None:
38
+ cls.run_device = torch.device(str(run_device))
39
+
40
+ @classmethod
41
+ def __load_model(cls, name):
42
+ if name not in cls.models_names:
43
+ print(f"{name} is not a valid model name")
44
+ return None
45
+
46
+ if name not in cls.models:
47
+ model, tokenizer = load_model(
48
+ cls.models_names[name], cls.load_device
49
+ )
50
+ cls.models[name] = model
51
+ cls.tokenizers[name] = tokenizer
52
+ else:
53
+ model, tokenizer = cls.models[name], cls.tokenizers[name]
54
+
55
+ return model, tokenizer
56
+
57
+ @classmethod
58
+ def load_model(cls, name):
59
+ if name not in cls.models:
60
+ cls.__load_model(name)
61
+
62
+ if name != cls.run_model and cls.run_model is not None:
63
+ cls.models[cls.run_model].to(cls.load_device)
64
+
65
+ cls.models[name].to(cls.run_device)
66
+ cls.run_model = name
67
+
68
+ return cls.models[name], cls.tokenizers[name]
69
+
70
+ @classmethod
71
+ def get_models_names(cls):
72
+ return list(cls.models_names.keys())