|
from typing import Type, Union |
|
|
|
from transformers import AutoConfig, PretrainedConfig |
|
|
|
|
|
def register_to_hf_auto_config( |
|
config_class: Type[PretrainedConfig], |
|
) -> Type[PretrainedConfig]: |
|
AutoConfig.register(config_class.model_type, config_class) |
|
return config_class |
|
|
|
|
|
@register_to_hf_auto_config |
|
class LengthControlLMConfig(PretrainedConfig): |
|
model_type = "length_control_lm" |
|
|
|
def __init__( |
|
self, |
|
max_length_levels: int = 256, |
|
length_embed_hidden_dim: int = 512, |
|
model_name: str = "openai-community/gpt2", |
|
**kwargs, |
|
): |
|
"""constructor of config object for GPT2LengthControl |
|
|
|
Args: |
|
max_length_levels (int, optional): |
|
max length (tokens or levels) of a generated caption. Defaults to 256. |
|
length_embed_hidden_dim (int, optional): |
|
dimensions of nonlinear length embedding (dimensions of embedding MLP layers). Defaults to 512. |
|
""" |
|
super().__init__(**kwargs) |
|
self.max_length_levels = max_length_levels |
|
self.length_embed_hidden_dim = length_embed_hidden_dim |
|
self.model_name = model_name |
|
|
|
auto_config = AutoConfig.from_pretrained(model_name) |
|
for key, value in vars(auto_config).items(): |
|
setattr(self, key, value) |
|
self.hidden_size = auto_config.hidden_size |
|
self.num_attention_heads = auto_config.num_attention_heads |
|
self.is_decoder = True |
|
self.add_cross_attention = True |
|
if self.pad_token_id is None: |
|
self.pad_token_id = self.eos_token_id |
|
|