kanhatakeyama commited on
Commit
b6dcc1d
1 Parent(s): c821aa0

Upload model

Browse files
Files changed (2) hide show
  1. MoEConfig.py +5 -2
  2. MoEModel.py +55 -8
MoEConfig.py CHANGED
@@ -3,8 +3,11 @@ from typing import List
3
 
4
 
5
  class MoEConfig(PretrainedConfig):
6
- model_type = "moewrapper" # モデルの名前を命名?
7
- torch_dtype = "float32",
 
 
 
8
 
9
  def __init__(
10
  self,
 
3
 
4
 
5
  class MoEConfig(PretrainedConfig):
6
+ model_type = "moewrapper"
7
+ model_list = [
8
+ "kanhatakeyama/01b_model_30b_token",
9
+ "kanhatakeyama/01b_model_30b_token",
10
+ ]
11
 
12
  def __init__(
13
  self,
MoEModel.py CHANGED
@@ -1,33 +1,80 @@
1
  from transformers import PreTrainedModel
2
- from MoEConfig import MoEConfig
3
  from transformers import AutoModelForCausalLM
4
  import torch
5
-
6
- model_name = "kanhatakeyama/01b_model_30b_token"
7
 
8
 
9
  class MoeModel(PreTrainedModel):
10
  config_class = MoEConfig
 
 
11
 
12
  def __init__(self, config):
13
  super().__init__(config)
 
 
 
14
 
15
- self.model = None
16
- self.set_model()
17
 
18
- def set_model(self):
 
19
  self.model = AutoModelForCausalLM.from_pretrained(
20
  model_name,
21
  device_map="auto",
22
  torch_dtype=torch.float16
23
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def generate(self, input_ids, attention_mask,
26
  **generate_kwargs):
27
- if self.model is None:
28
- self.set_model()
 
 
 
 
 
 
 
29
 
30
  ret = self.model.generate(input_ids=input_ids,
31
  attention_mask=attention_mask,
32
  **generate_kwargs)
33
  return ret
 
 
 
 
 
 
 
 
1
  from transformers import PreTrainedModel
2
+ from .MoEConfig import MoEConfig
3
  from transformers import AutoModelForCausalLM
4
  import torch
5
+ import numpy as np
 
6
 
7
 
8
  class MoeModel(PreTrainedModel):
9
  config_class = MoEConfig
10
+ verbose = True
11
+ fix_mode = False
12
 
13
  def __init__(self, config):
14
  super().__init__(config)
15
+ self.model_list = []
16
+ for model_name in self.config_class.model_list:
17
+ self.append_model(model_name)
18
 
19
+ self.set_model_id(0)
 
20
 
21
+ """
22
+ def set_model(self, model_name):
23
  self.model = AutoModelForCausalLM.from_pretrained(
24
  model_name,
25
  device_map="auto",
26
  torch_dtype=torch.float16
27
  )
28
+ """
29
+
30
+ def append_model(self, model_name):
31
+ print("loading ", model_name)
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ model_name,
34
+ device_map="auto",
35
+ torch_dtype=torch.float16
36
+ )
37
+ self.model_list.append(model)
38
+
39
+ # def set_tokenizer(self, tokenizer):
40
+ # self.tokenizer = tokenizer
41
+
42
+ def set_model_id(self, model_id):
43
+ self.model = self.model_list[model_id]
44
+
45
+ def calc_perplexity(self, tokenized_input):
46
+ ppl_list = []
47
+ for model in self.model_list:
48
+ ppl_list.append(perplexity(model, tokenized_input))
49
+ return np.array(ppl_list)
50
+
51
+ def fix_model(self, model_id):
52
+ self.set_model_id(model_id)
53
+ self.fix_mode = True
54
+
55
+ def set_flexible_mode(self):
56
+ self.fix_mode = False
57
 
58
  def generate(self, input_ids, attention_mask,
59
  **generate_kwargs):
60
+
61
+ if not self.fix_mode:
62
+ ppl_array = self.calc_perplexity(input_ids)
63
+ best_model_id = np.where(ppl_array == min(ppl_array))[0][0]
64
+ self.set_model_id(best_model_id)
65
+
66
+ if self.verbose:
67
+ print(f"model {best_model_id} will be used")
68
+ print("ppl array: ", ppl_array)
69
 
70
  ret = self.model.generate(input_ids=input_ids,
71
  attention_mask=attention_mask,
72
  **generate_kwargs)
73
  return ret
74
+
75
+
76
+ def perplexity(model, tokenized_input) -> torch.Tensor:
77
+ with torch.inference_mode():
78
+ output = model(tokenized_input, labels=tokenized_input)
79
+ ppl = torch.exp(output.loss)
80
+ return ppl.item()