Sunbread's picture
update model & inference
1e53095
#!/usr/bin/env python
# coding: utf-8
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import StepLR
from tqdm.auto import tqdm
import torch.nn.functional as F
import pandas as pd
torch.manual_seed(114514)
torch.set_default_device('cuda')
SOS_token = 1
EOS_token = 2
katakana = list('゠ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロヮワヰヱヲンヴヵヶヷヸヹヺ・ーヽヾヿㇰㇱㇲㇳㇴㇵㇶㇷㇸㇹㇺㇻㇼㇽㇾㇿ')
vocab = ['<pad>', '<sos>', '<eos>'] + katakana
vocab_dict = {v: k for k, v in enumerate(vocab)}
texts = pd.read_csv('rolename.txt', header=None)[0].tolist()
vocab_size=len(vocab)
h=192
h_latent=64
max_len=40
bs=128
lr=0.2
lr_step_size=32
lr_decay=0.5
momentum=0.9
epochs=192
grad_max_norm=1
def tokenize(text):
return [vocab_dict[ch] for ch in text]
def detokenize(tokens):
if EOS_token in tokens:
tokens = tokens[:tokens.index(EOS_token)]
return ''.join(vocab[token] for token in tokens)
class BatchNormVAE(nn.Module): # https://spaces.ac.cn/archives/7381/
def __init__(self, num_features, **kwargs):
super(BatchNormVAE, self).__init__()
kwargs['affine'] = False
self.TAU = 0.5
self.bn_mu = nn.BatchNorm1d(num_features, **kwargs)
self.bn_sigma = nn.BatchNorm1d(num_features, **kwargs)
self.theta = nn.Parameter(torch.zeros(1))
def forward(self, mu, sigma):
mu = self.bn_mu(mu)
sigma = self.bn_sigma(sigma)
scale_mu = torch.sqrt(self.TAU + (1 - self.TAU) * F.sigmoid(self.theta))
scale_sigma = torch.sqrt((1 - self.TAU) * F.sigmoid(-self.theta))
return mu*scale_mu, sigma*scale_sigma
class EncoderVAEBiGRU(nn.Module):
def __init__(self, input_size, hidden_size, latent_size, dropout_p=0.1):
super(EncoderVAEBiGRU, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size, num_layers=2, batch_first=True, bidirectional=True)
self.proj_mu = nn.Linear(4 * hidden_size, latent_size)
self.proj_sigma = nn.Linear(4 * hidden_size, latent_size)
self.dropout = nn.Dropout(dropout_p)
self.bn = BatchNormVAE(latent_size)
def forward(self, input, input_lengths):
input_lengths = input_lengths.to('cpu')
embedded = self.dropout(self.embedding(input))
embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True, enforce_sorted=False)
_, hidden = self.gru(embedded)
hidden = hidden.permute(1, 0, 2).flatten(1, 2)
mu = self.proj_mu(hidden)
sigma = self.proj_sigma(hidden) # not std, can be negative
mu, sigma = self.bn(mu, sigma)
return self._reparameterize(mu, sigma), mu, sigma ** 2
def _reparameterize(self, mu, sigma):
eps = torch.randn_like(sigma)
return eps * sigma + mu # var is sigma^2
class DecoderGRU(nn.Module):
def __init__(self, latent_size, hidden_size, output_size):
super(DecoderGRU, self).__init__()
self.proj1 = nn.Linear(latent_size, latent_size)
self.proj_activation = nn.ReLU()
self.proj2 = nn.Linear(latent_size, 2 * hidden_size)
self.embedding = nn.Embedding(output_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size, num_layers=2, batch_first=True)
self.out = nn.Linear(hidden_size, output_size)
def forward(self, encoder_sample, target_tensor=None, max_length=16):
batch_size = encoder_sample.size(0)
decoder_hidden = self.proj1(encoder_sample)
decoder_hidden = self.proj_activation(decoder_hidden)
decoder_hidden = self.proj2(decoder_hidden)
decoder_hidden = decoder_hidden.view(batch_size, 2, -1).permute(1, 0, 2).contiguous()
if target_tensor is not None:
decoder_input = target_tensor
decoder_outputs, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
else:
decoder_input = torch.empty(batch_size, 1, dtype=torch.long).fill_(SOS_token)
decoder_outputs = []
for i in range(max_length):
decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
decoder_outputs.append(decoder_output)
_, topi = decoder_output.topk(1)
decoder_input = topi.squeeze(-1).detach()
decoder_outputs = torch.cat(decoder_outputs, dim=1)
decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
return decoder_outputs, decoder_hidden
def forward_step(self, input, hidden):
output = self.embedding(input)
output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = self.out(output)
return output, hidden
class KatakanaDataset(Dataset):
def __init__(self, texts, tokenizer, max_length):
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
tokens = self.tokenizer(text)
enc_text = tokens
enc_len = len(enc_text)
input_text = [SOS_token] + tokens
target_text = tokens + [EOS_token]
enc_text = torch.tensor(enc_text + [0] * (self.max_length - len(enc_text)), dtype=torch.long)
input_text = torch.tensor(input_text + [0] * (self.max_length - len(input_text)), dtype=torch.long)
target_text = torch.tensor(target_text + [0] * (self.max_length - len(target_text)), dtype=torch.long)
return enc_text, enc_len, input_text, target_text
dataloader = DataLoader(
KatakanaDataset(texts, tokenize, max_len),
batch_size=bs,
shuffle=True,
generator=torch.Generator(device='cuda'),
)
def train_epoch(dataloader, encoder, decoder, optimizer, max_norm, norm_p=2):
total_loss = 0
nll = nn.NLLLoss()
for enc_text, enc_len, input_text, target_text in dataloader:
optimizer.zero_grad()
encoder_sample, mu, var = encoder(enc_text, enc_len)
decoder_outputs, _ = decoder(encoder_sample, input_text)
loss_recons = nll(decoder_outputs.view(-1, decoder_outputs.size(-1)), target_text.view(-1))
loss_kld = 0.5 * torch.mean(mu ** 2 + var - var.log() - 1)
loss = loss_recons + loss_kld
loss.backward()
# gradient clipping by norm
nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(decoder.parameters()), max_norm, norm_type=norm_p)
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
encoder = EncoderVAEBiGRU(vocab_size, h, h_latent).train()
decoder = DecoderGRU(h_latent, h, vocab_size).train()
optimizer = optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=lr, momentum=momentum) # momentum
scheduler = StepLR(optimizer, step_size=lr_step_size, gamma=lr_decay)
with tqdm(range(epochs), desc='Training') as pbar:
for i in pbar:
pbar.set_postfix(loss=train_epoch(dataloader, encoder, decoder, optimizer, grad_max_norm))
scheduler.step()
decoder.eval()
for name in [detokenize(seq) for seq in decoder(torch.randn(8,h_latent), max_length=max_len)[0].topk(1)[1].squeeze().tolist()]:
print(name)
torch.save(decoder, 'decoder.pt')