patrickvonplaten commited on
Commit
f0d385b
β€’
1 Parent(s): fb106d4

save intermediate

Browse files
Files changed (2) hide show
  1. check_gradients_pt_flax.py +64 -36
  2. run_models.sh +17 -0
check_gradients_pt_flax.py CHANGED
@@ -1,11 +1,11 @@
1
  #!/usr/bin/env python3
2
- from transformers import SpeechEncoderDecoderModel, FlaxSpeechEncoderDecoderModel
3
  import tempfile
4
  import random
5
  import numpy as np
6
  import torch
7
  import optax
8
  import jax
 
9
  from flax.training.common_utils import onehot
10
  from flax.traverse_util import flatten_dict
11
 
@@ -63,7 +63,7 @@ def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_tok
63
  return shifted_input_ids
64
 
65
 
66
- def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 4e-2):
67
  diff = np.abs((a - b)).max()
68
  if diff < tol:
69
  print(f"βœ… Difference between Flax and PyTorch is {diff} (< {tol})")
@@ -71,46 +71,60 @@ def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 4e-2):
71
  print(f"❌ Difference between Flax and PyTorch is {diff} (>= {tol})")
72
 
73
 
74
- def assert_dict_equal(a: dict, b: dict, tol: float = 4e-2):
75
  if a.keys() != b.keys():
76
  print("❌ Dictionary keys for PyTorch and Flax do not match")
 
 
 
 
 
77
  for k in a:
78
- diff = np.abs((a[k] - b[k])).max()
 
 
 
79
  if diff < tol:
80
- print(f"βœ… Layer {k} diff is {diff} < {tol}).")
 
 
 
 
81
  else:
82
- print(f"❌ Layer {k} diff is {diff} (>= {tol}).")
 
83
 
84
 
85
- def main():
86
- encoder_id = "hf-internal-testing/tiny-random-wav2vec2"
87
- decoder_id = "hf-internal-testing/tiny-random-bart"
88
 
89
- use_decoder_attention_mask = False
90
- freeze_feature_encoder = False
91
 
92
- pt_model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id,
93
- encoder_add_adapter=True)
94
 
95
- with tempfile.TemporaryDirectory() as tmpdirname:
96
- pt_model.save_pretrained(tmpdirname)
97
- fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
98
 
99
- batch_size = 13
100
- input_values = floats_tensor([batch_size, 512], fx_model.config.encoder.vocab_size)
101
- attention_mask = random_attention_mask([batch_size, 512])
102
- label_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
103
- decoder_input_ids = shift_tokens_right(input_ids=label_ids, pad_token_id=fx_model.config.decoder.pad_token_id,
104
- decoder_start_token_id=fx_model.config.decoder.decoder_start_token_id)
105
- decoder_attention_mask = random_attention_mask([batch_size, 4])
 
 
 
106
 
107
  fx_inputs = {
108
- "inputs": input_values,
109
  "attention_mask": attention_mask,
110
- "decoder_input_ids": decoder_input_ids,
111
  }
112
- if use_decoder_attention_mask:
113
- fx_inputs["decoder_attention_mask"] = decoder_attention_mask
 
 
114
 
115
  pt_inputs = {k: torch.tensor(v.tolist()) for k, v in fx_inputs.items()}
116
  pt_inputs["labels"] = torch.tensor(label_ids.tolist())
@@ -118,9 +132,6 @@ def main():
118
  fx_outputs = fx_model(**fx_inputs)
119
  fx_logits = fx_outputs.logits
120
 
121
- if freeze_feature_encoder:
122
- pt_model.freeze_feature_encoder()
123
-
124
  pt_outputs = pt_model(**pt_inputs)
125
  pt_logits = pt_outputs.logits
126
  pt_loss = pt_outputs.loss
@@ -129,11 +140,10 @@ def main():
129
  print(f"Flax logits shape: {fx_logits.shape}, PyTorch logits shape: {pt_logits.shape}")
130
  assert_almost_equals(fx_logits, pt_logits.detach().numpy())
131
 
132
- def fx_train_step(fx_model, batch, freeze_feature_encoder=False):
133
  def compute_loss(params):
134
  label_ids = batch.pop('label_ids')
135
- logits = fx_model(**batch, params=params,
136
- freeze_feature_encoder=freeze_feature_encoder).logits
137
  vocab_size = logits.shape[-1]
