File size: 5,194 Bytes
2a1e94d d8bd233 2a1e94d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
"""
This module includes all the classes and functions for the nested autoencoder.
"""
from transformers import PreTrainedModel
from transformers import T5ForConditionalGeneration, AutoModelForSeq2SeqLM
import datasets
import torch
import torch.nn.functional as F
from torch import nn
import random
import os
from .configuration_detime import DeTiMEAutoConfig
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Define the CNN encoder and decoder model
class CNNEncoder(nn.Module):
def __init__(self, hidden_size1, hidden_size3):
super().__init__()
# Define the encoder
self.encoder = nn.Sequential(
nn.Conv1d(in_channels=hidden_size1, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv1d(in_channels=128, out_channels=16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
# nn.Conv1d(in_channels=16, out_channels=4, kernel_size=3, stride=1, padding=1),
# nn.ReLU(),
nn.Conv1d(in_channels=16, out_channels=hidden_size3, kernel_size=3, stride=1, padding=1)
)
def forward(self, x):
# x = x.permute(0, 2, 1)
# Encode the input
encoded = self.encoder(x)
return encoded
class CNNDecoder(nn.Module):
def __init__(self, hidden_size1, hidden_size3) -> None:
super().__init__()
# Define the decoder
self.decoder = nn.Sequential(
nn.Conv1d(in_channels=hidden_size3, out_channels=16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv1d(in_channels=16, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
# nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
# nn.ReLU(),
nn.Conv1d(in_channels=128, out_channels=hidden_size1, kernel_size=3, stride=1, padding=1),
nn.Sigmoid()
)
def forward(self, x):
# Decode the encoding
decoded = self.decoder(x)
# decoded = decoded.permute(0, 2, 1)
return decoded
class DeTiME(PreTrainedModel):
config_class = DeTiMEAutoConfig
def __init__(self, config):
super().__init__(config)
#change t5-small to config
model_name_or_path = config.model
# peft_config = PrefixTuningConfig(peft_type="PREFIX_TUNING", task_type=TaskType.SEQ_2_SEQ_LM,
# inference_mode=False, num_virtual_tokens=10)
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
# model = get_peft_model(model, peft_config)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
#model.print_trainable_parameters()
self.model = model
self.config_model = 'CNN'
if self.config_model == 'CNN':
# self.model = T5ForConditionalGeneration.from_pretrained("t5-small")
self.encoder = CNNEncoder(
config.hidden_size1, config.hidden_size3)
self.decoder = CNNDecoder(
config.hidden_size1, config.hidden_size3)
self.encoder.main_input_name = self.model.main_input_name
self.encoder.main_input_name = self.model.main_input_name
self.main_input_name = self.model.main_input_name
def forward(self, input_ids, attention_mask, labels, **kwargs):
output = self.model.encoder(
input_ids=input_ids, attention_mask=attention_mask).last_hidden_state #batch size * seq length * embedding size,
#print(output.shape)
if self.config_model == 'CNN':
encoder_output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size
#print(encoder_output.shape)
output = self.decoder(encoder_output) #1 batch_size, hidden_size
return self.model.forward(input_ids=input_ids.contiguous(), encoder_outputs=(output.contiguous(), ), labels=labels.contiguous(), **kwargs)
def generate(self, input_ids, attention_mask, **kwargs):
output = self.model.encoder(
input_ids=input_ids, attention_mask=attention_mask).last_hidden_state #batch size * seq length * embedding size,
#print(output.shape)
# encoder_output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size
# #print(encoder_output.shape)
if self.config_model == 'CNN':
encoder_output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size
#print(encoder_output.shape)
output = self.decoder(encoder_output) #1 batch_size, hidden_size
elif self.config_model == 'RNN':
output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size
# output = self.decoder(encoder_output) #1 batch_size, hidden_size
return self.model.generate(input_ids=input_ids.contiguous(), encoder_outputs=(output.contiguous(), ), **kwargs)
|