voice_clone_v3 / transformers /tests /test_modeling_tf_utils.py
ahassoun's picture
Upload 3018 files
ee6e328
raw history blame
No virus
27.9 kB
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import inspect
import json
import os
import random
import tempfile
import unittest
import unittest.mock as mock
from huggingface_hub import HfFolder, Repository, delete_repo
from huggingface_hub.file_download import http_get
from requests.exceptions import HTTPError
from transformers import is_tf_available, is_torch_available
from transformers.configuration_utils import PretrainedConfig
from transformers.testing_utils import ( # noqa: F401
TOKEN,
USER,
CaptureLogger,
_tf_gpu_memory_limit,
is_pt_tf_cross_test,
is_staging_test,
require_safetensors,
require_tf,
slow,
)
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
logger = logging.get_logger(__name__)
if is_tf_available():
import h5py
import numpy as np
import tensorflow as tf
from transformers import (
BertConfig,
PreTrainedModel,
PushToHubCallback,
RagRetriever,
TFBertForMaskedLM,
TFBertForSequenceClassification,
TFBertModel,
TFPreTrainedModel,
TFRagModel,
)
from transformers.modeling_tf_utils import tf_shard_checkpoint, unpack_inputs
from transformers.tf_utils import stable_softmax
tf.config.experimental.enable_tensor_float_32_execution(False)
if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU")
for gpu in gpus:
# Restrict TensorFlow to only allocate x GB of memory on the GPUs
try:
tf.config.set_logical_device_configuration(
gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)]
)
logical_gpus = tf.config.list_logical_devices("GPU")
print("Logical GPUs", logical_gpus)
except RuntimeError as e:
# Virtual devices must be set before GPUs have been initialized
print(e)
if is_torch_available():
from transformers import BertModel
@require_tf
class TFModelUtilsTest(unittest.TestCase):
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
def test_load_from_one_file(self):
try:
tmp_file = tempfile.mktemp()
with open(tmp_file, "wb") as f:
http_get("https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5", f)
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
_ = TFBertModel.from_pretrained(tmp_file, config=config)
finally:
os.remove(tmp_file)
def test_legacy_load_from_url(self):
# This test is for deprecated behavior and can be removed in v5
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
_ = TFBertModel.from_pretrained(
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5", config=config
)
# tests whether the unpack_inputs function behaves as expected
def test_unpack_inputs(self):
class DummyModel:
def __init__(self):
config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False}
self.config = PretrainedConfig(**config_kwargs)
self.main_input_name = "input_ids"
@unpack_inputs
def call(
self,
input_ids=None,
past_key_values=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return input_ids, past_key_values, output_attentions, output_hidden_states, return_dict
@unpack_inputs
def foo(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None):
return pixel_values, output_attentions, output_hidden_states, return_dict
dummy_model = DummyModel()
input_ids = tf.constant([0, 1, 2, 3], dtype=tf.int32)
past_key_values = tf.constant([4, 5, 6, 7], dtype=tf.int32)
pixel_values = tf.constant([8, 9, 10, 11], dtype=tf.int32)
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
output = dummy_model.call(input_ids=input_ids, past_key_values=past_key_values)
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past_key_values)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertFalse(output[4])
# test case 2: Same as above, but with positional arguments.
output = dummy_model.call(input_ids, past_key_values)
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past_key_values)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertFalse(output[4])
# test case 3: We can also pack everything in the first input.
output = dummy_model.call(input_ids={"input_ids": input_ids, "past_key_values": past_key_values})
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past_key_values)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertFalse(output[4])
# test case 4: Explicit boolean arguments should override the config.
output = dummy_model.call(
input_ids=input_ids, past_key_values=past_key_values, output_attentions=False, return_dict=True
)
tf.debugging.assert_equal(output[0], input_ids)
tf.debugging.assert_equal(output[1], past_key_values)
self.assertFalse(output[2])
self.assertFalse(output[3])
self.assertTrue(output[4])
# test case 5: Unexpected arguments should raise an exception.
with self.assertRaises(ValueError):
output = dummy_model.call(input_ids=input_ids, past_key_values=past_key_values, foo="bar")
# test case 6: the decorator is independent from `main_input_name` -- it treats the first argument of the
# decorated function as its main input.
output = dummy_model.foo(pixel_values=pixel_values)
tf.debugging.assert_equal(output[0], pixel_values)
self.assertFalse(output[1])
self.assertFalse(output[2])
self.assertFalse(output[3])
# Tests whether the stable softmax is stable on CPU, with and without XLA
def test_xla_stable_softmax(self):
large_penalty = -1e9
n_tokens = 10
batch_size = 8
def masked_softmax(x, boolean_mask):
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
masked_x = x + numerical_mask
return stable_softmax(masked_x)
xla_masked_softmax = tf.function(masked_softmax, jit_compile=True)
xla_stable_softmax = tf.function(stable_softmax, jit_compile=True)
x = tf.random.normal((batch_size, n_tokens))
# Same outcome regardless of the boolean mask here
masked_tokens = random.randint(0, n_tokens)
boolean_mask = tf.convert_to_tensor([[1] * (n_tokens - masked_tokens) + [0] * masked_tokens], dtype=tf.int32)
# We can randomly mask a random numerical input OUTSIDE XLA
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
masked_x = x + numerical_mask
xla_out = xla_stable_softmax(masked_x)
out = stable_softmax(masked_x)
assert tf.experimental.numpy.allclose(xla_out, out)
# The stable softmax has the same output as the original softmax
unstable_out = tf.nn.softmax(masked_x)
assert tf.experimental.numpy.allclose(unstable_out, out)
# We can randomly mask a random numerical input INSIDE XLA
xla_out = xla_masked_softmax(x, boolean_mask)
out = masked_softmax(x, boolean_mask)
assert tf.experimental.numpy.allclose(xla_out, out)
def test_checkpoint_sharding_from_hub(self):
model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
# the model above is the same as the model below, just a sharded version.
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())
def test_sharded_checkpoint_with_prefix(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", load_weight_prefix="a/b")
sharded_model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded", load_weight_prefix="a/b")
for p1, p2 in zip(model.weights, sharded_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
self.assertTrue(p1.name.startswith("a/b/"))
self.assertTrue(p2.name.startswith("a/b/"))
def test_sharded_checkpoint_transfer(self):
# If this doesn't throw an error then the test passes
TFBertForSequenceClassification.from_pretrained("ArthurZ/tiny-random-bert-sharded")
@is_pt_tf_cross_test
def test_checkpoint_sharding_local_from_pt(self):
with tempfile.TemporaryDirectory() as tmp_dir:
_ = Repository(local_dir=tmp_dir, clone_from="hf-internal-testing/tiny-random-bert-sharded")
model = TFBertModel.from_pretrained(tmp_dir, from_pt=True)
# the model above is the same as the model below, just a sharded pytorch version.
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())
@is_pt_tf_cross_test
def test_checkpoint_loading_with_prefix_from_pt(self):
model = TFBertModel.from_pretrained(
"hf-internal-testing/tiny-random-bert", from_pt=True, load_weight_prefix="a/b"
)
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
for p1, p2 in zip(model.weights, ref_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
self.assertTrue(p1.name.startswith("a/b/"))
@is_pt_tf_cross_test
def test_checkpoint_sharding_hub_from_pt(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
# the model above is the same as the model below, just a sharded pytorch version.
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())
def test_shard_checkpoint(self):
# This is the model we will use, total size 340,000 bytes.
model = tf.keras.Sequential(
[
tf.keras.layers.Dense(200, use_bias=False), # size 80,000
tf.keras.layers.Dense(200, use_bias=False), # size 160,000
tf.keras.layers.Dense(100, use_bias=False), # size 80,000
tf.keras.layers.Dense(50, use_bias=False), # size 20,000
]
)
inputs = tf.zeros((1, 100), dtype=tf.float32)
model(inputs)
weights = model.weights
weights_dict = {w.name: w for w in weights}
with self.subTest("No shard when max size is bigger than model size"):
shards, index = tf_shard_checkpoint(weights)
self.assertIsNone(index)
self.assertDictEqual(shards, {TF2_WEIGHTS_NAME: weights})
with self.subTest("Test sharding, no weights bigger than max size"):
shards, index = tf_shard_checkpoint(weights, max_shard_size="300kB")
# Split is first two layers then last two.
self.assertDictEqual(
index,
{
"metadata": {"total_size": 340000},
"weight_map": {
"dense/kernel:0": "tf_model-00001-of-00002.h5",
"dense_1/kernel:0": "tf_model-00001-of-00002.h5",
"dense_2/kernel:0": "tf_model-00002-of-00002.h5",
"dense_3/kernel:0": "tf_model-00002-of-00002.h5",
},
},
)
shard1 = [weights_dict["dense/kernel:0"], weights_dict["dense_1/kernel:0"]]
shard2 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
self.assertDictEqual(shards, {"tf_model-00001-of-00002.h5": shard1, "tf_model-00002-of-00002.h5": shard2})
with self.subTest("Test sharding with weights bigger than max size"):
shards, index = tf_shard_checkpoint(weights, max_shard_size="100kB")
# Split is first layer, second layer then last 2.
self.assertDictEqual(
index,
{
"metadata": {"total_size": 340000},
"weight_map": {
"dense/kernel:0": "tf_model-00001-of-00003.h5",
"dense_1/kernel:0": "tf_model-00002-of-00003.h5",
"dense_2/kernel:0": "tf_model-00003-of-00003.h5",
"dense_3/kernel:0": "tf_model-00003-of-00003.h5",
},
},
)
shard1 = [weights_dict["dense/kernel:0"]]
shard2 = [weights_dict["dense_1/kernel:0"]]
shard3 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
self.assertDictEqual(
shards,
{
"tf_model-00001-of-00003.h5": shard1,
"tf_model-00002-of-00003.h5": shard2,
"tf_model-00003-of-00003.h5": shard3,
},
)
@slow
def test_special_layer_name_sharding(self):
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever)
with tempfile.TemporaryDirectory() as tmp_dir:
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
model.save_pretrained(tmp_dir, max_shard_size=max_size)
ref_model = TFRagModel.from_pretrained(tmp_dir, retriever=retriever)
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())
def test_checkpoint_sharding_local(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
model.save_pretrained(tmp_dir, max_shard_size=max_size)
# Get each shard file and its size
shard_to_size = {}
for shard in os.listdir(tmp_dir):
if shard.endswith(".h5"):
shard_file = os.path.join(tmp_dir, shard)
shard_to_size[shard_file] = os.path.getsize(shard_file)
index_file = os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)
# Check there is an index but no regular weight file
self.assertTrue(os.path.isfile(index_file))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
# Check a file is bigger than max_size only when it has a single weight
for shard_file, size in shard_to_size.items():
if max_size.endswith("kiB"):
max_size_int = int(max_size[:-3]) * 2**10
else:
max_size_int = int(max_size[:-2]) * 10**3
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
# the size asked for (since we count parameters)
if size >= max_size_int + 50000:
with h5py.File(shard_file, "r") as state_file:
self.assertEqual(len(state_file), 1)
# Check the index and the shard files found match
with open(index_file, "r", encoding="utf-8") as f:
index = json.loads(f.read())
all_shards = set(index["weight_map"].values())
shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".h5")}
self.assertSetEqual(all_shards, shards_found)
# Finally, check the model can be reloaded
new_model = TFBertModel.from_pretrained(tmp_dir)
model.build()
new_model.build()
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@slow
def test_save_pretrained_signatures(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Short custom TF signature function.
# `input_signature` is specific to BERT.
@tf.function(
input_signature=[
[
tf.TensorSpec([None, None], tf.int32, name="input_ids"),
tf.TensorSpec([None, None], tf.int32, name="token_type_ids"),
tf.TensorSpec([None, None], tf.int32, name="attention_mask"),
]
]
)
def serving_fn(input):
return model(input)
# Using default signature (default behavior) overrides 'serving_default'
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, saved_model=True, signatures=None)
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
self.assertTrue("serving_default" in list(model_loaded.signatures.keys()))
# Providing custom signature function
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, saved_model=True, signatures={"custom_signature": serving_fn})
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
self.assertTrue("custom_signature" in list(model_loaded.signatures.keys()))
# Providing multiple custom signature function
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(
tmp_dir,
saved_model=True,
signatures={"custom_signature_1": serving_fn, "custom_signature_2": serving_fn},
)
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
self.assertTrue("custom_signature_1" in list(model_loaded.signatures.keys()))
self.assertTrue("custom_signature_2" in list(model_loaded.signatures.keys()))
@require_safetensors
def test_safetensors_save_and_load(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
# No tf_model.h5 file, only a model.safetensors
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
new_model = TFBertModel.from_pretrained(tmp_dir)
# Check models are equal
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@is_pt_tf_cross_test
def test_safetensors_save_and_load_pt_to_tf(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
pt_model.save_pretrained(tmp_dir, safe_serialization=True)
# Check we have a model.safetensors file
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
new_model = TFBertModel.from_pretrained(tmp_dir)
# Check models are equal
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_load_from_hub(self):
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Can load from the TF-formatted checkpoint
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors-tf")
# Check models are equal
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
# Can load from the PyTorch-formatted checkpoint
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
# Check models are equal
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_tf
@is_staging_test
class TFModelPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._token = TOKEN
HfFolder.save_token(TOKEN)
@classmethod
def tearDownClass(cls):
try:
delete_repo(token=cls._token, repo_id="test-model-tf")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="test-model-tf-callback")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="valid_org/test-model-tf-org")
except HTTPError:
pass
def test_push_to_hub(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertModel(config)
# Make sure model is properly initialized
model.build()
logging.set_verbosity_info()
logger = logging.get_logger("transformers.utils.hub")
with CaptureLogger(logger) as cl:
model.push_to_hub("test-model-tf", use_auth_token=self._token)
logging.set_verbosity_warning()
# Check the model card was created and uploaded.
self.assertIn("Uploading the following files to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
# Reset repo
delete_repo(token=self._token, repo_id="test-model-tf")
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, repo_id="test-model-tf", push_to_hub=True, use_auth_token=self._token)
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
@is_pt_tf_cross_test
def test_push_to_hub_callback(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertForMaskedLM(config)
model.compile()
with tempfile.TemporaryDirectory() as tmp_dir:
push_to_hub_callback = PushToHubCallback(
output_dir=tmp_dir,
hub_model_id="test-model-tf-callback",
hub_token=self._token,
)
model.fit(model.dummy_inputs, model.dummy_inputs, epochs=1, callbacks=[push_to_hub_callback])
new_model = TFBertForMaskedLM.from_pretrained(f"{USER}/test-model-tf-callback")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters)
tf_push_to_hub_params.pop("base_model_card_args")
pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters)
pt_push_to_hub_params.pop("deprecated_kwargs")
self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params)
def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertModel(config)
# Make sure model is properly initialized
model.build()
model.push_to_hub("valid_org/test-model-tf-org", use_auth_token=self._token)
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-model-tf-org")
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(
tmp_dir, push_to_hub=True, use_auth_token=self._token, repo_id="valid_org/test-model-tf-org"
)
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)