from model import * from utils import * from data import * class NameGenerator: def __init__(self, dataset_path, num_iters): self.data = NamesDataset(dataset_path) self.model = RNN_GEN(self.data.num_countries, NUM_LETTERS, NUM_LETTERS) self.loss_func = nn.NLLLoss() self.learning_rate = 0.0005 self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate) self.num_iters = num_iters self.n_print = 5000 self.max_name_length = 15 #self.save_model_path = save_model_path #def save_model(self): #torch.save(self.model.state_dict(), self.save_model_path) #def load_model(self): #self.model.load_state_dict(torch.load(self.save_model_path)) def train_step(self, name_tensor, country_tensor, target_tensor): num_letters = name_tensor.size()[0] loss = 0 hidden_state = self.model.init_hidden() for i in range(num_letters): # 1. Forward Pass output, hidden_state = self.model(name_tensor[i], country_tensor, hidden_state) # 2. Compute the loss loss += self.loss_func(output, target_tensor[:, i]) self.optimizer.zero_grad() loss.backward() self.optimizer.step() return output, loss.item() / num_letters def train(self): for i in range(self.num_iters): country, name, country_tensor, name_tensor, target_tensor = self.data.get_random_sample() output, loss = self.train_step(name_tensor, country_tensor, target_tensor) if (i % self.n_print == 0): print(f"Iter {i+1}: Loss = {loss:.4f}") print("Neural Network Generated Names: ") self.print_sample_names() print("-" * 40) def sample_neural_network(self, country, starting_letter): starting_letter = starting_letter.lower() with torch.no_grad(): country_tensor = self.data.country_to_tensor(country) hidden_state = self.model.init_hidden() output_name = starting_letter for i in range(0, self.max_name_length-1): output, hidden_state = self.model(letter_to_tensor(output_name[i]), country_tensor, hidden_state) output_letter = self.data.output_to_letter(output) output_name += output_letter if (output_letter == END_CHAR): break return output_name def print_sample_names(self): print(self.sample_neural_network('Greek', 'I')) print(self.sample_neural_network('Greek', 'Z')) print(self.sample_neural_network('Greek', 'B')) print(self.sample_neural_network('Korean', 'I')) print(self.sample_neural_network('Korean', 'X')) print(self.sample_neural_network('Korean', 'G')) if __name__ == '__main__': name_generator = NameGenerator('data/names/*.txt', 100000,) name_generator.train() # name_generator.save_model() #name_generator.load_model() print('Final Sampling: ') name_generator.print_sample_names()