jflo's picture
Update app.py
a790da8
import unicodedata
import string
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)
all_categories = ['Arabic','Chinese','Czech','Dutch','English','French','German','Greek',
'Irish','Italian','Japanese','Korean','Polish','Portuguese','Russian','Scottish',
'Spanish','Vietnamese']
# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
and c in all_letters
)
# Find letter index from all_letters, e.g. "a" = 0
def letterToIndex(letter):
return all_letters.find(letter)
# Turn a line into a <line_length x 1 x n_letters>,
# or an array of one-hot letter vectors
def lineToTensor(line):
tensor = torch.zeros(len(line), 1, n_letters)
for li, letter in enumerate(line):
tensor[li][0][letterToIndex(letter)] = 1
return tensor
# Just return an output given a line
def evaluate_model(line_tensor):
model = torch.jit.load("torchscript_classify_names_lstm.pt")
hidden = (torch.zeros(1, 1, 128),torch.zeros(1, 1, 128))
output = model(line_tensor,hidden)
return output
def classify_lastname(last_name):
# Converting to Ascii and capitalizing first letter
last_name = unicodeToAscii(last_name)
last_name = last_name.title()
# Converting name to tensor
line_tensor = lineToTensor(last_name)
output = evaluate_model(line_tensor)
# Grabbing top3 probabilities and categories
top3_prob, top3_cat = torch.topk(output,3)
probs = torch.exp(top3_prob[0])
cats = top3_cat[0]
model_output = {}
for i in range(3):
print(probs[i].item())
model_output[all_categories[cats[i].item()]] = round(probs[i].item(),2)
return model_output
demo = gr.Interface(classify_lastname,
inputs = "text",
outputs = gr.outputs.Label(type="confidences",num_top_classes=3),
title = "Classify Last Name :)",
description="Classifies last name into one of 18 language of origin. Returns confidence % for the top three categories"
)
demo.launch(inline=False)