File size: 7,025 Bytes
fb106d4 f0d385b fb106d4 f0d385b fb106d4 f0d385b fb106d4 f0d385b fb106d4 f0d385b fb106d4 f0d385b fb106d4 f0d385b fb106d4 f0d385b fb106d4 f0d385b fb106d4 f0d385b fb106d4 f0d385b fb106d4 f0d385b fb106d4 f0d385b fb106d4 f0d385b fb106d4 f0d385b fb106d4 f0d385b fb106d4 f0d385b fb106d4 a85c2d0 fb106d4 a85c2d0 fb106d4 a85c2d0 fb106d4 f0d385b fb106d4 f0d385b fb106d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
#!/usr/bin/env python3
import tempfile
import random
import numpy as np
import torch
import optax
import jax
import sys
from flax.training.common_utils import onehot
from flax.traverse_util import flatten_dict
def ids_tensor(shape, vocab_size, rng=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
output = np.array(values).reshape(shape)
return output
def random_attention_mask(shape, rng=None):
attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
# make sure that at least one token is attended to for each batch
attn_mask[:, -1] = 1
return attn_mask
def floats_tensor(shape, scale=1.0, rng=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.random() * scale)
return np.array(values, dtype=np.float32).reshape(shape)
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids
def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 1e-2):
diff = np.abs((a - b)).max()
if diff < tol:
print(f"β
Difference between Flax and PyTorch is {diff} (< {tol})")
else:
print(f"β Difference between Flax and PyTorch is {diff} (>= {tol})")
def assert_dict_equal(a: dict, b: dict, tol: float = 1e-2):
if a.keys() != b.keys():
print("β Dictionary keys for PyTorch and Flax do not match")
results_fail = []
results_correct = []
results_fail_rel = []
results_correct_rel = []
for k in a:
ak_norm = np.linalg.norm(a[k])
bk_norm = np.linalg.norm(b[k])
diff = np.abs(ak_norm - bk_norm)
diff_rel = np.abs(ak_norm - bk_norm) / np.abs(ak_norm)
if diff < tol:
results_correct.append(f"β
Layer {k} diff is {diff} < {tol}).")
else:
results_fail.append(f"β Layer {k} has PT grad norm {bk_norm} and flax grad norm {ak_norm}.")
if diff_rel < tol:
results_correct_rel.append(f"β
Layer {k} rel diff is {diff} < {tol}).")
else:
results_fail_rel.append(f"β Layer {k} has PT grad norm {bk_norm} and flax grad norm {ak_norm}.")
return results_fail_rel, results_correct_rel, results_fail, results_correct
def compare_grads(model_id, pt_architecture):
transformers_module = __import__("transformers", fromlist=[pt_architecture])
model_cls = getattr(transformers_module, pt_architecture)
flax_model_cls = getattr(transformers_module, "Flax" + pt_architecture)
pt_model, model_info = model_cls.from_pretrained(model_id, output_loading_info=True)
if len(model_info["missing_keys"]) > 0:
raise ValueError(f"{model_id} with {pt_architecture} has missing keys: {model_info['missing_keys']}")
fx_model = flax_model_cls.from_pretrained(model_id, from_pt=True)
batch_size = 2
seq_len = 64
input_ids = ids_tensor([batch_size, seq_len], fx_model.config.vocab_size)
label_ids = ids_tensor([batch_size, seq_len], fx_model.config.vocab_size)
attention_mask = random_attention_mask([batch_size, seq_len])
label_ids = ids_tensor([batch_size, seq_len], fx_model.config.vocab_size)
fx_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
if pt_model.config.is_encoder_decoder:
decoder_input_ids = shift_tokens_right(input_ids=label_ids, pad_token_id=fx_model.config.pad_token_id, decoder_start_token_id=fx_model.config.decoder_start_token_id)
fx_inputs["decoder_input_ids"] = decoder_input_ids
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in fx_inputs.items()}
pt_inputs["labels"] = torch.tensor(label_ids.tolist())
fx_outputs = fx_model(**fx_inputs)
fx_logits = fx_outputs.logits
pt_outputs = pt_model(**pt_inputs)
pt_logits = pt_outputs.logits
pt_loss = pt_outputs.loss
print("--------------------------Checking logits match--------------------------")
print(f"Flax logits shape: {fx_logits.shape}, PyTorch logits shape: {pt_logits.shape}")
assert_almost_equals(fx_logits, pt_logits.detach().numpy())
def fx_train_step(fx_model, batch):
def compute_loss(params):
label_ids = batch.pop('label_ids')
logits = fx_model(**batch, params=params).logits
vocab_size = logits.shape[-1]
targets = onehot(label_ids, vocab_size)
loss = optax.softmax_cross_entropy(logits, targets)
return loss.mean()
grad_fn = jax.value_and_grad(compute_loss)
loss, grad = grad_fn(fx_model.params)
return loss, grad
fx_inputs["label_ids"] = label_ids
fx_loss, fx_grad = fx_train_step(fx_model, fx_inputs)
print("--------------------------Checking losses match--------------------------")
print(f"Flax loss: {fx_loss}, PyTorch loss: {pt_loss}")
assert_almost_equals(fx_loss, pt_loss.detach().numpy())
pt_loss.backward()
pt_grad_dict = {k: v.grad for k, v in pt_model.named_parameters()}
missing_grads = [k for k in pt_model.state_dict().keys() if k not in pt_grad_dict]
missing_keys, unexpected_keys = pt_model.load_state_dict(pt_grad_dict, strict=False)
assert missing_grads == missing_keys, f"Error with either grads {missing_keys} or keys {unexpected_keys}"
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
pt_grad_model_to_fx = flax_model_cls.from_pretrained(tmpdirname, from_pt=True)
pt_grad_to_fx = pt_grad_model_to_fx.params
fx_grad = flatten_dict(fx_grad)
pt_grad_to_fx = flatten_dict(pt_grad_to_fx)
print("--------------------------Checking gradients match--------------------------")
results_fail_rel, results_correct_rel, results_fail, results_correct = assert_dict_equal(fx_grad, pt_grad_to_fx)
if len(results_fail) == 0:
print("β
All grads pass")
else:
print("\n".join(results_fail))
print("--------------------------Checking rel gradients match--------------------------")
if len(results_fail_rel) == 0:
print("β
All rel grads pass")
else:
print("\n".join(results_fail_rel))
def main():
model_id = sys.argv[1]
pt_architecture_name = sys.argv[2]
compare_grads(model_id, pt_architecture_name)
if __name__ == "__main__":
main()
|