xwjzds commited on
Commit
2a1e94d
1 Parent(s): 81e759b

Upload DeTiME

Browse files
Files changed (4) hide show
  1. config.json +39 -0
  2. configuration_detime.py +26 -0
  3. modeling_detime.py +120 -0
  4. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DeTiME"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_detime.DeTiMEAutoConfig",
7
+ "AutoModel": "modeling_detime.DeTiME"
8
+ },
9
+ "d_ff": 2048,
10
+ "d_kv": 64,
11
+ "d_model": 512,
12
+ "dense_act_fn": "relu",
13
+ "dropout": 0.1,
14
+ "dropout_rate": 0.1,
15
+ "eos_token_id": 1,
16
+ "feed_forward_proj": "relu",
17
+ "hidden_size1": 512,
18
+ "hidden_size2": 768,
19
+ "hidden_size3": 4,
20
+ "initializer_factor": 1.0,
21
+ "is_encoder_decoder": true,
22
+ "is_gated_act": false,
23
+ "layer_norm_epsilon": 1e-06,
24
+ "model": "google/flan-t5-large",
25
+ "model_name": null,
26
+ "model_type": "detime",
27
+ "num_decoder_layers": 6,
28
+ "num_heads": 8,
29
+ "num_layer": 1,
30
+ "num_layers": 6,
31
+ "output_size": 3072,
32
+ "pad_token_id": 0,
33
+ "relative_attention_max_distance": 128,
34
+ "relative_attention_num_buckets": 32,
35
+ "torch_dtype": "float32",
36
+ "transformers_version": "4.30.1",
37
+ "use_cache": true,
38
+ "vocab_size": 32128
39
+ }
configuration_detime.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import T5Config, PretrainedConfig
3
+ from typing import List
4
+
5
+
6
+ # define Flan-T5 nest CNN autoencoder here
7
+ class DeTiMEAutoConfig(T5Config):
8
+ model_type = "detime"
9
+
10
+ def __init__(
11
+ self,
12
+ hidden_size1: int = 512,
13
+ hidden_size3: int = 512,
14
+ num_layer: int = 1,
15
+ dropout: float = 0.1,
16
+ max_length: int = 512,
17
+ model_name: str = None,
18
+ **kwargs,
19
+ ):
20
+ self.hidden_size1 = hidden_size1
21
+ self.hidden_size3 = hidden_size3
22
+ self.num_layer = num_layer
23
+ self.dropout = dropout
24
+ self.max_length = max_length
25
+ self.model_name = model_name
26
+ super().__init__(**kwargs)
modeling_detime.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module includes all the classes and functions for the nested autoencoder.
3
+ """
4
+
5
+ from transformers import PreTrainedModel
6
+ from transformers import T5ForConditionalGeneration, AutoModelForSeq2SeqLM
7
+ import datasets
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+ import random
12
+ import os
13
+ from configuration_detime import DeTiMEAutoConfig
14
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
15
+
16
+
17
+
18
+ # Define the CNN encoder and decoder model
19
+ class CNNEncoder(nn.Module):
20
+ def __init__(self, hidden_size1, hidden_size3):
21
+ super().__init__()
22
+ # Define the encoder
23
+ self.encoder = nn.Sequential(
24
+ nn.Conv1d(in_channels=hidden_size1, out_channels=128, kernel_size=3, stride=1, padding=1),
25
+ nn.ReLU(),
26
+ nn.Conv1d(in_channels=128, out_channels=16, kernel_size=3, stride=1, padding=1),
27
+ nn.ReLU(),
28
+ # nn.Conv1d(in_channels=16, out_channels=4, kernel_size=3, stride=1, padding=1),
29
+ # nn.ReLU(),
30
+ nn.Conv1d(in_channels=16, out_channels=hidden_size3, kernel_size=3, stride=1, padding=1)
31
+ )
32
+
33
+ def forward(self, x):
34
+ # x = x.permute(0, 2, 1)
35
+ # Encode the input
36
+ encoded = self.encoder(x)
37
+ return encoded
38
+
39
+ class CNNDecoder(nn.Module):
40
+ def __init__(self, hidden_size1, hidden_size3) -> None:
41
+ super().__init__()
42
+
43
+ # Define the decoder
44
+ self.decoder = nn.Sequential(
45
+ nn.Conv1d(in_channels=hidden_size3, out_channels=16, kernel_size=3, stride=1, padding=1),
46
+ nn.ReLU(),
47
+ nn.Conv1d(in_channels=16, out_channels=128, kernel_size=3, stride=1, padding=1),
48
+ nn.ReLU(),
49
+ # nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
50
+ # nn.ReLU(),
51
+ nn.Conv1d(in_channels=128, out_channels=hidden_size1, kernel_size=3, stride=1, padding=1),
52
+ nn.Sigmoid()
53
+ )
54
+
55
+ def forward(self, x):
56
+ # Decode the encoding
57
+ decoded = self.decoder(x)
58
+ # decoded = decoded.permute(0, 2, 1)
59
+ return decoded
60
+
61
+
62
+
63
+ class DeTiME(PreTrainedModel):
64
+ config_class = DeTiMEAutoConfig
65
+
66
+ def __init__(self, config):
67
+ super().__init__(config)
68
+ #change t5-small to config
69
+ model_name_or_path = config.model
70
+ # peft_config = PrefixTuningConfig(peft_type="PREFIX_TUNING", task_type=TaskType.SEQ_2_SEQ_LM,
71
+ # inference_mode=False, num_virtual_tokens=10)
72
+ # model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
73
+ # model = get_peft_model(model, peft_config)
74
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
75
+
76
+ #model.print_trainable_parameters()
77
+ self.model = model
78
+ self.config_model = 'CNN'
79
+ if self.config_model == 'CNN':
80
+ # self.model = T5ForConditionalGeneration.from_pretrained("t5-small")
81
+ self.encoder = CNNEncoder(
82
+ config.hidden_size1, config.hidden_size3)
83
+ self.decoder = CNNDecoder(
84
+ config.hidden_size1, config.hidden_size3)
85
+ self.encoder.main_input_name = self.model.main_input_name
86
+
87
+
88
+ self.encoder.main_input_name = self.model.main_input_name
89
+ self.main_input_name = self.model.main_input_name
90
+
91
+ def forward(self, input_ids, attention_mask, labels, **kwargs):
92
+ output = self.model.encoder(
93
+ input_ids=input_ids, attention_mask=attention_mask).last_hidden_state #batch size * seq length * embedding size,
94
+ #print(output.shape)
95
+ if self.config_model == 'CNN':
96
+ encoder_output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size
97
+ #print(encoder_output.shape)
98
+
99
+ output = self.decoder(encoder_output) #1 batch_size, hidden_size
100
+
101
+ return self.model.forward(input_ids=input_ids.contiguous(), encoder_outputs=(output.contiguous(), ), labels=labels.contiguous(), **kwargs)
102
+
103
+ def generate(self, input_ids, attention_mask, **kwargs):
104
+ output = self.model.encoder(
105
+ input_ids=input_ids, attention_mask=attention_mask).last_hidden_state #batch size * seq length * embedding size,
106
+ #print(output.shape)
107
+ # encoder_output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size
108
+ # #print(encoder_output.shape)
109
+ if self.config_model == 'CNN':
110
+ encoder_output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size
111
+ #print(encoder_output.shape)
112
+
113
+ output = self.decoder(encoder_output) #1 batch_size, hidden_size
114
+ elif self.config_model == 'RNN':
115
+ output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size
116
+
117
+ # output = self.decoder(encoder_output) #1 batch_size, hidden_size
118
+
119
+ return self.model.generate(input_ids=input_ids.contiguous(), encoder_outputs=(output.contiguous(), ), **kwargs)
120
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a9e4b3b793cc18cbca71c6f4b988a6837765380dd02c75bbb8a213891b65d10
3
+ size 3134420082