File size: 1,587 Bytes
5d86214
2f61275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d86214
2f61275
 
 
 
 
 
 
 
 
 
 
 
 
5d86214
2f61275
5d86214
2f61275
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from typing import Type, Union

from transformers import AutoConfig, PretrainedConfig


def register_to_hf_auto_config(
    config_class: Type[PretrainedConfig],
) -> Type[PretrainedConfig]:
    AutoConfig.register(config_class.model_type, config_class)
    return config_class


@register_to_hf_auto_config
class LengthControlLMConfig(PretrainedConfig):
    model_type = "length_control_lm"

    def __init__(
        self,
        max_length_levels: int = 256,
        length_embed_hidden_dim: int = 512,
        model_name: str = "openai-community/gpt2",
        **kwargs,
    ):
        """constructor of config object for GPT2LengthControl

        Args:
            max_length_levels (int, optional):
                max length (tokens or levels) of a generated caption. Defaults to 256.
            length_embed_hidden_dim (int, optional):
                dimensions of nonlinear length embedding (dimensions of embedding MLP layers). Defaults to 512.
        """
        super().__init__(**kwargs)
        self.max_length_levels = max_length_levels
        self.length_embed_hidden_dim = length_embed_hidden_dim
        self.model_name = model_name

        auto_config = AutoConfig.from_pretrained(model_name)
        for key, value in vars(auto_config).items():
            setattr(self, key, value)
        self.hidden_size = auto_config.hidden_size
        self.num_attention_heads = auto_config.num_attention_heads
        self.is_decoder = True
        self.add_cross_attention = True
        if self.pad_token_id is None:
            self.pad_token_id = self.eos_token_id