Boris Albar commited on
Commit
c69543f
1 Parent(s): 9ed7e39

Upload configuration_flash_t5.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. configuration_flash_t5.py +84 -0
configuration_flash_t5.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from collections import OrderedDict
3
+ from typing import Mapping
4
+ import logging
5
+
6
+ from transformers import T5Config
7
+
8
+ AUTO_MAP = {
9
+ "AutoModel": "modeling_flash_t5.FlashT5ForConditionalGeneration",
10
+ "AutoModelForSeq2SeqLM": "modeling_flash_t5.FlashT5ForConditionalGeneration",
11
+ "AutoModelForTokenClassification": "custom_heads_flash_t5.FlashT5ForTokenClassification",
12
+ "AutoModelForQuestionAnswering": "custom_heads_flash_t5.FlashT5ForQuestionAnswering",
13
+ "AutoModelForSequenceClassification": "custom_heads_flash_t5.FlashT5ForSequenceClassification",
14
+ }
15
+
16
+ class FlashT5Config(T5Config):
17
+
18
+ model_type = "flash_t5"
19
+
20
+ def __init__(
21
+ self,
22
+ decoder_start_token_id=0,
23
+ pad_token_id=-100,
24
+ use_glu_mlp=False,
25
+ position_encoding_type="t5",
26
+ use_randomized_position_encoding=False,
27
+ label_smoothing=0.0,
28
+ z_loss=None,
29
+ attention_type="ref",
30
+ max_sequence_length=1024,
31
+ attention_dropout_rate=0.0,
32
+ alibi_mode="symetric",
33
+ use_triton_layernorm=False,
34
+ use_triton_crossentropy=False,
35
+ use_triton_gated_mlp=False,
36
+ use_gelu_act=True,
37
+ use_full_bias_size=False,
38
+ rotary_emb_fraction=1.0,
39
+ rotary_base=10000,
40
+ rotary_interleaved=False,
41
+ rotary_scale_base=None,
42
+ fire_mlp_width=32,
43
+ use_masking=False,
44
+ attention_scale=None,
45
+ **kwargs,
46
+ ):
47
+ super().__init__(**kwargs)
48
+
49
+ self.decoder_start_token_id = decoder_start_token_id
50
+ self.pad_token_id = pad_token_id
51
+ self.use_glu_mlp = use_glu_mlp
52
+ self.position_encoding_type = position_encoding_type
53
+ self.use_randomized_position_encoding = use_randomized_position_encoding
54
+ self.label_smoothing = label_smoothing
55
+ self.z_loss = z_loss
56
+ self.attention_type = attention_type
57
+ self.max_sequence_length = max_sequence_length
58
+ self.alibi_mode = alibi_mode
59
+ self.attention_dropout_rate = attention_dropout_rate
60
+ self.use_triton_layernorm = use_triton_layernorm
61
+ self.use_triton_crossentropy = use_triton_crossentropy
62
+ self.use_triton_gated_mlp = use_triton_gated_mlp
63
+ self.use_gelu_act = use_gelu_act
64
+ self.use_full_bias_size = use_full_bias_size
65
+ self.rotary_base = rotary_base
66
+ self.rotary_interleaved = rotary_interleaved
67
+ self.rotary_scale_base = rotary_scale_base
68
+ self.rotary_emb_fraction = rotary_emb_fraction
69
+ self.fire_mlp_width = fire_mlp_width
70
+ self.use_masking = use_masking
71
+ self.attention_scale = attention_scale
72
+
73
+ self.auto_map = AUTO_MAP
74
+
75
+ def str_to_class(classname):
76
+ return getattr(sys.modules[__name__], classname)
77
+
78
+ # Register model in Auto API
79
+ try:
80
+ FlashT5Config.register_for_auto_class()
81
+ for key, value in AUTO_MAP.items():
82
+ str_to_class(value.split(".")[-1]).register_for_auto_class(key)
83
+ except:
84
+ logging.warn("AutoRegister isn't available.")