138
  targets = onehot(label_ids, vocab_size)
139
  loss = optax.softmax_cross_entropy(logits, targets)
@@ -145,7 +155,7 @@ def main():
145
 
146
  fx_inputs["label_ids"] = label_ids
147
 
148
- fx_loss, fx_grad = fx_train_step(fx_model, fx_inputs, freeze_feature_encoder=freeze_feature_encoder)
149
 
150
  print("--------------------------Checking losses match--------------------------")
151
  print(f"Flax loss: {fx_loss}, PyTorch loss: {pt_loss}")
@@ -166,13 +176,31 @@ def main():
166
 
167
  with tempfile.TemporaryDirectory() as tmpdirname:
168
  pt_model.save_pretrained(tmpdirname)
169
- pt_grad_model_to_fx = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
170
 
171
  pt_grad_to_fx = pt_grad_model_to_fx.params
172
  fx_grad = flatten_dict(fx_grad)
173
  pt_grad_to_fx = flatten_dict(pt_grad_to_fx)
174
  print("--------------------------Checking gradients match--------------------------")
175
- assert_dict_equal(fx_grad, pt_grad_to_fx)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
 
178
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python3
 
2
  import tempfile
3
  import random
4
  import numpy as np
5
  import torch
6
  import optax
7
  import jax
8
+ import sys
9
  from flax.training.common_utils import onehot
10
  from flax.traverse_util import flatten_dict
11
 
 
63
  return shifted_input_ids
64
 
65
 
66
+ def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 1e-2):
67
  diff = np.abs((a - b)).max()
68
  if diff < tol:
69
  print(f"βœ… Difference between Flax and PyTorch is {diff} (< {tol})")
 
71
  print(f"❌ Difference between Flax and PyTorch is {diff} (>= {tol})")
72
 
73
 
74
+ def assert_dict_equal(a: dict, b: dict, tol: float = 1e-2):
75
  if a.keys() != b.keys():
76
  print("❌ Dictionary keys for PyTorch and Flax do not match")
77
+ results_fail = []
78
+ results_correct = []
79
+
80
+ results_fail_rel = []
81
+ results_correct_rel = []
82
  for k in a:
83
+ ak_norm = np.linalg.norm(a[k])
84
+ bk_norm = np.linalg.norm(b[k])
85
+ diff = np.abs(ak_norm - bk_norm)
86
+ diff_rel = np.abs(ak_norm - bk_norm) / np.abs(ak_norm)
87
  if diff < tol:
88
+ results_correct.append(f"βœ… Layer {k} diff is {diff} < {tol}).")
89
+ else:
90
+ results_fail.append(f"❌ Layer {k} has PT grad norm {bk_norm} and flax grad norm {ak_norm}.")
91
+ if diff_rel < tol:
92
+ results_correct_rel.append(f"βœ… Layer {k} rel diff is {diff} < {tol}).")
93
  else:
94
+ results_fail_rel.append(f"❌ Layer {k} has PT grad norm {bk_norm} and flax grad norm {ak_norm}.")
95
+ return results_fail_rel, results_correct_rel, results_fail, results_correct
96
 
97
 
98
+ def compare_grads(model_id, pt_architecture):
99
+ transformers_module = __import__("transformers", fromlist=[pt_architecture])
 
100
 
101
+ model_cls = getattr(transformers_module, pt_architecture)
102
+ flax_model_cls = getattr(transformers_module, "Flax" + pt_architecture)
103
 
104
+ pt_model, model_info = model_cls.from_pretrained(model_id, output_loading_info=True)
 
105
 
106
+ if len(model_info["missing_keys"]) > 0:
107
+ raise ValueError(f"{model_id} with {pt_architecture} has missing keys: {model_info['missing_keys']}")
 
108
 
109
+ fx_model = flax_model_cls.from_pretrained(model_id, from_pt=True)
110
+
111
+ batch_size = 2
112
+ seq_len = 64
113
+
114
+ input_ids = ids_tensor([batch_size, seq_len], fx_model.config.vocab_size)
115
+ label_ids = ids_tensor([batch_size, seq_len], fx_model.config.vocab_size)
116
+
117
+ attention_mask = random_attention_mask([batch_size, seq_len])
118
+ label_ids = ids_tensor([batch_size, seq_len], fx_model.config.vocab_size)
119
 
