File size: 6,661 Bytes
8ebda9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from fengshen.models.DAVAE.DAVAEModel import DAVAEModel
from fengshen.models.PPVAE.utils import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Encoder(nn.Module):
def __init__(self, latent_dim=128, bottle_dim=20) -> None:
super().__init__()
self.fc1 = nn.Linear(latent_dim, latent_dim//2)
self.fc2 = nn.Linear(latent_dim//2, latent_dim//4)
self.mean = nn.Linear(latent_dim//4, bottle_dim)
self.log_var = nn.Linear(latent_dim//4, bottle_dim)
def kl_loss(self, mean, log_var):
return (-0.5 * (1 + log_var - mean**2 - log_var.exp()).sum(-1)).mean()
def sampling(self, mean, log_var):
epsilon = torch.randn(mean.shape[0], mean.shape[-1], device=mean.device)
return mean + (log_var / 2).exp() * epsilon.unsqueeze(1)
def forward(self, z):
'''
:param z: shape (b, latent_dim)
'''
z = self.fc1(z)
z = F.leaky_relu(z)
z = F.leaky_relu(self.fc2(z))
z_mean = self.mean(z)
z_log_var = self.log_var(z)
kl_loss = self.kl_loss(z_mean, z_log_var)
enc_z = self.sampling(z_mean, z_log_var)
if not self.training:
enc_z = z_mean
return enc_z, kl_loss
class Decoder(nn.Module):
def __init__(self, latent_dim=128, bottle_dim=20) -> None:
super().__init__()
self.fc1 = nn.Linear(bottle_dim, latent_dim//4)
self.fc2 = nn.Linear(latent_dim//4, latent_dim//2)
self.fc3 = nn.Linear(latent_dim//2, latent_dim)
def forward(self, enc_z):
z = F.leaky_relu(self.fc1(enc_z))
z = F.leaky_relu(self.fc2(z))
z = self.fc3(z)
return z
class PluginVAE(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.kl_weight = config.kl_weight
self.beta = config.beta
self.encoder = Encoder(config.latent_dim, config.bottle_dim)
self.decoder = Decoder(config.latent_dim, config.bottle_dim)
def set_beta(self, beta):
self.beta = beta
def forward(self, z):
enc_z, kl_loss = self.encoder(z)
z_out = self.decoder(enc_z)
return z_out, kl_loss
def loss(self, z):
z_out, kl_loss = self.forward(z)
z_loss = ((z_out-z)**2).mean()
loss = z_loss + self.kl_weight * (kl_loss-self.beta).abs()
return loss, kl_loss
class PPVAEPretrainedModel(PreTrainedModel):
def _init_weights(self, module):
""" Initialize the weights """
pass # to bypass the not implement error
class PPVAEModel(PPVAEPretrainedModel):
config_class = PretrainedConfig
def __init__(self, config:PretrainedConfig) -> None:
super().__init__(config=config)
self.config =config
self.pluginvae = PluginVAE(self.config)
self.vae_model = DAVAEModel(self.config)
def train_plugin(self,encoder_tokenizer,decoder_tokenizer,input_texts,negative_samples=None):
# 输入:pluginVAE,label,train_data_dict
# 输出:pluginVAE
self.vae_model.set_tokenizers(encoder_tokenizer,decoder_tokenizer)
pos=self.get_latent(input_texts)
pos_batch_size = self.config.batch_size
total_epoch = self.config.total_epoch
pos_dataset = CustomDataset(pos)
pos_dataloader = DataLoader(
pos_dataset,
batch_size=pos_batch_size,
shuffle=True
)
neg =None
if negative_samples is not None:
neg=self.get_latent(negative_samples)
neg_batch_size = int(pos_batch_size*(neg.shape[0]/pos.shape[0]))
neg_dataset = CustomDataset(neg)
neg_dataloader = DataLoader(
neg_dataset,
batch_size=neg_batch_size,
shuffle=True
)
optimizer = torch.optim.Adam(
params=self.pluginvae.parameters(),
lr=self.config.ppvae_lr, betas=(self.config.mu, self.config.nu)
)
gamma = self.config.gamma
iter_num = 0
early_stopper = EarlyStopping()
min_loss = 10.0
for epoch in range(total_epoch):
self.pluginvae.train()
total_pos_loss = 0.0
total_neg_loss = 0.0
total_loss = 0.0
total_pos_kl = 0.0
for i, data in enumerate(pos_dataloader):
if self.config.get_dymanic_beta:
self.pluginvae.set_beta(self.get_beta_weight(iter_num,self.config.beta,self.config.beta_total_step))
iter_num += 1
pos_loss,pos_kl = self.pluginvae.loss(data)
neg_loss = 0.0
if neg is not None:
neg_data = next(iter(neg_dataloader))
neg_loss,loss_kl = self.pluginvae.loss(neg_data)
if neg_loss.item()>self.config.neg_loss_threshold*pos_loss.item():
# print("neg_loss exceed, detached")
neg_loss = neg_loss.detach()
total_neg_loss += neg_loss.item()
loss = pos_loss - gamma*neg_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_pos_loss += pos_loss.item()
total_loss += loss.item()
total_pos_kl += pos_kl.item()
avg_loss = total_loss/len(pos_dataloader)
avg_kl_loss = total_pos_kl/len(pos_dataloader)
if avg_loss<min_loss:
min_loss = avg_loss
early_stopper.counter = 0
early_stopper(avg_loss, min_loss)
if early_stopper.early_stop:
# print(f"stop training at epoch {epoch}")
break
def generate(self,n):
latent_z = self.gen_latent(n)
text_analogy = self.vae_model.text_from_latent_code_batch(latent_z)
return text_analogy
def get_latent(self,texts):
latent = self.vae_model.latent_code_from_text_batch(texts)
return latent
def gen_latent(self,gen_num=5):
random_vec = torch.randn((gen_num, self.config.bottle_dim)).to(device)
with torch.no_grad():
g_vec = self.pluginvae.decoder(random_vec)
return g_vec
def get_beta_weight(self,iter_num,beta,total_step):
now_beta_weight = min((beta/total_step)*iter_num, beta)
return now_beta_weight
|