susnato commited on
Commit
8deba90
1 Parent(s): 05a49c5

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +3 -13
convert.py CHANGED
@@ -1,4 +1,5 @@
1
  # Copied from https://github.com/nghuyong/ERNIE-Pytorch/blob/master/convert.py
 
2
 
3
 
4
  #!/usr/bin/env python
@@ -99,19 +100,8 @@ def extract_and_convert(input_dir, output_dir):
99
  paddle_paddle_params, _ = D.load_dygraph(os.path.join(input_dir, 'model_state.pdparams'))
100
  for weight_name, weight_value in paddle_paddle_params.items():
101
  if 'weight' in weight_name:
102
- # if 'encoder' in weight_name or 'pooler' in weight_name or 'cls.' in weight_name:
103
- # weight_value = weight_value.transpose()
104
-
105
- # if 'encoder' in weight_name or 'pooler' in weight_name or 'cls.' in weight_name and \
106
- # "k_proj" not in weight_name and "v_proj" not in weight_name and \
107
- # "out_proj" not in weight_name and "linear1" not in weight_name and \
108
- # "linear2" not in weight_name:
109
- # weight_value = weight_value.transpose()
110
- if "encoder" in weight_name:
111
- if "linear1" in weight_name or "linear2" in weight_name:
112
- weight_value = weight_value.transpose()
113
- else:
114
- weight_value = weight_value.transpose()
115
 
116
  if weight_name not in weight_map:
117
  print('=' * 20, '[SKIP]', weight_name, '=' * 20)
 
1
  # Copied from https://github.com/nghuyong/ERNIE-Pytorch/blob/master/convert.py
2
+ # with some modifications for ernie-m
3
 
4
 
5
  #!/usr/bin/env python
 
100
  paddle_paddle_params, _ = D.load_dygraph(os.path.join(input_dir, 'model_state.pdparams'))
101
  for weight_name, weight_value in paddle_paddle_params.items():
102
  if 'weight' in weight_name:
103
+ if 'encoder' in weight_name or 'pooler' in weight_name or 'cls.' in weight_name:
104
+ weight_value = weight_value.transpose()
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  if weight_name not in weight_map:
107
  print('=' * 20, '[SKIP]', weight_name, '=' * 20)