File size: 5,194 Bytes
2a1e94d
 
 
 
 
 
 
 
 
 
 
 
d8bd233
2a1e94d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
This module includes all the classes and functions for the nested autoencoder.
"""

from transformers import PreTrainedModel
from transformers import T5ForConditionalGeneration, AutoModelForSeq2SeqLM
import datasets
import torch
import torch.nn.functional as F
from torch import nn
import random
import os
from .configuration_detime import DeTiMEAutoConfig
os.environ["TOKENIZERS_PARALLELISM"] = "false"



# Define the CNN encoder and decoder model
class CNNEncoder(nn.Module):
    def __init__(self, hidden_size1, hidden_size3):
        super().__init__()
        # Define the encoder
        self.encoder = nn.Sequential(
            nn.Conv1d(in_channels=hidden_size1, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(in_channels=128, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
#             nn.Conv1d(in_channels=16, out_channels=4, kernel_size=3, stride=1, padding=1),
#             nn.ReLU(),
            nn.Conv1d(in_channels=16, out_channels=hidden_size3, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x):
        # x = x.permute(0, 2, 1)
        # Encode the input
        encoded = self.encoder(x)
        return encoded

class CNNDecoder(nn.Module):
        def __init__(self, hidden_size1, hidden_size3) -> None:
            super().__init__()

            # Define the decoder
            self.decoder = nn.Sequential(
                nn.Conv1d(in_channels=hidden_size3, out_channels=16, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.Conv1d(in_channels=16, out_channels=128, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
    #             nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
    #             nn.ReLU(),
                nn.Conv1d(in_channels=128, out_channels=hidden_size1, kernel_size=3, stride=1, padding=1),
                nn.Sigmoid()
            )

        def forward(self, x):
            # Decode the encoding
            decoded = self.decoder(x)
            # decoded = decoded.permute(0, 2, 1)
            return decoded
        


class DeTiME(PreTrainedModel):
    config_class = DeTiMEAutoConfig

    def __init__(self, config):
        super().__init__(config)
        #change t5-small to config 
        model_name_or_path = config.model
        # peft_config = PrefixTuningConfig(peft_type="PREFIX_TUNING", task_type=TaskType.SEQ_2_SEQ_LM, 
        #                                  inference_mode=False, num_virtual_tokens=10)
        # model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
        # model = get_peft_model(model, peft_config)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)

        #model.print_trainable_parameters()
        self.model = model
        self.config_model = 'CNN'
        if self.config_model == 'CNN':
            # self.model = T5ForConditionalGeneration.from_pretrained("t5-small")
            self.encoder = CNNEncoder(
                                config.hidden_size1, config.hidden_size3)
            self.decoder = CNNDecoder(
                                config.hidden_size1, config.hidden_size3)
            self.encoder.main_input_name = self.model.main_input_name


            self.encoder.main_input_name = self.model.main_input_name
        self.main_input_name = self.model.main_input_name

    def forward(self, input_ids, attention_mask, labels, **kwargs):
        output = self.model.encoder(
            input_ids=input_ids, attention_mask=attention_mask).last_hidden_state   #batch size * seq length * embedding size, 
        #print(output.shape)
        if self.config_model == 'CNN':
            encoder_output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size
            #print(encoder_output.shape)
            
            output = self.decoder(encoder_output) #1 batch_size, hidden_size

        return self.model.forward(input_ids=input_ids.contiguous(), encoder_outputs=(output.contiguous(), ), labels=labels.contiguous(),  **kwargs)

    def generate(self, input_ids, attention_mask, **kwargs):
        output = self.model.encoder(
            input_ids=input_ids, attention_mask=attention_mask).last_hidden_state   #batch size * seq length * embedding size, 
        #print(output.shape)
        # encoder_output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size
        # #print(encoder_output.shape)
        if self.config_model == 'CNN':
            encoder_output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size
            #print(encoder_output.shape)
            
            output = self.decoder(encoder_output) #1 batch_size, hidden_size
        elif self.config_model == 'RNN':
            output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size

        # output = self.decoder(encoder_output) #1 batch_size, hidden_size

        return self.model.generate(input_ids=input_ids.contiguous(), encoder_outputs=(output.contiguous(), ),  **kwargs)