|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch import optim |
|
from torch.utils.data import DataLoader, Dataset |
|
from torch.optim.lr_scheduler import StepLR |
|
from tqdm.auto import tqdm |
|
import torch.nn.functional as F |
|
import pandas as pd |
|
|
|
torch.manual_seed(114514) |
|
torch.set_default_device('cuda') |
|
|
|
SOS_token = 1 |
|
EOS_token = 2 |
|
katakana = list('゠ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロヮワヰヱヲンヴヵヶヷヸヹヺ・ーヽヾヿㇰㇱㇲㇳㇴㇵㇶㇷㇸㇹㇺㇻㇼㇽㇾㇿ') |
|
vocab = ['<pad>', '<sos>', '<eos>'] + katakana |
|
vocab_dict = {v: k for k, v in enumerate(vocab)} |
|
|
|
texts = pd.read_csv('rolename.txt', header=None)[0].tolist() |
|
vocab_size=len(vocab) |
|
h=192 |
|
h_latent=64 |
|
max_len=40 |
|
bs=128 |
|
lr=0.2 |
|
lr_step_size=32 |
|
lr_decay=0.5 |
|
momentum=0.9 |
|
epochs=192 |
|
grad_max_norm=1 |
|
|
|
def tokenize(text): |
|
return [vocab_dict[ch] for ch in text] |
|
|
|
def detokenize(tokens): |
|
if EOS_token in tokens: |
|
tokens = tokens[:tokens.index(EOS_token)] |
|
return ''.join(vocab[token] for token in tokens) |
|
|
|
class BatchNormVAE(nn.Module): |
|
def __init__(self, num_features, **kwargs): |
|
super(BatchNormVAE, self).__init__() |
|
kwargs['affine'] = False |
|
self.TAU = 0.5 |
|
self.bn_mu = nn.BatchNorm1d(num_features, **kwargs) |
|
self.bn_sigma = nn.BatchNorm1d(num_features, **kwargs) |
|
self.theta = nn.Parameter(torch.zeros(1)) |
|
|
|
def forward(self, mu, sigma): |
|
mu = self.bn_mu(mu) |
|
sigma = self.bn_sigma(sigma) |
|
scale_mu = torch.sqrt(self.TAU + (1 - self.TAU) * F.sigmoid(self.theta)) |
|
scale_sigma = torch.sqrt((1 - self.TAU) * F.sigmoid(-self.theta)) |
|
return mu*scale_mu, sigma*scale_sigma |
|
|
|
class EncoderVAEBiGRU(nn.Module): |
|
def __init__(self, input_size, hidden_size, latent_size, dropout_p=0.1): |
|
super(EncoderVAEBiGRU, self).__init__() |
|
self.hidden_size = hidden_size |
|
self.embedding = nn.Embedding(input_size, hidden_size) |
|
self.gru = nn.GRU(hidden_size, hidden_size, num_layers=2, batch_first=True, bidirectional=True) |
|
self.proj_mu = nn.Linear(4 * hidden_size, latent_size) |
|
self.proj_sigma = nn.Linear(4 * hidden_size, latent_size) |
|
self.dropout = nn.Dropout(dropout_p) |
|
self.bn = BatchNormVAE(latent_size) |
|
|
|
def forward(self, input, input_lengths): |
|
input_lengths = input_lengths.to('cpu') |
|
embedded = self.dropout(self.embedding(input)) |
|
embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True, enforce_sorted=False) |
|
_, hidden = self.gru(embedded) |
|
hidden = hidden.permute(1, 0, 2).flatten(1, 2) |
|
mu = self.proj_mu(hidden) |
|
sigma = self.proj_sigma(hidden) |
|
mu, sigma = self.bn(mu, sigma) |
|
return self._reparameterize(mu, sigma), mu, sigma ** 2 |
|
|
|
def _reparameterize(self, mu, sigma): |
|
eps = torch.randn_like(sigma) |
|
return eps * sigma + mu |
|
|
|
class DecoderGRU(nn.Module): |
|
def __init__(self, latent_size, hidden_size, output_size): |
|
super(DecoderGRU, self).__init__() |
|
self.proj1 = nn.Linear(latent_size, latent_size) |
|
self.proj_activation = nn.ReLU() |
|
self.proj2 = nn.Linear(latent_size, 2 * hidden_size) |
|
self.embedding = nn.Embedding(output_size, hidden_size) |
|
self.gru = nn.GRU(hidden_size, hidden_size, num_layers=2, batch_first=True) |
|
self.out = nn.Linear(hidden_size, output_size) |
|
|
|
def forward(self, encoder_sample, target_tensor=None, max_length=16): |
|
batch_size = encoder_sample.size(0) |
|
decoder_hidden = self.proj1(encoder_sample) |
|
decoder_hidden = self.proj_activation(decoder_hidden) |
|
decoder_hidden = self.proj2(decoder_hidden) |
|
decoder_hidden = decoder_hidden.view(batch_size, 2, -1).permute(1, 0, 2).contiguous() |
|
if target_tensor is not None: |
|
decoder_input = target_tensor |
|
decoder_outputs, decoder_hidden = self.forward_step(decoder_input, decoder_hidden) |
|
else: |
|
decoder_input = torch.empty(batch_size, 1, dtype=torch.long).fill_(SOS_token) |
|
decoder_outputs = [] |
|
for i in range(max_length): |
|
decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden) |
|
decoder_outputs.append(decoder_output) |
|
_, topi = decoder_output.topk(1) |
|
decoder_input = topi.squeeze(-1).detach() |
|
decoder_outputs = torch.cat(decoder_outputs, dim=1) |
|
decoder_outputs = F.log_softmax(decoder_outputs, dim=-1) |
|
return decoder_outputs, decoder_hidden |
|
|
|
def forward_step(self, input, hidden): |
|
output = self.embedding(input) |
|
output = F.relu(output) |
|
output, hidden = self.gru(output, hidden) |
|
output = self.out(output) |
|
return output, hidden |
|
|
|
class KatakanaDataset(Dataset): |
|
def __init__(self, texts, tokenizer, max_length): |
|
self.texts = texts |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
|
|
def __len__(self): |
|
return len(self.texts) |
|
|
|
def __getitem__(self, idx): |
|
text = self.texts[idx] |
|
tokens = self.tokenizer(text) |
|
enc_text = tokens |
|
enc_len = len(enc_text) |
|
input_text = [SOS_token] + tokens |
|
target_text = tokens + [EOS_token] |
|
enc_text = torch.tensor(enc_text + [0] * (self.max_length - len(enc_text)), dtype=torch.long) |
|
input_text = torch.tensor(input_text + [0] * (self.max_length - len(input_text)), dtype=torch.long) |
|
target_text = torch.tensor(target_text + [0] * (self.max_length - len(target_text)), dtype=torch.long) |
|
return enc_text, enc_len, input_text, target_text |
|
|
|
dataloader = DataLoader( |
|
KatakanaDataset(texts, tokenize, max_len), |
|
batch_size=bs, |
|
shuffle=True, |
|
generator=torch.Generator(device='cuda'), |
|
) |
|
|
|
def train_epoch(dataloader, encoder, decoder, optimizer, max_norm, norm_p=2): |
|
total_loss = 0 |
|
nll = nn.NLLLoss() |
|
for enc_text, enc_len, input_text, target_text in dataloader: |
|
optimizer.zero_grad() |
|
|
|
encoder_sample, mu, var = encoder(enc_text, enc_len) |
|
decoder_outputs, _ = decoder(encoder_sample, input_text) |
|
|
|
loss_recons = nll(decoder_outputs.view(-1, decoder_outputs.size(-1)), target_text.view(-1)) |
|
loss_kld = 0.5 * torch.mean(mu ** 2 + var - var.log() - 1) |
|
loss = loss_recons + loss_kld |
|
loss.backward() |
|
|
|
|
|
nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(decoder.parameters()), max_norm, norm_type=norm_p) |
|
|
|
optimizer.step() |
|
|
|
total_loss += loss.item() |
|
return total_loss / len(dataloader) |
|
|
|
encoder = EncoderVAEBiGRU(vocab_size, h, h_latent).train() |
|
decoder = DecoderGRU(h_latent, h, vocab_size).train() |
|
optimizer = optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=lr, momentum=momentum) |
|
scheduler = StepLR(optimizer, step_size=lr_step_size, gamma=lr_decay) |
|
|
|
with tqdm(range(epochs), desc='Training') as pbar: |
|
for i in pbar: |
|
pbar.set_postfix(loss=train_epoch(dataloader, encoder, decoder, optimizer, grad_max_norm)) |
|
scheduler.step() |
|
|
|
decoder.eval() |
|
for name in [detokenize(seq) for seq in decoder(torch.randn(8,h_latent), max_length=max_len)[0].topk(1)[1].squeeze().tolist()]: |
|
print(name) |
|
torch.save(decoder, 'decoder.pt') |
|
|