|
import torch |
|
|
|
END_CHAR = '|' |
|
all_letters = 'abcdefghijklmnopqrstuvwxyz ' + END_CHAR |
|
NUM_LETTERS = len(all_letters) |
|
|
|
def letter_to_index(letter): |
|
return all_letters.index(letter) |
|
|
|
def letter_to_tensor(letter): |
|
idx = letter_to_index(letter) |
|
one_hot = torch.zeros((1, NUM_LETTERS)) |
|
one_hot[0, idx] = 1 |
|
return one_hot |
|
|
|
def name_to_tensor(name): |
|
letter_tensors = [] |
|
for letter in name: |
|
letter_tensors.append(letter_to_tensor(letter)) |
|
return torch.stack(letter_tensors) |
|
|
|
|
|
|
|
def shift_name_right(name): |
|
return name[1:] + END_CHAR |
|
|
|
|