Spaces:
Runtime error
Runtime error
from transformers import AutoTokenizer, AutoModelForTokenClassification | |
import gradio as gr | |
import torch | |
id2label = { | |
"0": "B-LOC", | |
"1": "B-MISC", | |
"2": "B-ORG", | |
"3": "B-PER", | |
"4": "I-LOC", | |
"5": "I-MISC", | |
"6": "I-ORG", | |
"7": "I-PER", | |
"8": "O" | |
} | |
tokenizer = AutoTokenizer.from_pretrained('mrm8488/TinyBERT-spanish-uncased-finetuned-ner') | |
model = AutoModelForTokenClassification.from_pretrained('mrm8488/TinyBERT-spanish-uncased-finetuned-ner') | |
def get_objects(text): | |
input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0) | |
outputs = model(input_ids) | |
last_hidden_states = outputs[0] | |
personajes=[] | |
locaciones=[] | |
for m in last_hidden_states: | |
for index, n in enumerate(m): | |
if(index > 0 and index <= len(text.split(" "))): | |
if('LOC' in id2label[str(torch.argmax(n).item())]): | |
locaciones.append(text.split(" ")[index-1]+"=> \n") | |
if('PER' in id2label[str(torch.argmax(n).item())]): | |
personajes.append(text.split(" ")[index-1]+"=> \n") | |
return ''.join(personajes) + "Ubicaciones:\n" + ''.join(locaciones) | |
def change_objects(text, objetos): | |
for personaje in objetos.split('\n'): | |
if ('=>' in personaje and len(personaje.split('=>')) > 0): | |
text = text.replace(personaje.split("=>")[0], personaje.split("=>")[1]) | |
return text | |
demo = gr.Blocks() | |
with demo: | |
cuento = gr.Textbox(lines=2) | |
objetos = gr.Textbox() | |
label = gr.Label() | |
b1 = gr.Button("Identificar Pesonajes y Ubicaciones") | |
b2 = gr.Button("Cambiar objetos") | |
b1.click(get_objects, inputs=cuento, outputs=objetos) | |
b2.click(change_objects, inputs=[cuento, objetos], outputs=label) | |
demo.launch() |