jflo commited on
Commit
a790da8
1 Parent(s): 835b42d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -39
app.py CHANGED
@@ -35,58 +35,25 @@ def lineToTensor(line):
35
  tensor[li][0][letterToIndex(letter)] = 1
36
  return tensor
37
 
38
-
39
- class RNN(nn.Module):
40
- """LSTM class"""
41
- def __init__(self, input_size, hidden_size, output_size):
42
- '''
43
- :param input_size: number of input coming in
44
- :param hidden_size: number of hidden units
45
- :param output_size: size of the output
46
- '''
47
- super(RNN, self).__init__()
48
-
49
- self.hidden_size = hidden_size
50
- self.input_size = input_size
51
-
52
- #LSTM
53
- self.lstm = nn.LSTM(input_size, hidden_size)
54
- self.hidden2Cat = nn.Linear(hidden_size, output_size)
55
- self.hidden = self.init_hidden()
56
-
57
- def forward(self, input, hidden):
58
-
59
- lstm_out, self.hidden = self.lstm(input, hidden)
60
- output = self.hidden2Cat(lstm_out[-1]) #many to one
61
- output = F.log_softmax(output, dim=1)
62
-
63
- return output
64
-
65
- def init_hidden(self):
66
- return (torch.zeros(1, 1, self.hidden_size),
67
- torch.zeros(1, 1, self.hidden_size))
68
-
69
  # Just return an output given a line
70
  def evaluate_model(line_tensor):
71
- n_hidden = 128
72
- n_categories = len(all_categories)
73
- model = RNN(n_letters, n_hidden, n_categories)
74
- model.load_state_dict(torch.load('classify_names_lstm.pt'))
75
- model.eval()
76
 
77
- hidden = (torch.zeros(1, 1, 128),
78
- torch.zeros(1, 1, 128))
79
  output = model(line_tensor,hidden)
80
 
81
  return output
82
 
83
  def classify_lastname(last_name):
 
84
  last_name = unicodeToAscii(last_name)
85
  last_name = last_name.title()
86
-
 
87
  line_tensor = lineToTensor(last_name)
88
  output = evaluate_model(line_tensor)
89
 
 
90
  top3_prob, top3_cat = torch.topk(output,3)
91
  probs = torch.exp(top3_prob[0])
92
  cats = top3_cat[0]
35
  tensor[li][0][letterToIndex(letter)] = 1
36
  return tensor
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # Just return an output given a line
39
  def evaluate_model(line_tensor):
40
+ model = torch.jit.load("torchscript_classify_names_lstm.pt")
 
 
 
 
41
 
42
+ hidden = (torch.zeros(1, 1, 128),torch.zeros(1, 1, 128))
 
43
  output = model(line_tensor,hidden)
44
 
45
  return output
46
 
47
  def classify_lastname(last_name):
48
+ # Converting to Ascii and capitalizing first letter
49
  last_name = unicodeToAscii(last_name)
50
  last_name = last_name.title()
51
+
52
+ # Converting name to tensor
53
  line_tensor = lineToTensor(last_name)
54
  output = evaluate_model(line_tensor)
55
 
56
+ # Grabbing top3 probabilities and categories
57
  top3_prob, top3_cat = torch.topk(output,3)
58
  probs = torch.exp(top3_prob[0])
59
  cats = top3_cat[0]