Respair commited on
Commit
d5a51ca
1 Parent(s): 2536b89

Update Utils/PLBERT/util.py

Browse files
Files changed (1) hide show
  1. Utils/PLBERT/util.py +19 -19
Utils/PLBERT/util.py CHANGED
@@ -3,7 +3,6 @@ import yaml
3
  import torch
4
  from transformers import AlbertConfig, AlbertModel
5
 
6
-
7
  class CustomAlbert(AlbertModel):
8
  def forward(self, *args, **kwargs):
9
  # Call the original forward method
@@ -16,34 +15,35 @@ class CustomAlbert(AlbertModel):
16
  def load_plbert(log_dir):
17
  config_path = os.path.join(log_dir, "config.yml")
18
  plbert_config = yaml.safe_load(open(config_path))
19
-
20
- albert_base_configuration = AlbertConfig(**plbert_config["model_params"])
21
  bert = CustomAlbert(albert_base_configuration)
22
 
23
  files = os.listdir(log_dir)
24
  ckpts = []
25
  for f in os.listdir(log_dir):
26
- if f.startswith("step_"):
27
- ckpts.append(f)
28
-
29
- iters = [
30
- int(f.split("_")[-1].split(".")[0])
31
- for f in ckpts
32
- if os.path.isfile(os.path.join(log_dir, f))
33
- ]
34
  iters = sorted(iters)[-1]
35
 
36
- checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location="cpu")
37
- state_dict = checkpoint["net"]
38
  from collections import OrderedDict
39
-
40
  new_state_dict = OrderedDict()
41
  for k, v in state_dict.items():
42
- name = k[7:] # remove `module.`
43
- if name.startswith("encoder."):
44
- name = name[8:] # remove `encoder.`
45
  new_state_dict[name] = v
46
- del new_state_dict["embeddings.position_ids"]
 
 
 
 
 
47
  bert.load_state_dict(new_state_dict, strict=False)
48
-
49
  return bert
 
 
 
3
  import torch
4
  from transformers import AlbertConfig, AlbertModel
5
 
 
6
  class CustomAlbert(AlbertModel):
7
  def forward(self, *args, **kwargs):
8
  # Call the original forward method
 
15
  def load_plbert(log_dir):
16
  config_path = os.path.join(log_dir, "config.yml")
17
  plbert_config = yaml.safe_load(open(config_path))
18
+
19
+ albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
20
  bert = CustomAlbert(albert_base_configuration)
21
 
22
  files = os.listdir(log_dir)
23
  ckpts = []
24
  for f in os.listdir(log_dir):
25
+ if f.startswith("step_"): ckpts.append(f)
26
+
27
+ iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
 
 
 
 
 
28
  iters = sorted(iters)[-1]
29
 
30
+ checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location='cpu')
31
+ state_dict = checkpoint['net']
32
  from collections import OrderedDict
 
33
  new_state_dict = OrderedDict()
34
  for k, v in state_dict.items():
35
+ name = k[7:] # remove `module.`
36
+ if name.startswith('encoder.'):
37
+ name = name[8:] # remove `encoder.`
38
  new_state_dict[name] = v
39
+
40
+ # Check if 'embeddings.position_ids' exists before attempting to delete it
41
+ if not hasattr(bert.embeddings, 'position_ids'):
42
+ del new_state_dict["embeddings.position_ids"]
43
+
44
+
45
  bert.load_state_dict(new_state_dict, strict=False)
46
+
47
  return bert
48
+
49
+