| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import unittest |
| |
|
| | from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig |
| | from transformers.models.bert.configuration_bert import BertConfig |
| | from transformers.models.roberta.configuration_roberta import RobertaConfig |
| | from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER |
| |
|
| |
|
| | SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json") |
| |
|
| |
|
| | class AutoConfigTest(unittest.TestCase): |
| | def test_config_from_model_shortcut(self): |
| | config = AutoConfig.from_pretrained("bert-base-uncased") |
| | self.assertIsInstance(config, BertConfig) |
| |
|
| | def test_config_model_type_from_local_file(self): |
| | config = AutoConfig.from_pretrained(SAMPLE_ROBERTA_CONFIG) |
| | self.assertIsInstance(config, RobertaConfig) |
| |
|
| | def test_config_model_type_from_model_identifier(self): |
| | config = AutoConfig.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER) |
| | self.assertIsInstance(config, RobertaConfig) |
| |
|
| | def test_config_for_model_str(self): |
| | config = AutoConfig.for_model("roberta") |
| | self.assertIsInstance(config, RobertaConfig) |
| |
|
| | def test_pattern_matching_fallback(self): |
| | """ |
| | In cases where config.json doesn't include a model_type, |
| | perform a few safety checks on the config mapping's order. |
| | """ |
| | |
| | keys = list(CONFIG_MAPPING.keys()) |
| | for i, key in enumerate(keys): |
| | self.assertFalse(any(key in later_key for later_key in keys[i + 1 :])) |
| |
|