|
|
|
|
|
from dataclasses import dataclass |
|
import json |
|
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
|
|
|
def load_config_from_json(config_file): |
|
with open(config_file, "r") as f: |
|
config = json.load(f) |
|
config = RetNetConfig.from_dict(config) |
|
return config |
|
|
|
|
|
@dataclass |
|
class RetNetConfig(PretrainedConfig): |
|
model_type = "retnet" |
|
initializer_range: float = 0.02 |
|
activation_fn: str = "gelu" |
|
dropout: float = 0.0 |
|
activation_dropout: float = 0.0 |
|
drop_path_rate: float = 0.0 |
|
decoder_embed_dim: int = 768 |
|
decoder_value_embed_dim: int = 1280 |
|
decoder_ffn_embed_dim: int = 1280 |
|
decoder_layers: int = 12 |
|
decoder_retention_heads: int = 3 |
|
decoder_normalize_before: bool = True |
|
layernorm_embedding: bool = False |
|
no_scale_embedding: bool = True |
|
recurrent_chunk_size: int = 512 |
|
use_lm_decay: bool = False |
|
use_glu: bool = True |
|
z_loss_coeff: float = 0.0 |
|
deepnorm: bool = False |
|
subln: bool = True |
|
use_ffn_rms_norm: bool = False |
|
layernorm_eps: float = 1e-6 |
|
tie_word_embeddings: bool = False |
|
|
|
def __init__( |
|
self, |
|
vocab_size: int = 50257, |
|
initializer_range: float = 0.02, |
|
is_decoder: bool = True, |
|
pad_token_id: int = 0, |
|
eos_token_id: int = 0, |
|
output_retentions: bool = False, |
|
use_cache: bool = True, |
|
forward_impl: str = "parallel", |
|
activation_fn: str = "gelu", |
|
dropout: float = 0.0, |
|
activation_dropout: float = 0.0, |
|
drop_path_rate: float = 0.0, |
|
decoder_embed_dim: int = 768, |
|
decoder_value_embed_dim: int = 1280, |
|
decoder_ffn_embed_dim: int = 1280, |
|
decoder_layers: int = 12, |
|
decoder_retention_heads: int = 3, |
|
decoder_normalize_before: bool = True, |
|
layernorm_embedding: bool = False, |
|
no_scale_embedding: bool = True, |
|
recurrent_chunk_size: int = 512, |
|
use_glu: bool = True, |
|
z_loss_coeff: float = 0.0, |
|
use_lm_decay: bool = False, |
|
deepnorm: bool = False, |
|
subln: bool = True, |
|
use_ffn_rms_norm: bool = False, |
|
layernorm_eps: float = 1e-6, |
|
tie_word_embeddings: bool = False, |
|
**kwargs |
|
): |
|
self.vocab_size = vocab_size |
|
self.initializer_range = initializer_range |
|
self.output_retentions = output_retentions |
|
self.use_lm_decay = use_lm_decay |
|
self.use_glu = use_glu |
|
self.z_loss_coeff = z_loss_coeff |
|
|
|
self.decoder_embed_dim = decoder_embed_dim |
|
self.decoder_value_embed_dim = decoder_value_embed_dim |
|
self.decoder_retention_heads = decoder_retention_heads |
|
self.decoder_ffn_embed_dim = decoder_ffn_embed_dim |
|
self.decoder_layers = decoder_layers |
|
|
|
self.decoder_normalize_before = decoder_normalize_before |
|
self.activation_fn = activation_fn |
|
self.dropout = dropout |
|
self.drop_path_rate = drop_path_rate |
|
self.activation_dropout = activation_dropout |
|
self.no_scale_embedding = no_scale_embedding |
|
self.layernorm_embedding = layernorm_embedding |
|
self.deepnorm = deepnorm |
|
self.subln = subln |
|
self.use_ffn_rms_norm = use_ffn_rms_norm |
|
self.layernorm_eps = layernorm_eps |
|
|
|
self.recurrent_chunk_size = recurrent_chunk_size |
|
self.forward_impl = forward_impl |
|
|
|
if self.deepnorm: |
|
self.decoder_normalize_before = False |
|
self.subln = False |
|
if self.subln: |
|
self.decoder_normalize_before = True |
|
self.deepnorm = False |
|
|
|
super().__init__( |
|
is_decoder=is_decoder, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
use_cache=use_cache, |
|
tie_word_embeddings=tie_word_embeddings, |
|
**kwargs |
|
) |
|
|
|
def override(self, args): |
|
for hp in self.__dict__.keys(): |
|
if getattr(args, hp, None) is not None: |
|
self.__dict__[hp] = getattr(args, hp, None) |
|
|