Update convert.py
Browse files- convert.py +3 -13
convert.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
# Copied from https://github.com/nghuyong/ERNIE-Pytorch/blob/master/convert.py
|
|
|
2 |
|
3 |
|
4 |
#!/usr/bin/env python
|
@@ -99,19 +100,8 @@ def extract_and_convert(input_dir, output_dir):
|
|
99 |
paddle_paddle_params, _ = D.load_dygraph(os.path.join(input_dir, 'model_state.pdparams'))
|
100 |
for weight_name, weight_value in paddle_paddle_params.items():
|
101 |
if 'weight' in weight_name:
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
# if 'encoder' in weight_name or 'pooler' in weight_name or 'cls.' in weight_name and \
|
106 |
-
# "k_proj" not in weight_name and "v_proj" not in weight_name and \
|
107 |
-
# "out_proj" not in weight_name and "linear1" not in weight_name and \
|
108 |
-
# "linear2" not in weight_name:
|
109 |
-
# weight_value = weight_value.transpose()
|
110 |
-
if "encoder" in weight_name:
|
111 |
-
if "linear1" in weight_name or "linear2" in weight_name:
|
112 |
-
weight_value = weight_value.transpose()
|
113 |
-
else:
|
114 |
-
weight_value = weight_value.transpose()
|
115 |
|
116 |
if weight_name not in weight_map:
|
117 |
print('=' * 20, '[SKIP]', weight_name, '=' * 20)
|
|
|
1 |
# Copied from https://github.com/nghuyong/ERNIE-Pytorch/blob/master/convert.py
|
2 |
+
# with some modifications for ernie-m
|
3 |
|
4 |
|
5 |
#!/usr/bin/env python
|
|
|
100 |
paddle_paddle_params, _ = D.load_dygraph(os.path.join(input_dir, 'model_state.pdparams'))
|
101 |
for weight_name, weight_value in paddle_paddle_params.items():
|
102 |
if 'weight' in weight_name:
|
103 |
+
if 'encoder' in weight_name or 'pooler' in weight_name or 'cls.' in weight_name:
|
104 |
+
weight_value = weight_value.transpose()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
if weight_name not in weight_map:
|
107 |
print('=' * 20, '[SKIP]', weight_name, '=' * 20)
|