Spaces:
Runtime error
Runtime error
Update Utils/PLBERT/util.py
Browse files- Utils/PLBERT/util.py +19 -19
Utils/PLBERT/util.py
CHANGED
@@ -3,7 +3,6 @@ import yaml
|
|
3 |
import torch
|
4 |
from transformers import AlbertConfig, AlbertModel
|
5 |
|
6 |
-
|
7 |
class CustomAlbert(AlbertModel):
|
8 |
def forward(self, *args, **kwargs):
|
9 |
# Call the original forward method
|
@@ -16,34 +15,35 @@ class CustomAlbert(AlbertModel):
|
|
16 |
def load_plbert(log_dir):
|
17 |
config_path = os.path.join(log_dir, "config.yml")
|
18 |
plbert_config = yaml.safe_load(open(config_path))
|
19 |
-
|
20 |
-
albert_base_configuration = AlbertConfig(**plbert_config[
|
21 |
bert = CustomAlbert(albert_base_configuration)
|
22 |
|
23 |
files = os.listdir(log_dir)
|
24 |
ckpts = []
|
25 |
for f in os.listdir(log_dir):
|
26 |
-
if f.startswith("step_"):
|
27 |
-
|
28 |
-
|
29 |
-
iters = [
|
30 |
-
int(f.split("_")[-1].split(".")[0])
|
31 |
-
for f in ckpts
|
32 |
-
if os.path.isfile(os.path.join(log_dir, f))
|
33 |
-
]
|
34 |
iters = sorted(iters)[-1]
|
35 |
|
36 |
-
checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location=
|
37 |
-
state_dict = checkpoint[
|
38 |
from collections import OrderedDict
|
39 |
-
|
40 |
new_state_dict = OrderedDict()
|
41 |
for k, v in state_dict.items():
|
42 |
-
name = k[7:]
|
43 |
-
if name.startswith(
|
44 |
-
name = name[8:]
|
45 |
new_state_dict[name] = v
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
47 |
bert.load_state_dict(new_state_dict, strict=False)
|
48 |
-
|
49 |
return bert
|
|
|
|
|
|
3 |
import torch
|
4 |
from transformers import AlbertConfig, AlbertModel
|
5 |
|
|
|
6 |
class CustomAlbert(AlbertModel):
|
7 |
def forward(self, *args, **kwargs):
|
8 |
# Call the original forward method
|
|
|
15 |
def load_plbert(log_dir):
|
16 |
config_path = os.path.join(log_dir, "config.yml")
|
17 |
plbert_config = yaml.safe_load(open(config_path))
|
18 |
+
|
19 |
+
albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
|
20 |
bert = CustomAlbert(albert_base_configuration)
|
21 |
|
22 |
files = os.listdir(log_dir)
|
23 |
ckpts = []
|
24 |
for f in os.listdir(log_dir):
|
25 |
+
if f.startswith("step_"): ckpts.append(f)
|
26 |
+
|
27 |
+
iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
|
|
|
|
|
|
|
|
|
|
|
28 |
iters = sorted(iters)[-1]
|
29 |
|
30 |
+
checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location='cpu')
|
31 |
+
state_dict = checkpoint['net']
|
32 |
from collections import OrderedDict
|
|
|
33 |
new_state_dict = OrderedDict()
|
34 |
for k, v in state_dict.items():
|
35 |
+
name = k[7:] # remove `module.`
|
36 |
+
if name.startswith('encoder.'):
|
37 |
+
name = name[8:] # remove `encoder.`
|
38 |
new_state_dict[name] = v
|
39 |
+
|
40 |
+
# Check if 'embeddings.position_ids' exists before attempting to delete it
|
41 |
+
if not hasattr(bert.embeddings, 'position_ids'):
|
42 |
+
del new_state_dict["embeddings.position_ids"]
|
43 |
+
|
44 |
+
|
45 |
bert.load_state_dict(new_state_dict, strict=False)
|
46 |
+
|
47 |
return bert
|
48 |
+
|
49 |
+
|