Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,55 +1,100 @@
|
|
1 |
-
import
|
2 |
-
import torch.nn.functional as F
|
3 |
import string
|
|
|
4 |
import gradio as gr
|
5 |
|
|
|
|
|
|
|
|
|
6 |
all_letters = string.ascii_letters + " .,;'"
|
7 |
n_letters = len(all_letters)
|
8 |
|
9 |
-
|
10 |
'Irish','Italian','Japanese','Korean','Polish','Portuguese','Russian','Scottish',
|
11 |
'Spanish','Vietnamese']
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def letterToIndex(letter):
|
15 |
return all_letters.find(letter)
|
16 |
-
|
17 |
-
|
|
|
|
|
18 |
def lineToTensor(line):
|
19 |
-
tensor = torch.zeros(len(line),1,n_letters)
|
20 |
-
for li,letter in enumerate(line):
|
21 |
tensor[li][0][letterToIndex(letter)] = 1
|
|
|
22 |
|
23 |
-
return tensor
|
24 |
-
|
25 |
-
# Loading in torchscript model
|
26 |
-
my_model = torch.jit.load('name_classifier_ts.ptl')
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
-
|
33 |
-
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
return model_output
|
52 |
|
|
|
53 |
demo = gr.Interface(classify_lastname,
|
54 |
inputs = "text",
|
55 |
outputs = gr.outputs.Label(type="confidences",num_top_classes=3),
|
|
|
1 |
+
import unicodedata
|
|
|
2 |
import string
|
3 |
+
|
4 |
import gradio as gr
|
5 |
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
all_letters = string.ascii_letters + " .,;'"
|
11 |
n_letters = len(all_letters)
|
12 |
|
13 |
+
all_categories = ['Arabic','Chinese','Czech','Dutch','English','French','German','Greek',
|
14 |
'Irish','Italian','Japanese','Korean','Polish','Portuguese','Russian','Scottish',
|
15 |
'Spanish','Vietnamese']
|
16 |
+
|
17 |
+
def unicodeToAscii(s):
|
18 |
+
return ''.join(
|
19 |
+
c for c in unicodedata.normalize('NFD', s)
|
20 |
+
if unicodedata.category(c) != 'Mn'
|
21 |
+
and c in all_letters
|
22 |
+
)
|
23 |
+
|
24 |
+
# Find letter index from all_letters, e.g. "a" = 0
|
25 |
def letterToIndex(letter):
|
26 |
return all_letters.find(letter)
|
27 |
+
|
28 |
+
|
29 |
+
# Turn a line into a <line_length x 1 x n_letters>,
|
30 |
+
# or an array of one-hot letter vectors
|
31 |
def lineToTensor(line):
|
32 |
+
tensor = torch.zeros(len(line), 1, n_letters)
|
33 |
+
for li, letter in enumerate(line):
|
34 |
tensor[li][0][letterToIndex(letter)] = 1
|
35 |
+
return tensor
|
36 |
|
|
|
|
|
|
|
|
|
37 |
|
38 |
+
class RNN(nn.Module):
|
39 |
+
"""LSTM class"""
|
40 |
+
def __init__(self, input_size, hidden_size, output_size):
|
41 |
+
'''
|
42 |
+
:param input_size: number of input coming in
|
43 |
+
:param hidden_size: number of he hidden units
|
44 |
+
:param output_size: size of the output
|
45 |
+
'''
|
46 |
+
super(RNN, self).__init__()
|
47 |
|
48 |
+
self.hidden_size = hidden_size
|
49 |
+
self.input_size = input_size
|
50 |
|
51 |
+
#LSTM
|
52 |
+
self.lstm = nn.LSTM(input_size, hidden_size)
|
53 |
+
self.hidden2Cat = nn.Linear(hidden_size, output_size)
|
54 |
+
self.hidden = self.init_hidden()
|
55 |
+
|
56 |
+
def forward(self, input, hidden):
|
57 |
+
|
58 |
+
lstm_out, self.hidden = self.lstm(input, hidden)
|
59 |
+
output = self.hidden2Cat(lstm_out[-1]) #many to one
|
60 |
+
output = F.log_softmax(output, dim=1)
|
61 |
+
|
62 |
+
return output
|
63 |
+
|
64 |
+
def init_hidden(self):
|
65 |
+
return (torch.zeros(1, 1, self.hidden_size),
|
66 |
+
torch.zeros(1, 1, self.hidden_size))
|
67 |
+
|
68 |
+
# Just return an output given a line
|
69 |
+
def evaluate_model(line_tensor):
|
70 |
+
n_hidden = 128
|
71 |
+
n_categories = len(all_categories)
|
72 |
+
model = RNN(n_letters, n_hidden, n_categories)
|
73 |
+
model.load_state_dict(torch.load('classify_names_lstm.pt'))
|
74 |
+
model.eval()
|
75 |
|
76 |
+
hidden = (torch.zeros(1, 1, 128),
|
77 |
+
torch.zeros(1, 1, 128))
|
78 |
+
output = model(line_tensor,hidden)
|
79 |
+
|
80 |
+
return output
|
81 |
+
|
82 |
+
def classify_lastname(last_name):
|
83 |
+
last_name = unicodeToAscii(last_name)
|
84 |
+
line_tensor = lineToTensor(last_name)
|
85 |
+
output = evaluate_model(line_tensor)
|
86 |
+
|
87 |
+
top3_prob, top3_cat = torch.topk(output,3)
|
88 |
+
probs = torch.exp(top3_prob[0])
|
89 |
+
cats = top3_cat[0]
|
90 |
+
|
91 |
+
model_output = {}
|
92 |
+
for i in range(3):
|
93 |
+
print(probs[i].item())
|
94 |
+
model_output[all_categories[cats[i].item()]] = round(probs[i].item(),2)
|
95 |
return model_output
|
96 |
|
97 |
+
|
98 |
demo = gr.Interface(classify_lastname,
|
99 |
inputs = "text",
|
100 |
outputs = gr.outputs.Label(type="confidences",num_top_classes=3),
|