|
from tempfile import TemporaryDirectory |
|
from unittest import TestCase |
|
from unittest.mock import MagicMock, patch |
|
|
|
from transformers import AutoModel, TFAutoModel |
|
from transformers.onnx import FeaturesManager |
|
from transformers.testing_utils import SMALL_MODEL_IDENTIFIER, require_tf, require_torch |
|
|
|
|
|
@require_torch |
|
@require_tf |
|
class DetermineFrameworkTest(TestCase): |
|
""" |
|
Test `FeaturesManager.determine_framework` |
|
""" |
|
|
|
def setUp(self): |
|
self.test_model = SMALL_MODEL_IDENTIFIER |
|
self.framework_pt = "pt" |
|
self.framework_tf = "tf" |
|
|
|
def _setup_pt_ckpt(self, save_dir): |
|
model_pt = AutoModel.from_pretrained(self.test_model) |
|
model_pt.save_pretrained(save_dir) |
|
|
|
def _setup_tf_ckpt(self, save_dir): |
|
model_tf = TFAutoModel.from_pretrained(self.test_model, from_pt=True) |
|
model_tf.save_pretrained(save_dir) |
|
|
|
def test_framework_provided(self): |
|
""" |
|
Ensure the that the provided framework is returned. |
|
""" |
|
mock_framework = "mock_framework" |
|
|
|
|
|
result = FeaturesManager.determine_framework(self.test_model, mock_framework) |
|
self.assertEqual(result, mock_framework) |
|
|
|
|
|
|
|
with TemporaryDirectory() as local_pt_ckpt: |
|
self._setup_pt_ckpt(local_pt_ckpt) |
|
result = FeaturesManager.determine_framework(local_pt_ckpt, mock_framework) |
|
self.assertEqual(result, mock_framework) |
|
|
|
|
|
with TemporaryDirectory() as local_tf_ckpt: |
|
self._setup_tf_ckpt(local_tf_ckpt) |
|
result = FeaturesManager.determine_framework(local_tf_ckpt, mock_framework) |
|
self.assertEqual(result, mock_framework) |
|
|
|
def test_checkpoint_provided(self): |
|
""" |
|
Ensure that the determined framework is the one used for the local checkpoint. |
|
|
|
For the functionality to execute, local checkpoints are provided but framework is not. |
|
""" |
|
|
|
with TemporaryDirectory() as local_pt_ckpt: |
|
self._setup_pt_ckpt(local_pt_ckpt) |
|
result = FeaturesManager.determine_framework(local_pt_ckpt) |
|
self.assertEqual(result, self.framework_pt) |
|
|
|
|
|
with TemporaryDirectory() as local_tf_ckpt: |
|
self._setup_tf_ckpt(local_tf_ckpt) |
|
result = FeaturesManager.determine_framework(local_tf_ckpt) |
|
self.assertEqual(result, self.framework_tf) |
|
|
|
|
|
with TemporaryDirectory() as local_invalid_ckpt: |
|
with self.assertRaises(FileNotFoundError): |
|
result = FeaturesManager.determine_framework(local_invalid_ckpt) |
|
|
|
def test_from_environment(self): |
|
""" |
|
Ensure that the determined framework is the one available in the environment. |
|
|
|
For the functionality to execute, framework and local checkpoints are not provided. |
|
""" |
|
|
|
|
|
mock_tf_available = MagicMock(return_value=False) |
|
with patch("transformers.onnx.features.is_tf_available", mock_tf_available): |
|
result = FeaturesManager.determine_framework(self.test_model) |
|
self.assertEqual(result, self.framework_pt) |
|
|
|
|
|
mock_torch_available = MagicMock(return_value=False) |
|
with patch("transformers.onnx.features.is_torch_available", mock_torch_available): |
|
result = FeaturesManager.determine_framework(self.test_model) |
|
self.assertEqual(result, self.framework_tf) |
|
|
|
|
|
mock_tf_available = MagicMock(return_value=True) |
|
mock_torch_available = MagicMock(return_value=True) |
|
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch( |
|
"transformers.onnx.features.is_torch_available", mock_torch_available |
|
): |
|
result = FeaturesManager.determine_framework(self.test_model) |
|
self.assertEqual(result, self.framework_pt) |
|
|
|
|
|
mock_tf_available = MagicMock(return_value=False) |
|
mock_torch_available = MagicMock(return_value=False) |
|
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch( |
|
"transformers.onnx.features.is_torch_available", mock_torch_available |
|
): |
|
with self.assertRaises(EnvironmentError): |
|
result = FeaturesManager.determine_framework(self.test_model) |
|
|