lc_video_description_videomae_gpt2 / configuration_decoder.py
fztkm's picture
Upload model
5d86214 verified
from typing import Type, Union
from transformers import AutoConfig, PretrainedConfig
def register_to_hf_auto_config(
config_class: Type[PretrainedConfig],
) -> Type[PretrainedConfig]:
AutoConfig.register(config_class.model_type, config_class)
return config_class
@register_to_hf_auto_config
class LengthControlLMConfig(PretrainedConfig):
model_type = "length_control_lm"
def __init__(
self,
max_length_levels: int = 256,
length_embed_hidden_dim: int = 512,
model_name: str = "openai-community/gpt2",
**kwargs,
):
"""constructor of config object for GPT2LengthControl
Args:
max_length_levels (int, optional):
max length (tokens or levels) of a generated caption. Defaults to 256.
length_embed_hidden_dim (int, optional):
dimensions of nonlinear length embedding (dimensions of embedding MLP layers). Defaults to 512.
"""
super().__init__(**kwargs)
self.max_length_levels = max_length_levels
self.length_embed_hidden_dim = length_embed_hidden_dim
self.model_name = model_name
auto_config = AutoConfig.from_pretrained(model_name)
for key, value in vars(auto_config).items():
setattr(self, key, value)
self.hidden_size = auto_config.hidden_size
self.num_attention_heads = auto_config.num_attention_heads
self.is_decoder = True
self.add_cross_attention = True
if self.pad_token_id is None:
self.pad_token_id = self.eos_token_id