Spaces:
Paused
Paused
# Copyright 2020 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import unittest | |
from transformers import BigBirdConfig, is_flax_available | |
from transformers.testing_utils import require_flax, slow | |
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask | |
if is_flax_available(): | |
import jax | |
from transformers.models.big_bird.modeling_flax_big_bird import ( | |
FlaxBigBirdForCausalLM, | |
FlaxBigBirdForMaskedLM, | |
FlaxBigBirdForMultipleChoice, | |
FlaxBigBirdForPreTraining, | |
FlaxBigBirdForQuestionAnswering, | |
FlaxBigBirdForSequenceClassification, | |
FlaxBigBirdForTokenClassification, | |
FlaxBigBirdModel, | |
) | |
class FlaxBigBirdModelTester(unittest.TestCase): | |
def __init__( | |
self, | |
parent, | |
batch_size=2, | |
seq_length=56, | |
is_training=True, | |
use_attention_mask=True, | |
use_token_type_ids=True, | |
use_labels=True, | |
vocab_size=99, | |
hidden_size=32, | |
num_hidden_layers=2, | |
num_attention_heads=2, | |
intermediate_size=7, | |
hidden_act="gelu_new", | |
hidden_dropout_prob=0.1, | |
attention_probs_dropout_prob=0.1, | |
max_position_embeddings=512, | |
type_vocab_size=16, | |
type_sequence_label_size=2, | |
initializer_range=0.02, | |
num_choices=4, | |
attention_type="block_sparse", | |
use_bias=True, | |
rescale_embeddings=False, | |
block_size=2, | |
num_random_blocks=3, | |
): | |
self.parent = parent | |
self.batch_size = batch_size | |
self.seq_length = seq_length | |
self.is_training = is_training | |
self.use_attention_mask = use_attention_mask | |
self.use_token_type_ids = use_token_type_ids | |
self.use_labels = use_labels | |
self.vocab_size = vocab_size | |
self.hidden_size = hidden_size | |
self.num_hidden_layers = num_hidden_layers | |
self.num_attention_heads = num_attention_heads | |
self.intermediate_size = intermediate_size | |
self.hidden_act = hidden_act | |
self.hidden_dropout_prob = hidden_dropout_prob | |
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |
self.max_position_embeddings = max_position_embeddings | |
self.type_vocab_size = type_vocab_size | |
self.type_sequence_label_size = type_sequence_label_size | |
self.initializer_range = initializer_range | |
self.num_choices = num_choices | |
self.rescale_embeddings = rescale_embeddings | |
self.attention_type = attention_type | |
self.use_bias = use_bias | |
self.block_size = block_size | |
self.num_random_blocks = num_random_blocks | |
def prepare_config_and_inputs(self): | |
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) | |
attention_mask = None | |
if self.use_attention_mask: | |
attention_mask = random_attention_mask([self.batch_size, self.seq_length]) | |
token_type_ids = None | |
if self.use_token_type_ids: | |
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) | |
config = BigBirdConfig( | |
vocab_size=self.vocab_size, | |
hidden_size=self.hidden_size, | |
num_hidden_layers=self.num_hidden_layers, | |
num_attention_heads=self.num_attention_heads, | |
intermediate_size=self.intermediate_size, | |
hidden_act=self.hidden_act, | |
hidden_dropout_prob=self.hidden_dropout_prob, | |
attention_probs_dropout_prob=self.attention_probs_dropout_prob, | |
max_position_embeddings=self.max_position_embeddings, | |
type_vocab_size=self.type_vocab_size, | |
is_decoder=False, | |
initializer_range=self.initializer_range, | |
attention_type=self.attention_type, | |
block_size=self.block_size, | |
num_random_blocks=self.num_random_blocks, | |
use_bias=self.use_bias, | |
rescale_embeddings=self.rescale_embeddings, | |
) | |
return config, input_ids, token_type_ids, attention_mask | |
def prepare_config_and_inputs_for_common(self): | |
config_and_inputs = self.prepare_config_and_inputs() | |
config, input_ids, token_type_ids, attention_mask = config_and_inputs | |
inputs_dict = { | |
"input_ids": input_ids, | |
"token_type_ids": token_type_ids, | |
"attention_mask": attention_mask, | |
} | |
return config, inputs_dict | |
class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase): | |
all_model_classes = ( | |
( | |
FlaxBigBirdForCausalLM, | |
FlaxBigBirdModel, | |
FlaxBigBirdForPreTraining, | |
FlaxBigBirdForMaskedLM, | |
FlaxBigBirdForMultipleChoice, | |
FlaxBigBirdForQuestionAnswering, | |
FlaxBigBirdForSequenceClassification, | |
FlaxBigBirdForTokenClassification, | |
) | |
if is_flax_available() | |
else () | |
) | |
test_attn_probs = False | |
test_mismatched_shapes = False | |
def setUp(self): | |
self.model_tester = FlaxBigBirdModelTester(self) | |
# copied from `test_modeling_flax_common` because it takes much longer than other models | |
def test_from_pretrained_save_pretrained(self): | |
super().test_from_pretrained_save_pretrained() | |
# copied from `test_modeling_flax_common` because it takes much longer than other models | |
def test_from_pretrained_with_no_automatic_init(self): | |
super().test_from_pretrained_with_no_automatic_init() | |
# copied from `test_modeling_flax_common` because it takes much longer than other models | |
def test_no_automatic_init(self): | |
super().test_no_automatic_init() | |
# copied from `test_modeling_flax_common` because it takes much longer than other models | |
def test_hidden_states_output(self): | |
super().test_hidden_states_output() | |
def test_model_from_pretrained(self): | |
for model_class_name in self.all_model_classes: | |
model = model_class_name.from_pretrained("google/bigbird-roberta-base") | |
self.assertIsNotNone(model) | |
def test_attention_outputs(self): | |
if self.test_attn_probs: | |
super().test_attention_outputs() | |
# copied from `test_modeling_flax_common` because it takes much longer than other models | |
def test_jit_compilation(self): | |
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | |
for model_class in self.all_model_classes: | |
with self.subTest(model_class.__name__): | |
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) | |
model = model_class(config) | |
def model_jitted(input_ids, attention_mask=None, **kwargs): | |
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) | |
with self.subTest("JIT Enabled"): | |
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple() | |
with self.subTest("JIT Disabled"): | |
with jax.disable_jit(): | |
outputs = model_jitted(**prepared_inputs_dict).to_tuple() | |
self.assertEqual(len(outputs), len(jitted_outputs)) | |
for jitted_output, output in zip(jitted_outputs, outputs): | |
self.assertEqual(jitted_output.shape, output.shape) | |
# overwrite from common in order to skip the check on `attentions` | |
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None): | |
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version, | |
# an effort was done to return `attention_probs` (yet to be verified). | |
if name.startswith("outputs.attentions"): | |
return | |
else: | |
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes) | |