Spaces:
Paused
Paused
import unicodedata | |
import string | |
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
all_letters = string.ascii_letters + " .,;'" | |
n_letters = len(all_letters) | |
all_categories = ['Arabic','Chinese','Czech','Dutch','English','French','German','Greek', | |
'Irish','Italian','Japanese','Korean','Polish','Portuguese','Russian','Scottish', | |
'Spanish','Vietnamese'] | |
# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427 | |
def unicodeToAscii(s): | |
return ''.join( | |
c for c in unicodedata.normalize('NFD', s) | |
if unicodedata.category(c) != 'Mn' | |
and c in all_letters | |
) | |
# Find letter index from all_letters, e.g. "a" = 0 | |
def letterToIndex(letter): | |
return all_letters.find(letter) | |
# Turn a line into a <line_length x 1 x n_letters>, | |
# or an array of one-hot letter vectors | |
def lineToTensor(line): | |
tensor = torch.zeros(len(line), 1, n_letters) | |
for li, letter in enumerate(line): | |
tensor[li][0][letterToIndex(letter)] = 1 | |
return tensor | |
# Just return an output given a line | |
def evaluate_model(line_tensor): | |
model = torch.jit.load("torchscript_classify_names_lstm.pt") | |
hidden = (torch.zeros(1, 1, 128),torch.zeros(1, 1, 128)) | |
output = model(line_tensor,hidden) | |
return output | |
def classify_lastname(last_name): | |
# Converting to Ascii and capitalizing first letter | |
last_name = unicodeToAscii(last_name) | |
last_name = last_name.title() | |
# Converting name to tensor | |
line_tensor = lineToTensor(last_name) | |
output = evaluate_model(line_tensor) | |
# Grabbing top3 probabilities and categories | |
top3_prob, top3_cat = torch.topk(output,3) | |
probs = torch.exp(top3_prob[0]) | |
cats = top3_cat[0] | |
model_output = {} | |
for i in range(3): | |
print(probs[i].item()) | |
model_output[all_categories[cats[i].item()]] = round(probs[i].item(),2) | |
return model_output | |
demo = gr.Interface(classify_lastname, | |
inputs = "text", | |
outputs = gr.outputs.Label(type="confidences",num_top_classes=3), | |
title = "Classify Last Name :)", | |
description="Classifies last name into one of 18 language of origin. Returns confidence % for the top three categories" | |
) | |
demo.launch(inline=False) |