Spaces:
Running
Running
File size: 4,894 Bytes
6197b2f a96f4dc 6197b2f a5ed112 7e48337 6197b2f 7e48337 a11892f 6197b2f 6f1f2d9 6197b2f a11892f 6197b2f a11892f 972bc8d 1bb3269 6197b2f a11892f a96f4dc a11892f 6197b2f 6f1f2d9 6197b2f 1bb3269 a96f4dc 6197b2f 1bb3269 6197b2f a265819 1bb3269 6197b2f 6f1f2d9 6197b2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
# coding=utf-8
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" DalleBart model configuration """
import warnings
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from .utils import PretrainedFromWandbMixin
logger = logging.get_logger(__name__)
class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
model_type = "dallebart"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"num_attention_heads": "encoder_attention_heads",
"hidden_size": "d_model",
}
def __init__(
self,
normalize_text=False,
encoder_vocab_size=50264,
image_vocab_size=16384, # encoded image token space
image_length=256, # number of encoded tokens
max_text_length=64, # max number of text tokens
encoder_layers=12,
encoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
activation_function="gelu",
d_model=1024,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
classifier_dropout=0.0,
scale_embedding=False,
gradient_checkpointing=False,
use_cache=True,
is_encoder_decoder=True,
forced_eos_token_id=None,
tie_word_embeddings=False, # different modalities and sizes
do_sample=True,
**kwargs,
):
self.normalize_text = normalize_text
self.encoder_vocab_size = encoder_vocab_size
self.image_vocab_size = image_vocab_size
self.image_length = image_length
self.max_text_length = max_text_length
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = (
scale_embedding # scale factor will be sqrt(d_model) if True
)
# special token id's are appended to vocab if not provided
decoder_start_token_id = kwargs.pop("decoder_start_token_id", image_vocab_size)
bos_token_id = kwargs.pop("bos_token_id", image_vocab_size)
pad_token_id = kwargs.pop("pad_token_id", image_vocab_size)
eos_token_id = kwargs.pop("eos_token_id", image_vocab_size)
# we generate to image_length + 1 (for bos) by default
min_length = kwargs.pop("min_length", image_length + 1)
max_length = kwargs.pop("max_length", image_length + 1)
super().__init__(
# args required in parent class
is_encoder_decoder=is_encoder_decoder,
tie_word_embeddings=tie_word_embeddings,
forced_eos_token_id=forced_eos_token_id,
decoder_start_token_id=decoder_start_token_id,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
min_length=min_length,
max_length=max_length,
do_sample=do_sample,
**kwargs,
)
# ensure backward compatibility for BART CNN models
if self.forced_bos_token_id is None and kwargs.get(
"force_bos_token_to_be_generated", False
):
self.forced_bos_token_id = self.bos_token_id
warnings.warn(
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions."
"The config can simply be saved and uploaded again to be fixed."
)
|