sparkles's picture
Upload utils.py
1dc7ff6
raw
history blame contribute delete
No virus
592 Bytes
import torch
END_CHAR = '|'
all_letters = 'abcdefghijklmnopqrstuvwxyz ' + END_CHAR
NUM_LETTERS = len(all_letters)
def letter_to_index(letter):
return all_letters.index(letter)
def letter_to_tensor(letter):
idx = letter_to_index(letter)
one_hot = torch.zeros((1, NUM_LETTERS))
one_hot[0, idx] = 1
return one_hot
def name_to_tensor(name):
letter_tensors = []
for letter in name:
letter_tensors.append(letter_to_tensor(letter))
return torch.stack(letter_tensors)
# chi -> hi|
# woo -> oo|
def shift_name_right(name):
return name[1:] + END_CHAR