File size: 1,926 Bytes
bac013d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import glob
import os
import io
import random
from utils import *

class NamesDataset:

    def __init__(self, dataset_path):
        self.country_to_names = {}
        filenames = glob.glob(dataset_path)
        for fn in filenames:
            basename = os.path.basename(fn)
            country = os.path.splitext(basename)[0]
            names = io.open(fn, encoding='utf-8').read().strip().split('\n')
            names = [name.lower() for name in names]
            self.country_to_names[country] = names

        self.countries = list(self.country_to_names.keys())
        self.num_countries = len(self.countries)

    # countries ['Greek', 'Korean']
    # 'Greek': [[1, 0]]
    # 'Korean': [[0, 1]]
    def country_to_tensor(self, country):
        idx = self.countries.index(country)
        country_tensor = torch.zeros((1, self.num_countries))
        country_tensor[0, idx] = 1.
        return country_tensor

    def get_random_sample(self):
        rand_country_idx = random.randint(0, self.num_countries-1)
        country = self.countries[rand_country_idx]
        rand_name_idx = random.randint(0, len(self.country_to_names[country]) - 1)
        name = self.country_to_names[country][rand_name_idx]

        name_tensor = name_to_tensor(name)
        country_tensor = self.country_to_tensor(country) # input to our model
        shifted_name = shift_name_right(name)
        # Gives a list of indices for each letter in shifted_name
        indices = [letter_to_index(letter) for letter in shifted_name]
        target_tensor = torch.LongTensor(indices)
        target_tensor.unsqueeze_(0)

        return country, name, country_tensor, name_tensor, target_tensor

    def output_to_country(self, output):
        country_idx = torch.argmax(output).item()
        return self.countries[country_idx]

    def output_to_letter(self, output):
        idx = torch.argmax(output)
        return all_letters[idx]