from transformers import T5Config POSITION_ENCODING_REL_T5_BIAS = "t5_relative_bias" POSITION_ENCODING_REL_TRANSFORMER_XL = "transformer_xl_relative_encoding" POSITION_ENCODING_ROTARY = "rotary" POSITION_ENCODING_ROTARY_RERUN = "rotary_rerun" POSITION_ENCODING_ROTARY_NEW = "new_rotary" POSITION_ENCODING_ABS_LEARNED = "abs_learned" POSITION_ENCODING_ABS_SINUSOID = "abs_sinusoid" POSITION_ENCODING_ALiBi = "alibi" POSITION_ENCODING_ALiBi_LEARNED = "alibi_learned" POSITION_ENCODING_NONE = "none" POSITION_ENCODING_NONE_WINDOW = "none_window" class CustomT5Config(T5Config): model_type = "custom_decoder_only_t5" def __init__( self, position_encoding_type=POSITION_ENCODING_REL_T5_BIAS, **kwargs, ): if position_encoding_type not in [ POSITION_ENCODING_ALiBi, POSITION_ENCODING_ALiBi_LEARNED, POSITION_ENCODING_ABS_LEARNED, POSITION_ENCODING_ABS_SINUSOID, POSITION_ENCODING_REL_T5_BIAS, POSITION_ENCODING_REL_TRANSFORMER_XL, POSITION_ENCODING_ROTARY, POSITION_ENCODING_ROTARY_NEW, POSITION_ENCODING_NONE, POSITION_ENCODING_NONE_WINDOW, ]: raise ValueError( f"Invalid position_encoding_type: {position_encoding_type}" ) self.position_encoding_type = position_encoding_type super().__init__(**kwargs)