|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import unittest |
|
|
|
from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig |
|
from transformers.models.bert.configuration_bert import BertConfig |
|
from transformers.models.roberta.configuration_roberta import RobertaConfig |
|
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER |
|
|
|
|
|
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json") |
|
|
|
|
|
class AutoConfigTest(unittest.TestCase): |
|
def test_config_from_model_shortcut(self): |
|
config = AutoConfig.from_pretrained("bert-base-uncased") |
|
self.assertIsInstance(config, BertConfig) |
|
|
|
def test_config_model_type_from_local_file(self): |
|
config = AutoConfig.from_pretrained(SAMPLE_ROBERTA_CONFIG) |
|
self.assertIsInstance(config, RobertaConfig) |
|
|
|
def test_config_model_type_from_model_identifier(self): |
|
config = AutoConfig.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER) |
|
self.assertIsInstance(config, RobertaConfig) |
|
|
|
def test_config_for_model_str(self): |
|
config = AutoConfig.for_model("roberta") |
|
self.assertIsInstance(config, RobertaConfig) |
|
|
|
def test_pattern_matching_fallback(self): |
|
""" |
|
In cases where config.json doesn't include a model_type, |
|
perform a few safety checks on the config mapping's order. |
|
""" |
|
|
|
keys = list(CONFIG_MAPPING.keys()) |
|
for i, key in enumerate(keys): |
|
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :])) |
|
|