import glob import os import io import random from utils import * class NamesDataset: def __init__(self, dataset_path): self.country_to_names = {} filenames = glob.glob(dataset_path) for fn in filenames: basename = os.path.basename(fn) country = os.path.splitext(basename)[0] names = io.open(fn, encoding='utf-8').read().strip().split('\n') names = [name.lower() for name in names] self.country_to_names[country] = names self.countries = list(self.country_to_names.keys()) self.num_countries = len(self.countries) # countries ['Greek', 'Korean'] # 'Greek': [[1, 0]] # 'Korean': [[0, 1]] def country_to_tensor(self, country): idx = self.countries.index(country) country_tensor = torch.zeros((1, self.num_countries)) country_tensor[0, idx] = 1. return country_tensor def get_random_sample(self): rand_country_idx = random.randint(0, self.num_countries-1) country = self.countries[rand_country_idx] rand_name_idx = random.randint(0, len(self.country_to_names[country]) - 1) name = self.country_to_names[country][rand_name_idx] name_tensor = name_to_tensor(name) country_tensor = self.country_to_tensor(country) # input to our model shifted_name = shift_name_right(name) # Gives a list of indices for each letter in shifted_name indices = [letter_to_index(letter) for letter in shifted_name] target_tensor = torch.LongTensor(indices) target_tensor.unsqueeze_(0) return country, name, country_tensor, name_tensor, target_tensor def output_to_country(self, output): country_idx = torch.argmax(output).item() return self.countries[country_idx] def output_to_letter(self, output): idx = torch.argmax(output) return all_letters[idx]