paraphrase_text_generation / modeling_detime.py
xwjzds's picture
Upload DeTiME
d8bd233 verified
"""
This module includes all the classes and functions for the nested autoencoder.
"""
from transformers import PreTrainedModel
from transformers import T5ForConditionalGeneration, AutoModelForSeq2SeqLM
import datasets
import torch
import torch.nn.functional as F
from torch import nn
import random
import os
from .configuration_detime import DeTiMEAutoConfig
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Define the CNN encoder and decoder model
class CNNEncoder(nn.Module):
def __init__(self, hidden_size1, hidden_size3):
super().__init__()
# Define the encoder
self.encoder = nn.Sequential(
nn.Conv1d(in_channels=hidden_size1, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv1d(in_channels=128, out_channels=16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
# nn.Conv1d(in_channels=16, out_channels=4, kernel_size=3, stride=1, padding=1),
# nn.ReLU(),
nn.Conv1d(in_channels=16, out_channels=hidden_size3, kernel_size=3, stride=1, padding=1)
)
def forward(self, x):
# x = x.permute(0, 2, 1)
# Encode the input
encoded = self.encoder(x)
return encoded
class CNNDecoder(nn.Module):
def __init__(self, hidden_size1, hidden_size3) -> None:
super().__init__()
# Define the decoder
self.decoder = nn.Sequential(
nn.Conv1d(in_channels=hidden_size3, out_channels=16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv1d(in_channels=16, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
# nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
# nn.ReLU(),
nn.Conv1d(in_channels=128, out_channels=hidden_size1, kernel_size=3, stride=1, padding=1),
nn.Sigmoid()
)
def forward(self, x):
# Decode the encoding
decoded = self.decoder(x)
# decoded = decoded.permute(0, 2, 1)
return decoded
class DeTiME(PreTrainedModel):
config_class = DeTiMEAutoConfig
def __init__(self, config):
super().__init__(config)
#change t5-small to config
model_name_or_path = config.model
# peft_config = PrefixTuningConfig(peft_type="PREFIX_TUNING", task_type=TaskType.SEQ_2_SEQ_LM,
# inference_mode=False, num_virtual_tokens=10)
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
# model = get_peft_model(model, peft_config)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
#model.print_trainable_parameters()
self.model = model
self.config_model = 'CNN'
if self.config_model == 'CNN':
# self.model = T5ForConditionalGeneration.from_pretrained("t5-small")
self.encoder = CNNEncoder(
config.hidden_size1, config.hidden_size3)
self.decoder = CNNDecoder(
config.hidden_size1, config.hidden_size3)
self.encoder.main_input_name = self.model.main_input_name
self.encoder.main_input_name = self.model.main_input_name
self.main_input_name = self.model.main_input_name
def forward(self, input_ids, attention_mask, labels, **kwargs):
output = self.model.encoder(
input_ids=input_ids, attention_mask=attention_mask).last_hidden_state #batch size * seq length * embedding size,
#print(output.shape)
if self.config_model == 'CNN':
encoder_output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size
#print(encoder_output.shape)
output = self.decoder(encoder_output) #1 batch_size, hidden_size
return self.model.forward(input_ids=input_ids.contiguous(), encoder_outputs=(output.contiguous(), ), labels=labels.contiguous(), **kwargs)
def generate(self, input_ids, attention_mask, **kwargs):
output = self.model.encoder(
input_ids=input_ids, attention_mask=attention_mask).last_hidden_state #batch size * seq length * embedding size,
#print(output.shape)
# encoder_output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size
# #print(encoder_output.shape)
if self.config_model == 'CNN':
encoder_output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size
#print(encoder_output.shape)
output = self.decoder(encoder_output) #1 batch_size, hidden_size
elif self.config_model == 'RNN':
output = self.encoder(output) #batch size * seq length * embedding size, 1 * batch size * hidden_size
# output = self.decoder(encoder_output) #1 batch_size, hidden_size
return self.model.generate(input_ids=input_ids.contiguous(), encoder_outputs=(output.contiguous(), ), **kwargs)