|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import tempfile |
|
import unittest |
|
|
|
from transformers import BertConfig, is_torch_available |
|
from transformers.models.auto import get_values |
|
from transformers.testing_utils import CaptureLogger, require_torch, require_torch_gpu, slow, torch_device |
|
|
|
from ...generation.test_utils import GenerationTesterMixin |
|
from ...test_configuration_common import ConfigTester |
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask |
|
from ...test_pipeline_mixin import PipelineTesterMixin |
|
|
|
|
|
if is_torch_available(): |
|
import torch |
|
|
|
from transformers import ( |
|
MODEL_FOR_PRETRAINING_MAPPING, |
|
BertForMaskedLM, |
|
BertForMultipleChoice, |
|
BertForNextSentencePrediction, |
|
BertForPreTraining, |
|
BertForQuestionAnswering, |
|
BertForSequenceClassification, |
|
BertForTokenClassification, |
|
BertLMHeadModel, |
|
BertModel, |
|
logging, |
|
) |
|
from transformers.models.bert.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST |
|
|
|
|
|
class BertModelTester: |
|
def __init__( |
|
self, |
|
parent, |
|
batch_size=13, |
|
seq_length=7, |
|
is_training=True, |
|
use_input_mask=True, |
|
use_token_type_ids=True, |
|
use_labels=True, |
|
vocab_size=99, |
|
hidden_size=32, |
|
num_hidden_layers=5, |
|
num_attention_heads=4, |
|
intermediate_size=37, |
|
hidden_act="gelu", |
|
hidden_dropout_prob=0.1, |
|
attention_probs_dropout_prob=0.1, |
|
max_position_embeddings=512, |
|
type_vocab_size=16, |
|
type_sequence_label_size=2, |
|
initializer_range=0.02, |
|
num_labels=3, |
|
num_choices=4, |
|
scope=None, |
|
): |
|
self.parent = parent |
|
self.batch_size = batch_size |
|
self.seq_length = seq_length |
|
self.is_training = is_training |
|
self.use_input_mask = use_input_mask |
|
self.use_token_type_ids = use_token_type_ids |
|
self.use_labels = use_labels |
|
self.vocab_size = vocab_size |
|
self.hidden_size = hidden_size |
|
self.num_hidden_layers = num_hidden_layers |
|
self.num_attention_heads = num_attention_heads |
|
self.intermediate_size = intermediate_size |
|
self.hidden_act = hidden_act |
|
self.hidden_dropout_prob = hidden_dropout_prob |
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob |
|
self.max_position_embeddings = max_position_embeddings |
|
self.type_vocab_size = type_vocab_size |
|
self.type_sequence_label_size = type_sequence_label_size |
|
self.initializer_range = initializer_range |
|
self.num_labels = num_labels |
|
self.num_choices = num_choices |
|
self.scope = scope |
|
|
|
def prepare_config_and_inputs(self): |
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) |
|
|
|
input_mask = None |
|
if self.use_input_mask: |
|
input_mask = random_attention_mask([self.batch_size, self.seq_length]) |
|
|
|
token_type_ids = None |
|
if self.use_token_type_ids: |
|
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) |
|
|
|
sequence_labels = None |
|
token_labels = None |
|
choice_labels = None |
|
if self.use_labels: |
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) |
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) |
|
choice_labels = ids_tensor([self.batch_size], self.num_choices) |
|
|
|
config = self.get_config() |
|
|
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels |
|
|
|
def get_config(self): |
|
""" |
|
Returns a tiny configuration by default. |
|
""" |
|
return BertConfig( |
|
vocab_size=self.vocab_size, |
|
hidden_size=self.hidden_size, |
|
num_hidden_layers=self.num_hidden_layers, |
|
num_attention_heads=self.num_attention_heads, |
|
intermediate_size=self.intermediate_size, |
|
hidden_act=self.hidden_act, |
|
hidden_dropout_prob=self.hidden_dropout_prob, |
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob, |
|
max_position_embeddings=self.max_position_embeddings, |
|
type_vocab_size=self.type_vocab_size, |
|
is_decoder=False, |
|
initializer_range=self.initializer_range, |
|
) |
|
|
|
def prepare_config_and_inputs_for_decoder(self): |
|
( |
|
config, |
|
input_ids, |
|
token_type_ids, |
|
input_mask, |
|
sequence_labels, |
|
token_labels, |
|
choice_labels, |
|
) = self.prepare_config_and_inputs() |
|
|
|
config.is_decoder = True |
|
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) |
|
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) |
|
|
|
return ( |
|
config, |
|
input_ids, |
|
token_type_ids, |
|
input_mask, |
|
sequence_labels, |
|
token_labels, |
|
choice_labels, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
) |
|
|
|
def create_and_check_model( |
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels |
|
): |
|
model = BertModel(config=config) |
|
model.to(torch_device) |
|
model.eval() |
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) |
|
result = model(input_ids, token_type_ids=token_type_ids) |
|
result = model(input_ids) |
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) |
|
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) |
|
|
|
def create_and_check_model_as_decoder( |
|
self, |
|
config, |
|
input_ids, |
|
token_type_ids, |
|
input_mask, |
|
sequence_labels, |
|
token_labels, |
|
choice_labels, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
): |
|
config.add_cross_attention = True |
|
model = BertModel(config) |
|
model.to(torch_device) |
|
model.eval() |
|
result = model( |
|
input_ids, |
|
attention_mask=input_mask, |
|
token_type_ids=token_type_ids, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
) |
|
result = model( |
|
input_ids, |
|
attention_mask=input_mask, |
|
token_type_ids=token_type_ids, |
|
encoder_hidden_states=encoder_hidden_states, |
|
) |
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) |
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) |
|
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) |
|
|
|
def create_and_check_for_causal_lm( |
|
self, |
|
config, |
|
input_ids, |
|
token_type_ids, |
|
input_mask, |
|
sequence_labels, |
|
token_labels, |
|
choice_labels, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
): |
|
model = BertLMHeadModel(config=config) |
|
model.to(torch_device) |
|
model.eval() |
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) |
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) |
|
|
|
def create_and_check_for_masked_lm( |
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels |
|
): |
|
model = BertForMaskedLM(config=config) |
|
model.to(torch_device) |
|
model.eval() |
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) |
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) |
|
|
|
def create_and_check_model_for_causal_lm_as_decoder( |
|
self, |
|
config, |
|
input_ids, |
|
token_type_ids, |
|
input_mask, |
|
sequence_labels, |
|
token_labels, |
|
choice_labels, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
): |
|
config.add_cross_attention = True |
|
model = BertLMHeadModel(config=config) |
|
model.to(torch_device) |
|
model.eval() |
|
result = model( |
|
input_ids, |
|
attention_mask=input_mask, |
|
token_type_ids=token_type_ids, |
|
labels=token_labels, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
) |
|
result = model( |
|
input_ids, |
|
attention_mask=input_mask, |
|
token_type_ids=token_type_ids, |
|
labels=token_labels, |
|
encoder_hidden_states=encoder_hidden_states, |
|
) |
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) |
|
|
|
def create_and_check_decoder_model_past_large_inputs( |
|
self, |
|
config, |
|
input_ids, |
|
token_type_ids, |
|
input_mask, |
|
sequence_labels, |
|
token_labels, |
|
choice_labels, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
): |
|
config.is_decoder = True |
|
config.add_cross_attention = True |
|
model = BertLMHeadModel(config=config).to(torch_device).eval() |
|
|
|
|
|
outputs = model( |
|
input_ids, |
|
attention_mask=input_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
use_cache=True, |
|
) |
|
past_key_values = outputs.past_key_values |
|
|
|
|
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) |
|
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) |
|
|
|
|
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) |
|
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) |
|
|
|
output_from_no_past = model( |
|
next_input_ids, |
|
attention_mask=next_attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
output_hidden_states=True, |
|
)["hidden_states"][0] |
|
output_from_past = model( |
|
next_tokens, |
|
attention_mask=next_attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
past_key_values=past_key_values, |
|
output_hidden_states=True, |
|
)["hidden_states"][0] |
|
|
|
|
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() |
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() |
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() |
|
|
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) |
|
|
|
|
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) |
|
|
|
def create_and_check_for_next_sequence_prediction( |
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels |
|
): |
|
model = BertForNextSentencePrediction(config=config) |
|
model.to(torch_device) |
|
model.eval() |
|
result = model( |
|
input_ids, |
|
attention_mask=input_mask, |
|
token_type_ids=token_type_ids, |
|
labels=sequence_labels, |
|
) |
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, 2)) |
|
|
|
def create_and_check_for_pretraining( |
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels |
|
): |
|
model = BertForPreTraining(config=config) |
|
model.to(torch_device) |
|
model.eval() |
|
result = model( |
|
input_ids, |
|
attention_mask=input_mask, |
|
token_type_ids=token_type_ids, |
|
labels=token_labels, |
|
next_sentence_label=sequence_labels, |
|
) |
|
self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) |
|
self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2)) |
|
|
|
def create_and_check_for_question_answering( |
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels |
|
): |
|
model = BertForQuestionAnswering(config=config) |
|
model.to(torch_device) |
|
model.eval() |
|
result = model( |
|
input_ids, |
|
attention_mask=input_mask, |
|
token_type_ids=token_type_ids, |
|
start_positions=sequence_labels, |
|
end_positions=sequence_labels, |
|
) |
|
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) |
|
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) |
|
|
|
def create_and_check_for_sequence_classification( |
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels |
|
): |
|
config.num_labels = self.num_labels |
|
model = BertForSequenceClassification(config) |
|
model.to(torch_device) |
|
model.eval() |
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) |
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) |
|
|
|
def create_and_check_for_token_classification( |
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels |
|
): |
|
config.num_labels = self.num_labels |
|
model = BertForTokenClassification(config=config) |
|
model.to(torch_device) |
|
model.eval() |
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) |
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) |
|
|
|
def create_and_check_for_multiple_choice( |
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels |
|
): |
|
config.num_choices = self.num_choices |
|
model = BertForMultipleChoice(config=config) |
|
model.to(torch_device) |
|
model.eval() |
|
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() |
|
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() |
|
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() |
|
result = model( |
|
multiple_choice_inputs_ids, |
|
attention_mask=multiple_choice_input_mask, |
|
token_type_ids=multiple_choice_token_type_ids, |
|
labels=choice_labels, |
|
) |
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices)) |
|
|
|
def prepare_config_and_inputs_for_common(self): |
|
config_and_inputs = self.prepare_config_and_inputs() |
|
( |
|
config, |
|
input_ids, |
|
token_type_ids, |
|
input_mask, |
|
sequence_labels, |
|
token_labels, |
|
choice_labels, |
|
) = config_and_inputs |
|
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} |
|
return config, inputs_dict |
|
|
|
|
|
@require_torch |
|
class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): |
|
all_model_classes = ( |
|
( |
|
BertModel, |
|
BertLMHeadModel, |
|
BertForMaskedLM, |
|
BertForMultipleChoice, |
|
BertForNextSentencePrediction, |
|
BertForPreTraining, |
|
BertForQuestionAnswering, |
|
BertForSequenceClassification, |
|
BertForTokenClassification, |
|
) |
|
if is_torch_available() |
|
else () |
|
) |
|
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () |
|
pipeline_model_mapping = ( |
|
{ |
|
"feature-extraction": BertModel, |
|
"fill-mask": BertForMaskedLM, |
|
"question-answering": BertForQuestionAnswering, |
|
"text-classification": BertForSequenceClassification, |
|
"text-generation": BertLMHeadModel, |
|
"token-classification": BertForTokenClassification, |
|
"zero-shot": BertForSequenceClassification, |
|
} |
|
if is_torch_available() |
|
else {} |
|
) |
|
fx_compatible = True |
|
|
|
|
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): |
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) |
|
|
|
if return_labels: |
|
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING): |
|
inputs_dict["labels"] = torch.zeros( |
|
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device |
|
) |
|
inputs_dict["next_sentence_label"] = torch.zeros( |
|
self.model_tester.batch_size, dtype=torch.long, device=torch_device |
|
) |
|
return inputs_dict |
|
|
|
def setUp(self): |
|
self.model_tester = BertModelTester(self) |
|
self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37) |
|
|
|
def test_config(self): |
|
self.config_tester.run_common_tests() |
|
|
|
def test_model(self): |
|
config_and_inputs = self.model_tester.prepare_config_and_inputs() |
|
self.model_tester.create_and_check_model(*config_and_inputs) |
|
|
|
def test_model_various_embeddings(self): |
|
config_and_inputs = self.model_tester.prepare_config_and_inputs() |
|
for type in ["absolute", "relative_key", "relative_key_query"]: |
|
config_and_inputs[0].position_embedding_type = type |
|
self.model_tester.create_and_check_model(*config_and_inputs) |
|
|
|
def test_model_as_decoder(self): |
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() |
|
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) |
|
|
|
def test_model_as_decoder_with_default_input_mask(self): |
|
|
|
( |
|
config, |
|
input_ids, |
|
token_type_ids, |
|
input_mask, |
|
sequence_labels, |
|
token_labels, |
|
choice_labels, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
) = self.model_tester.prepare_config_and_inputs_for_decoder() |
|
|
|
input_mask = None |
|
|
|
self.model_tester.create_and_check_model_as_decoder( |
|
config, |
|
input_ids, |
|
token_type_ids, |
|
input_mask, |
|
sequence_labels, |
|
token_labels, |
|
choice_labels, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
) |
|
|
|
def test_for_causal_lm(self): |
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() |
|
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) |
|
|
|
def test_for_masked_lm(self): |
|
config_and_inputs = self.model_tester.prepare_config_and_inputs() |
|
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) |
|
|
|
def test_for_causal_lm_decoder(self): |
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() |
|
self.model_tester.create_and_check_model_for_causal_lm_as_decoder(*config_and_inputs) |
|
|
|
def test_decoder_model_past_with_large_inputs(self): |
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() |
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) |
|
|
|
def test_decoder_model_past_with_large_inputs_relative_pos_emb(self): |
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() |
|
config_and_inputs[0].position_embedding_type = "relative_key" |
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) |
|
|
|
def test_for_multiple_choice(self): |
|
config_and_inputs = self.model_tester.prepare_config_and_inputs() |
|
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) |
|
|
|
def test_for_next_sequence_prediction(self): |
|
config_and_inputs = self.model_tester.prepare_config_and_inputs() |
|
self.model_tester.create_and_check_for_next_sequence_prediction(*config_and_inputs) |
|
|
|
def test_for_pretraining(self): |
|
config_and_inputs = self.model_tester.prepare_config_and_inputs() |
|
self.model_tester.create_and_check_for_pretraining(*config_and_inputs) |
|
|
|
def test_for_question_answering(self): |
|
config_and_inputs = self.model_tester.prepare_config_and_inputs() |
|
self.model_tester.create_and_check_for_question_answering(*config_and_inputs) |
|
|
|
def test_for_sequence_classification(self): |
|
config_and_inputs = self.model_tester.prepare_config_and_inputs() |
|
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) |
|
|
|
def test_for_token_classification(self): |
|
config_and_inputs = self.model_tester.prepare_config_and_inputs() |
|
self.model_tester.create_and_check_for_token_classification(*config_and_inputs) |
|
|
|
def test_for_warning_if_padding_and_no_attention_mask(self): |
|
( |
|
config, |
|
input_ids, |
|
token_type_ids, |
|
input_mask, |
|
sequence_labels, |
|
token_labels, |
|
choice_labels, |
|
) = self.model_tester.prepare_config_and_inputs() |
|
|
|
|
|
input_ids[0, 0] = config.pad_token_id |
|
|
|
|
|
logger = logging.get_logger("transformers.modeling_utils") |
|
|
|
logger.warning_once.cache_clear() |
|
|
|
with CaptureLogger(logger) as cl: |
|
model = BertModel(config=config) |
|
model.to(torch_device) |
|
model.eval() |
|
model(input_ids, attention_mask=None, token_type_ids=token_type_ids) |
|
self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out) |
|
|
|
@slow |
|
def test_model_from_pretrained(self): |
|
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: |
|
model = BertModel.from_pretrained(model_name) |
|
self.assertIsNotNone(model) |
|
|
|
@slow |
|
@require_torch_gpu |
|
def test_torchscript_device_change(self): |
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
|
for model_class in self.all_model_classes: |
|
|
|
if model_class == BertForMultipleChoice: |
|
return |
|
|
|
config.torchscript = True |
|
model = model_class(config=config) |
|
|
|
inputs_dict = self._prepare_for_class(inputs_dict, model_class) |
|
traced_model = torch.jit.trace( |
|
model, (inputs_dict["input_ids"].to("cpu"), inputs_dict["attention_mask"].to("cpu")) |
|
) |
|
|
|
with tempfile.TemporaryDirectory() as tmp: |
|
torch.jit.save(traced_model, os.path.join(tmp, "bert.pt")) |
|
loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device) |
|
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device)) |
|
|
|
|
|
@require_torch |
|
class BertModelIntegrationTest(unittest.TestCase): |
|
@slow |
|
def test_inference_no_head_absolute_embedding(self): |
|
model = BertModel.from_pretrained("bert-base-uncased") |
|
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) |
|
attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) |
|
with torch.no_grad(): |
|
output = model(input_ids, attention_mask=attention_mask)[0] |
|
expected_shape = torch.Size((1, 11, 768)) |
|
self.assertEqual(output.shape, expected_shape) |
|
expected_slice = torch.tensor([[[0.4249, 0.1008, 0.7531], [0.3771, 0.1188, 0.7467], [0.4152, 0.1098, 0.7108]]]) |
|
|
|
self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4)) |
|
|
|
@slow |
|
def test_inference_no_head_relative_embedding_key(self): |
|
model = BertModel.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key") |
|
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) |
|
attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) |
|
with torch.no_grad(): |
|
output = model(input_ids, attention_mask=attention_mask)[0] |
|
expected_shape = torch.Size((1, 11, 768)) |
|
self.assertEqual(output.shape, expected_shape) |
|
expected_slice = torch.tensor( |
|
[[[0.0756, 0.3142, -0.5128], [0.3761, 0.3462, -0.5477], [0.2052, 0.3760, -0.1240]]] |
|
) |
|
|
|
self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4)) |
|
|
|
@slow |
|
def test_inference_no_head_relative_embedding_key_query(self): |
|
model = BertModel.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key-query") |
|
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) |
|
attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) |
|
with torch.no_grad(): |
|
output = model(input_ids, attention_mask=attention_mask)[0] |
|
expected_shape = torch.Size((1, 11, 768)) |
|
self.assertEqual(output.shape, expected_shape) |
|
expected_slice = torch.tensor( |
|
[[[0.6496, 0.3784, 0.8203], [0.8148, 0.5656, 0.2636], [-0.0681, 0.5597, 0.7045]]] |
|
) |
|
|
|
self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4)) |
|
|