jflo commited on
Commit
5e108cc
1 Parent(s): a3b5dec

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import string
4
+ import gradio as gr
5
+
6
+ all_letters = string.ascii_letters + " .,;'"
7
+ n_letters = len(all_letters)
8
+
9
+ all_categories = ['Arabic','Chinese','Czech','Dutch','English','French','German','Greek',
10
+ 'Irish','Italian','Japanese','Korean','Polish','Portuguese','Russian','Scottish',
11
+ 'Spanish','Vietnamese']
12
+
13
+ # Find letter index from all_letters: Ex: "a" = 0
14
+ def letterToIndex(letter):
15
+ return all_letters.find(letter)
16
+
17
+ # Giving each charachter in name a one hot vector
18
+ def lineToTensor(line):
19
+ tensor = torch.zeros(len(line),1,n_letters)
20
+ for li,letter in enumerate(line):
21
+ tensor[li][0][letterToIndex(letter)] = 1
22
+
23
+ return tensor
24
+
25
+ # Loading in torchscript model
26
+ my_model = torch.jit.load('name_classifier_ts.ptl')
27
+
28
+ # Return output given a line_tensor
29
+ def evaluate(line_tensor):
30
+ hidden = torch.zeros(1,128)
31
+
32
+ for i in range(line_tensor.size()[0]):
33
+ output, hidden = my_model(line_tensor[i], hidden)
34
+
35
+ return output
36
+
37
+ # Feeding in a name and number of top predictions you want to output
38
+ def predict(last_name,n_predictions=3):
39
+
40
+ last_name = last_name.title()
41
+ with torch.no_grad():
42
+ output = evaluate(lineToTensor(last_name))
43
+ output = F.softmax(output,dim=1)
44
+
45
+ topv,topi = output.topk(n_predictions,1,True)
46
+
47
+ top_3_countries = ''
48
+ for i in range(n_predictions):
49
+ value = topv[0]
50
+ category_index = topi[0][i].item()
51
+ top_3_countries += f'{all_categories[category_index]}: {round(value[i].item()*100,2)}%'
52
+ top_3_countries += '\n'
53
+ return top_3_countries
54
+
55
+ demo = gr.Interface(predict,
56
+ inputs = "text",
57
+ outputs = "text",
58
+ description="Classify name into language of origin. Returns top 3 languages of origin"
59
+ )
60
+
61
+ demo.launch(inline=False)