sparkles's picture
Upload data.py
bac013d
import glob
import os
import io
import random
from utils import *
class NamesDataset:
def __init__(self, dataset_path):
self.country_to_names = {}
filenames = glob.glob(dataset_path)
for fn in filenames:
basename = os.path.basename(fn)
country = os.path.splitext(basename)[0]
names = io.open(fn, encoding='utf-8').read().strip().split('\n')
names = [name.lower() for name in names]
self.country_to_names[country] = names
self.countries = list(self.country_to_names.keys())
self.num_countries = len(self.countries)
# countries ['Greek', 'Korean']
# 'Greek': [[1, 0]]
# 'Korean': [[0, 1]]
def country_to_tensor(self, country):
idx = self.countries.index(country)
country_tensor = torch.zeros((1, self.num_countries))
country_tensor[0, idx] = 1.
return country_tensor
def get_random_sample(self):
rand_country_idx = random.randint(0, self.num_countries-1)
country = self.countries[rand_country_idx]
rand_name_idx = random.randint(0, len(self.country_to_names[country]) - 1)
name = self.country_to_names[country][rand_name_idx]
name_tensor = name_to_tensor(name)
country_tensor = self.country_to_tensor(country) # input to our model
shifted_name = shift_name_right(name)
# Gives a list of indices for each letter in shifted_name
indices = [letter_to_index(letter) for letter in shifted_name]
target_tensor = torch.LongTensor(indices)
target_tensor.unsqueeze_(0)
return country, name, country_tensor, name_tensor, target_tensor
def output_to_country(self, output):
country_idx = torch.argmax(output).item()
return self.countries[country_idx]
def output_to_letter(self, output):
idx = torch.argmax(output)
return all_letters[idx]