|
|
|
from transformers import SpeechEncoderDecoderModel, FlaxSpeechEncoderDecoderModel |
|
import tempfile |
|
import random |
|
import numpy as np |
|
import torch |
|
import optax |
|
import jax |
|
from flax.training.common_utils import onehot |
|
from flax.traverse_util import flatten_dict |
|
|
|
|
|
def ids_tensor(shape, vocab_size, rng=None): |
|
"""Creates a random int32 tensor of the shape within the vocab size.""" |
|
if rng is None: |
|
rng = random.Random() |
|
|
|
total_dims = 1 |
|
for dim in shape: |
|
total_dims *= dim |
|
|
|
values = [] |
|
for _ in range(total_dims): |
|
values.append(rng.randint(0, vocab_size - 1)) |
|
|
|
output = np.array(values).reshape(shape) |
|
|
|
return output |
|
|
|
|
|
def random_attention_mask(shape, rng=None): |
|
attn_mask = ids_tensor(shape, vocab_size=2, rng=rng) |
|
|
|
attn_mask[:, -1] = 1 |
|
return attn_mask |
|
|
|
|
|
def floats_tensor(shape, scale=1.0, rng=None): |
|
"""Creates a random float32 tensor""" |
|
if rng is None: |
|
rng = random.Random() |
|
|
|
total_dims = 1 |
|
for dim in shape: |
|
total_dims *= dim |
|
|
|
values = [] |
|
for _ in range(total_dims): |
|
values.append(rng.random() * scale) |
|
|
|
return np.array(values, dtype=np.float32).reshape(shape) |
|
|
|
|
|
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: |
|
""" |
|
Shift input ids one token to the right. |
|
""" |
|
shifted_input_ids = np.zeros_like(input_ids) |
|
shifted_input_ids[:, 1:] = input_ids[:, :-1] |
|
shifted_input_ids[:, 0] = decoder_start_token_id |
|
|
|
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) |
|
return shifted_input_ids |
|
|
|
|
|
def assert_almost_equals(a: np.ndarray, b: np.ndarray, tol: float = 4e-2): |
|
diff = np.abs((a - b)).max() |
|
if diff < tol: |
|
print(f"β
Difference between Flax and PyTorch is {diff} (< {tol})") |
|
else: |
|
print(f"β Difference between Flax and PyTorch is {diff} (>= {tol})") |
|
|
|
|
|
def assert_dict_equal(a: dict, b: dict, tol: float = 4e-2): |
|
if a.keys() != b.keys(): |
|
print("β Dictionary keys for PyTorch and Flax do not match") |
|
for k in a: |
|
diff = np.abs((a[k] - b[k])).max() |
|
if diff < tol: |
|
print(f"β
Layer {k} diff is {diff} < {tol}).") |
|
else: |
|
print(f"β Layer {k} diff is {diff} (>= {tol}).") |
|
|
|
|
|
def main(): |
|
encoder_id = "hf-internal-testing/tiny-random-wav2vec2" |
|
decoder_id = "hf-internal-testing/tiny-random-bart" |
|
|
|
use_decoder_attention_mask = False |
|
freeze_feature_encoder = False |
|
|
|
pt_model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, |
|
encoder_add_adapter=True) |
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
pt_model.save_pretrained(tmpdirname) |
|
fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) |
|
|
|
batch_size = 13 |
|
input_values = floats_tensor([batch_size, 512], fx_model.config.encoder.vocab_size) |
|
attention_mask = random_attention_mask([batch_size, 512]) |
|
label_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size) |
|
decoder_input_ids = shift_tokens_right(input_ids=label_ids, pad_token_id=fx_model.config.decoder.pad_token_id, |
|
decoder_start_token_id=fx_model.config.decoder.decoder_start_token_id) |
|
decoder_attention_mask = random_attention_mask([batch_size, 4]) |
|
|
|
fx_inputs = { |
|
"inputs": input_values, |
|
"attention_mask": attention_mask, |
|
"decoder_input_ids": decoder_input_ids, |
|
} |
|
if use_decoder_attention_mask: |
|
fx_inputs["decoder_attention_mask"] = decoder_attention_mask |
|
|
|
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in fx_inputs.items()} |
|
pt_inputs["labels"] = torch.tensor(label_ids.tolist()) |
|
|
|
fx_outputs = fx_model(**fx_inputs) |
|
fx_logits = fx_outputs.logits |
|
|
|
if freeze_feature_encoder: |
|
pt_model.freeze_feature_encoder() |
|
|
|
pt_outputs = pt_model(**pt_inputs) |
|
pt_logits = pt_outputs.logits |
|
pt_loss = pt_outputs.loss |
|
|
|
print("--------------------------Checking logits match--------------------------") |
|
print(f"Flax logits shape: {fx_logits.shape}, PyTorch logits shape: {pt_logits.shape}") |
|
assert_almost_equals(fx_logits, pt_logits.detach().numpy()) |
|
|
|
def fx_train_step(fx_model, batch, freeze_feature_encoder=False): |
|
def compute_loss(params): |
|
label_ids = batch.pop('label_ids') |
|
logits = fx_model(**batch, params=params, |
|
freeze_feature_encoder=freeze_feature_encoder).logits |
|
vocab_size = logits.shape[-1] |
|
targets = onehot(label_ids, vocab_size) |
|
loss = optax.softmax_cross_entropy(logits, targets) |
|
return loss.mean() |
|
|
|
grad_fn = jax.value_and_grad(compute_loss) |
|
loss, grad = grad_fn(fx_model.params) |
|
return loss, grad |
|
|
|
fx_inputs["label_ids"] = label_ids |
|
|
|
fx_loss, fx_grad = fx_train_step(fx_model, fx_inputs, freeze_feature_encoder=freeze_feature_encoder) |
|
|
|
print("--------------------------Checking losses match--------------------------") |
|
print(f"Flax loss: {fx_loss}, PyTorch loss: {pt_loss}") |
|
assert_almost_equals(fx_loss, pt_loss.detach().numpy()) |
|
|
|
pt_loss.backward() |
|
|
|
pt_grad_dict = {k: v.grad if v.grad is not None else torch.zeros_like(v) for k, v in pt_model.named_parameters()} |
|
|
|
for k in pt_model.state_dict(): |
|
if k not in pt_grad_dict: |
|
|
|
|
|
pt_grad_dict[k] = torch.zeros_like(pt_model.state_dict()[k]) |
|
pt_model.state_dict()[k] = pt_grad_dict[k] |
|
|
|
pt_model.load_state_dict(pt_grad_dict) |
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
pt_model.save_pretrained(tmpdirname) |
|
pt_grad_model_to_fx = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) |
|
|
|
pt_grad_to_fx = pt_grad_model_to_fx.params |
|
fx_grad = flatten_dict(fx_grad) |
|
pt_grad_to_fx = flatten_dict(pt_grad_to_fx) |
|
print("--------------------------Checking gradients match--------------------------") |
|
assert_dict_equal(fx_grad, pt_grad_to_fx) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|