120
  fx_inputs = {
121
+ "input_ids": input_ids,
122
  "attention_mask": attention_mask,
 
123
  }
124
+
125
+ if pt_model.config.is_encoder_decoder:
126
+ decoder_input_ids = shift_tokens_right(input_ids=label_ids, pad_token_id=fx_model.config.pad_token_id, decoder_start_token_id=fx_model.config.decoder_start_token_id)
127
+ fx_inputs["decoder_input_ids"] = decoder_input_ids
128
 
129
  pt_inputs = {k: torch.tensor(v.tolist()) for k, v in fx_inputs.items()}
130
  pt_inputs["labels"] = torch.tensor(label_ids.tolist())
 
132
  fx_outputs = fx_model(**fx_inputs)
133
  fx_logits = fx_outputs.logits
134
 
 
 
 
135
  pt_outputs = pt_model(**pt_inputs)
136
  pt_logits = pt_outputs.logits
137
  pt_loss = pt_outputs.loss
 
140
  print(f"Flax logits shape: {fx_logits.shape}, PyTorch logits shape: {pt_logits.shape}")
141
  assert_almost_equals(fx_logits, pt_logits.detach().numpy())
142
 
143
+ def fx_train_step(fx_model, batch):
144
  def compute_loss(params):
145
  label_ids = batch.pop('label_ids')
146
+ logits = fx_model(**batch, params=params).logits
 
147
  vocab_size = logits.shape[-1]
148
  targets = onehot(label_ids, vocab_size)
149
  loss = optax.softmax_cross_entropy(logits, targets)
 
155
 
156
  fx_inputs["label_ids"] = label_ids
157
 
158
+ fx_loss, fx_grad = fx_train_step(fx_model, fx_inputs)
159
 
160
  print("--------------------------Checking losses match--------------------------")
161
  print(f"Flax loss: {fx_loss}, PyTorch loss: {pt_loss}")
 
176
 
177
  with tempfile.TemporaryDirectory() as tmpdirname:
178
  pt_model.save_pretrained(tmpdirname)
179
+ pt_grad_model_to_fx = flax_model_cls.from_pretrained(tmpdirname, from_pt=True)
180
 
181
  pt_grad_to_fx = pt_grad_model_to_fx.params
182
  fx_grad = flatten_dict(fx_grad)
183
  pt_grad_to_fx = flatten_dict(pt_grad_to_fx)
184
  print("--------------------------Checking gradients match--------------------------")
185
+ results_fail_rel, results_correct_rel, results_fail, results_correct = assert_dict_equal(fx_grad, pt_grad_to_fx)
186
+
187
+ if len(results_fail) == 0:
188
+ print("βœ… All grads pass")
189
+ else:
190
+ print("\n".join(results_fail))
191
+
192
+ print("--------------------------Checking rel gradients match--------------------------")
193
+
194
+ if len(results_fail_rel) == 0:
195
+ print("βœ… All rel grads pass")
196
+ else:
197
+ print("\n".join(results_fail_rel))
198
+
199
+
200
+ def main():
201
+ model_id = sys.argv[1]
202
+ pt_architecture_name = sys.argv[2]
203
+ compare_grads(model_id, pt_architecture_name)
204
 
205
 
206
  if __name__ == "__main__":
run_models.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ #model_ids=("hf-internal-testing/tiny-random-roberta" "hf-internal-testing/tiny-random-bert" "hf-internal-testing/tiny-random-bart" "tf-internal-testing/tiny-random-t5")
3
+ #model_architectures=("RobertaForMaskedLM" "BertForMaskedLM" "BartForConditionalGeneration" "T5ForConditionalGeneration")
4
+ model_ids=("hf-internal-testing/tiny-random-roberta")
5
+ model_architectures=("RobertaForMaskedLM")
6
+
7
+ rm -rf log.txt
8
+ touch log.txt
9
+
10
+ for model_idx in "${!model_ids[@]}"; do
11
+ model_id=${model_ids[model_idx]}
12
+ model_architecture=${model_architectures[model_idx]}
13
+
14
+ echo "Check ${model_id} ..." >> log.txt
15
+ ./check_gradients_pt_flax.py "${model_id}" "${model_architecture}" >> log.txt
16
+ echo "=========================================" >> log.txt
17
+ done