changeCharacter / app.py
dquisi's picture
Create new file
51254c8
raw
history blame contribute delete
No virus
1.68 kB
from transformers import AutoTokenizer, AutoModelForTokenClassification
import gradio as gr
import torch
id2label = {
"0": "B-LOC",
"1": "B-MISC",
"2": "B-ORG",
"3": "B-PER",
"4": "I-LOC",
"5": "I-MISC",
"6": "I-ORG",
"7": "I-PER",
"8": "O"
}
tokenizer = AutoTokenizer.from_pretrained('mrm8488/TinyBERT-spanish-uncased-finetuned-ner')
model = AutoModelForTokenClassification.from_pretrained('mrm8488/TinyBERT-spanish-uncased-finetuned-ner')
def get_objects(text):
input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)
outputs = model(input_ids)
last_hidden_states = outputs[0]
personajes=[]
locaciones=[]
for m in last_hidden_states:
for index, n in enumerate(m):
if(index > 0 and index <= len(text.split(" "))):
if('LOC' in id2label[str(torch.argmax(n).item())]):
locaciones.append(text.split(" ")[index-1]+"=> \n")
if('PER' in id2label[str(torch.argmax(n).item())]):
personajes.append(text.split(" ")[index-1]+"=> \n")
return ''.join(personajes) + "Ubicaciones:\n" + ''.join(locaciones)
def change_objects(text, objetos):
for personaje in objetos.split('\n'):
if ('=>' in personaje and len(personaje.split('=>')) > 0):
text = text.replace(personaje.split("=>")[0], personaje.split("=>")[1])
return text
demo = gr.Blocks()
with demo:
cuento = gr.Textbox(lines=2)
objetos = gr.Textbox()
label = gr.Label()
b1 = gr.Button("Identificar Pesonajes y Ubicaciones")
b2 = gr.Button("Cambiar objetos")
b1.click(get_objects, inputs=cuento, outputs=objetos)
b2.click(change_objects, inputs=[cuento, objetos], outputs=label)
demo.launch()