#!/usr/bin/env python # coding: utf-8 import torch import torch.nn as nn from torch import optim from torch.utils.data import DataLoader, Dataset from torch.optim.lr_scheduler import StepLR from tqdm.auto import tqdm import torch.nn.functional as F import pandas as pd torch.manual_seed(114514) torch.set_default_device('cuda') SOS_token = 1 EOS_token = 2 katakana = list('゠ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロヮワヰヱヲンヴヵヶヷヸヹヺ・ーヽヾヿㇰㇱㇲㇳㇴㇵㇶㇷㇸㇹㇺㇻㇼㇽㇾㇿ') vocab = ['', '', ''] + katakana vocab_dict = {v: k for k, v in enumerate(vocab)} texts = pd.read_csv('rolename.txt', header=None)[0].tolist() vocab_size=len(vocab) h=192 h_latent=64 max_len=40 bs=128 lr=0.2 lr_step_size=32 lr_decay=0.5 momentum=0.9 epochs=192 grad_max_norm=1 def tokenize(text): return [vocab_dict[ch] for ch in text] def detokenize(tokens): if EOS_token in tokens: tokens = tokens[:tokens.index(EOS_token)] return ''.join(vocab[token] for token in tokens) class BatchNormVAE(nn.Module): # https://spaces.ac.cn/archives/7381/ def __init__(self, num_features, **kwargs): super(BatchNormVAE, self).__init__() kwargs['affine'] = False self.TAU = 0.5 self.bn_mu = nn.BatchNorm1d(num_features, **kwargs) self.bn_sigma = nn.BatchNorm1d(num_features, **kwargs) self.theta = nn.Parameter(torch.zeros(1)) def forward(self, mu, sigma): mu = self.bn_mu(mu) sigma = self.bn_sigma(sigma) scale_mu = torch.sqrt(self.TAU + (1 - self.TAU) * F.sigmoid(self.theta)) scale_sigma = torch.sqrt((1 - self.TAU) * F.sigmoid(-self.theta)) return mu*scale_mu, sigma*scale_sigma class EncoderVAEBiGRU(nn.Module): def __init__(self, input_size, hidden_size, latent_size, dropout_p=0.1): super(EncoderVAEBiGRU, self).__init__() self.hidden_size = hidden_size self.embedding = nn.Embedding(input_size, hidden_size) self.gru = nn.GRU(hidden_size, hidden_size, num_layers=2, batch_first=True, bidirectional=True) self.proj_mu = nn.Linear(4 * hidden_size, latent_size) self.proj_sigma = nn.Linear(4 * hidden_size, latent_size) self.dropout = nn.Dropout(dropout_p) self.bn = BatchNormVAE(latent_size) def forward(self, input, input_lengths): input_lengths = input_lengths.to('cpu') embedded = self.dropout(self.embedding(input)) embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True, enforce_sorted=False) _, hidden = self.gru(embedded) hidden = hidden.permute(1, 0, 2).flatten(1, 2) mu = self.proj_mu(hidden) sigma = self.proj_sigma(hidden) # not std, can be negative mu, sigma = self.bn(mu, sigma) return self._reparameterize(mu, sigma), mu, sigma ** 2 def _reparameterize(self, mu, sigma): eps = torch.randn_like(sigma) return eps * sigma + mu # var is sigma^2 class DecoderGRU(nn.Module): def __init__(self, latent_size, hidden_size, output_size): super(DecoderGRU, self).__init__() self.proj1 = nn.Linear(latent_size, latent_size) self.proj_activation = nn.ReLU() self.proj2 = nn.Linear(latent_size, 2 * hidden_size) self.embedding = nn.Embedding(output_size, hidden_size) self.gru = nn.GRU(hidden_size, hidden_size, num_layers=2, batch_first=True) self.out = nn.Linear(hidden_size, output_size) def forward(self, encoder_sample, target_tensor=None, max_length=16): batch_size = encoder_sample.size(0) decoder_hidden = self.proj1(encoder_sample) decoder_hidden = self.proj_activation(decoder_hidden) decoder_hidden = self.proj2(decoder_hidden) decoder_hidden = decoder_hidden.view(batch_size, 2, -1).permute(1, 0, 2).contiguous() if target_tensor is not None: decoder_input = target_tensor decoder_outputs, decoder_hidden = self.forward_step(decoder_input, decoder_hidden) else: decoder_input = torch.empty(batch_size, 1, dtype=torch.long).fill_(SOS_token) decoder_outputs = [] for i in range(max_length): decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden) decoder_outputs.append(decoder_output) _, topi = decoder_output.topk(1) decoder_input = topi.squeeze(-1).detach() decoder_outputs = torch.cat(decoder_outputs, dim=1) decoder_outputs = F.log_softmax(decoder_outputs, dim=-1) return decoder_outputs, decoder_hidden def forward_step(self, input, hidden): output = self.embedding(input) output = F.relu(output) output, hidden = self.gru(output, hidden) output = self.out(output) return output, hidden class KatakanaDataset(Dataset): def __init__(self, texts, tokenizer, max_length): self.texts = texts self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] tokens = self.tokenizer(text) enc_text = tokens enc_len = len(enc_text) input_text = [SOS_token] + tokens target_text = tokens + [EOS_token] enc_text = torch.tensor(enc_text + [0] * (self.max_length - len(enc_text)), dtype=torch.long) input_text = torch.tensor(input_text + [0] * (self.max_length - len(input_text)), dtype=torch.long) target_text = torch.tensor(target_text + [0] * (self.max_length - len(target_text)), dtype=torch.long) return enc_text, enc_len, input_text, target_text dataloader = DataLoader( KatakanaDataset(texts, tokenize, max_len), batch_size=bs, shuffle=True, generator=torch.Generator(device='cuda'), ) def train_epoch(dataloader, encoder, decoder, optimizer, max_norm, norm_p=2): total_loss = 0 nll = nn.NLLLoss() for enc_text, enc_len, input_text, target_text in dataloader: optimizer.zero_grad() encoder_sample, mu, var = encoder(enc_text, enc_len) decoder_outputs, _ = decoder(encoder_sample, input_text) loss_recons = nll(decoder_outputs.view(-1, decoder_outputs.size(-1)), target_text.view(-1)) loss_kld = 0.5 * torch.mean(mu ** 2 + var - var.log() - 1) loss = loss_recons + loss_kld loss.backward() # gradient clipping by norm nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(decoder.parameters()), max_norm, norm_type=norm_p) optimizer.step() total_loss += loss.item() return total_loss / len(dataloader) encoder = EncoderVAEBiGRU(vocab_size, h, h_latent).train() decoder = DecoderGRU(h_latent, h, vocab_size).train() optimizer = optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=lr, momentum=momentum) # momentum scheduler = StepLR(optimizer, step_size=lr_step_size, gamma=lr_decay) with tqdm(range(epochs), desc='Training') as pbar: for i in pbar: pbar.set_postfix(loss=train_epoch(dataloader, encoder, decoder, optimizer, grad_max_norm)) scheduler.step() decoder.eval() for name in [detokenize(seq) for seq in decoder(torch.randn(8,h_latent), max_length=max_len)[0].topk(1)[1].squeeze().tolist()]: print(name) torch.save(decoder, 'decoder.pt')