|
import glob |
|
import os |
|
import io |
|
import random |
|
from utils import * |
|
|
|
class NamesDataset: |
|
|
|
def __init__(self, dataset_path): |
|
self.country_to_names = {} |
|
filenames = glob.glob(dataset_path) |
|
for fn in filenames: |
|
basename = os.path.basename(fn) |
|
country = os.path.splitext(basename)[0] |
|
names = io.open(fn, encoding='utf-8').read().strip().split('\n') |
|
names = [name.lower() for name in names] |
|
self.country_to_names[country] = names |
|
|
|
self.countries = list(self.country_to_names.keys()) |
|
self.num_countries = len(self.countries) |
|
|
|
|
|
|
|
|
|
def country_to_tensor(self, country): |
|
idx = self.countries.index(country) |
|
country_tensor = torch.zeros((1, self.num_countries)) |
|
country_tensor[0, idx] = 1. |
|
return country_tensor |
|
|
|
def get_random_sample(self): |
|
rand_country_idx = random.randint(0, self.num_countries-1) |
|
country = self.countries[rand_country_idx] |
|
rand_name_idx = random.randint(0, len(self.country_to_names[country]) - 1) |
|
name = self.country_to_names[country][rand_name_idx] |
|
|
|
name_tensor = name_to_tensor(name) |
|
country_tensor = self.country_to_tensor(country) |
|
shifted_name = shift_name_right(name) |
|
|
|
indices = [letter_to_index(letter) for letter in shifted_name] |
|
target_tensor = torch.LongTensor(indices) |
|
target_tensor.unsqueeze_(0) |
|
|
|
return country, name, country_tensor, name_tensor, target_tensor |
|
|
|
def output_to_country(self, output): |
|
country_idx = torch.argmax(output).item() |
|
return self.countries[country_idx] |
|
|
|
def output_to_letter(self, output): |
|
idx = torch.argmax(output) |
|
return all_letters[idx] |
|
|
|
|