osiria's picture
Update app.py
015ae8b
import os
import gradio as gr
import subprocess
import sys
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
install("numpy")
install("torch")
install("transformers")
install("unidecode")
import numpy as np
import torch
from transformers import AutoTokenizer
from transformers import BertForTokenClassification
from collections import Counter
from unidecode import unidecode
import string
import re
tokenizer = AutoTokenizer.from_pretrained("osiria/bert-italian-uncased-ner")
model = BertForTokenClassification.from_pretrained("osiria/bert-italian-uncased-ner", num_labels = 5)
device = torch.device("cpu")
model = model.to(device)
model.eval()
from transformers import pipeline
ner = pipeline('ner', model=model, tokenizer=tokenizer, device=-1)
header = '''--------------------------------------------------------------------------------------------------
<style>
.vertical-text {
writing-mode: vertical-lr;
text-orientation: upright;
background-color:red;
}
</style>
<center>
<body>
<span class="vertical-text" style="background-color:lightgreen;border-radius: 3px;padding: 3px;"> </span>
<span class="vertical-text" style="background-color:orange;border-radius: 3px;padding: 3px;"> D</span>
<span class="vertical-text" style="background-color:lightblue;border-radius: 3px;padding: 3px;">    E</span>
<span class="vertical-text" style="background-color:tomato;border-radius: 3px;padding: 3px;">    M</span>
<span class="vertical-text" style="background-color:lightgrey;border-radius: 3px;padding: 3px;"> O</span>
<span class="vertical-text" style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"> </span>
</body>
</center>
<br>
'''
maps = {"O": "NONE", "PER": "PER", "LOC": "LOC", "ORG": "ORG", "MISC": "MISC", "DATE": "DATE"}
reg_month = "(?:gennaio|febbraio|marzo|aprile|maggio|giugno|luglio|agosto|settembre|ottobre|novembre|dicembre|january|february|march|april|may|june|july|august|september|october|november|december)"
reg_date = "(?:\d{1,2}\°{0,1}|primo|\d{1,2}\º{0,1})" + " " + reg_month + " " + "\d{4}|"
reg_date = reg_date + reg_month + " " + "\d{4}|"
reg_date = reg_date + "\d{1,2}" + " " + reg_month
reg_date = reg_date + "\d{1,2}" + "(?:\/|\.)\d{1,2}(?:\/|\.)" + "\d{4}|"
reg_date = reg_date + "(?<=dal )\d{4}|(?<=al )\d{4}|(?<=nel )\d{4}|(?<=anno )\d{4}|(?<=del )\d{4}|"
reg_date = reg_date + "\d{1,5} a\.c\.|\d{1,5} d\.c\."
map_punct = {"’": "'", "«": '"', "»": '"', "”": '"', "“": '"', "–": "-", "$": ""}
unk_tok = 9005
merge_th_1 = 0.8
merge_th_2 = 0.4
min_th = 0.55
def extract(text):
text = text.strip()
for mp in map_punct:
text = text.replace(mp, map_punct[mp])
text = re.sub("\[\d+\]", "", text)
warn_flag = False
res_total = []
out_text = ""
for p_text in text.split("\n"):
if p_text:
toks = tokenizer.encode(p_text)
if unk_tok in toks:
warn_flag = True
res_orig = ner(p_text, aggregation_strategy = "first")
res_orig = [el for r, el in enumerate(res_orig) if len(el["word"].strip()) > 1]
res = []
for r, ent in enumerate(res_orig):
if r > 0 and ent["score"] < merge_th_1 and ent["start"] <= res[-1]["end"] + 1 and ent["score"] <= res[-1]["score"]:
res[-1]["word"] = res[-1]["word"] + " " + ent["word"]
res[-1]["score"] = merge_th_1*(res[-1]["score"] > merge_th_2)
res[-1]["end"] = ent["end"]
elif r < len(res_orig) - 1 and ent["score"] < merge_th_1 and res_orig[r+1]["start"] <= ent["end"] + 1 and res_orig[r+1]["score"] > ent["score"]:
res_orig[r+1]["word"] = ent["word"] + " " + res_orig[r+1]["word"]
res_orig[r+1]["score"] = merge_th_1*(res_orig[r+1]["score"] > merge_th_2)
res_orig[r+1]["start"] = ent["start"]
else:
res.append(ent)
res = [el for r, el in enumerate(res) if el["score"] >= min_th]
dates = [{"entity_group": "DATE", "score": 1.0, "word": p_text[el.span()[0]:el.span()[1]], "start": el.span()[0], "end": el.span()[1]} for el in re.finditer(reg_date, p_text, flags = re.IGNORECASE)]
res.extend(dates)
res = sorted(res, key = lambda t: t["start"])
res_total.extend(res)
chunks = [("", "", 0, "NONE")]
for el in res:
if maps[el["entity_group"]] != "NONE":
tag = maps[el["entity_group"]]
chunks.append((p_text[el["start"]: el["end"]], p_text[chunks[-1][2]:el["end"]], el["end"], tag))
if chunks[-1][2] < len(p_text):
chunks.append(("END", p_text[chunks[-1][2]:], -1, "NONE"))
chunks = chunks[1:]
n_text = []
for i, chunk in enumerate(chunks):
rep = chunk[0]
if chunk[3] == "PER":
rep = '<span style="background-color:lightgreen;border-radius: 3px;padding: 3px;"><b>ᴘᴇʀ</b> ' + chunk[0] + '</span>'
elif chunk[3] == "LOC":
rep = '<span style="background-color:orange;border-radius: 3px;padding: 3px;"><b>ʟᴏᴄ</b> ' + chunk[0] + '</span>'
elif chunk[3] == "ORG":
rep = '<span style="background-color:lightblue;border-radius: 3px;padding: 3px;"><b>ᴏʀɢ</b> ' + chunk[0] + '</span>'
elif chunk[3] == "MISC":
rep = '<span style="background-color:tomato;border-radius: 3px;padding: 3px;"><b>ᴍɪsᴄ</b> ' + chunk[0] + '</span>'
elif chunk[3] == "DATE":
rep = '<span style="background-color:lightgrey;border-radius: 3px;padding: 3px;"><b>ᴅᴀᴛᴇ</b> ' + chunk[0] + '</span>'
n_text.append(chunk[1].replace(chunk[0], rep))
n_text = "".join(n_text)
if out_text:
out_text = out_text + "<br>" + n_text
else:
out_text = n_text
tags = [el["word"] for el in res_total if el["entity_group"] not in ['DATE', None]]
cnt = Counter(tags)
tags = sorted(list(set([el for el in tags if cnt[el] > 1])), key = lambda t: cnt[t]*np.exp(-tags.index(t)))[::-1]
tags = [" ".join(re.sub("[^A-Za-z0-9\s]", "", unidecode(tag)).split()) for tag in tags]
tags = ['<span style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"><b>ᴛᴀɢ </b> ' + el + '</span>' for el in tags]
tags = " ".join(tags)
if tags:
out_text = out_text + "<br><br><b>Tags:</b> " + tags
if warn_flag:
out_text = out_text + "<br><br><b>Warning ⚠️:</b> Unknown tokens detected in text. The model might behave erratically"
return out_text
init_text = '''l'agenzia spaziale europea, nota internazionalmente con l'acronimo esa dalla denominazione inglese european space agency, è un'agenzia internazionale fondata nel 1975 incaricata di coordinare i progetti spaziali di 22 paesi europei. il suo quartier generale si trova a parigi in francia, con uffici a mosca, bruxelles, washington e houston. il personale dell'esa del 2016 ammontava a 2 200 persone (esclusi sub-appaltatori e le agenzie nazionali) e il budget del 2022 è di 7,15 miliardi di euro. attualmente il direttore generale dell'agenzia è l'austriaco josef aschbacher, il quale ha sostituito il tedesco johann-dietrich wörner il primo marzo 2021.
lo spazioporto dell'esa è il centre spatial guyanais a kourou, nella guyana francese, un sito scelto, come tutte le basi di lancio, per via della sua vicinanza con l'equatore. durante gli ultimi anni il lanciatore ariane 5 ha consentito all'esa di raggiungere una posizione di primo piano nei lanci commerciali e l'esa è il principale concorrente della nasa nell'esplorazione spaziale.
le missioni scientifiche dell'esa hanno le loro basi al centro europeo per la ricerca e la tecnologia spaziale (estec) di noordwijk, nei paesi bassi. il centro europeo per le operazioni spaziali (esoc), di darmstadt in germania, è responsabile del controllo dei satelliti esa in orbita. le responsabilità del centro europeo per l'osservazione della terra (esrin) di frascati, in italia, includono la raccolta, l'archiviazione e la distribuzione di dati satellitari ai partner dell'esa; oltre a ciò, la struttura agisce come centro di informazione tecnologica per l'intera agenzia. [...]
l'agenzia spaziale italiana (asi) venne fondata nel 1988 per promuovere, coordinare e condurre le attività spaziali in italia. opera in collaborazione con il ministero dell'università e della ricerca scientifica e coopera in numerosi progetti con entità attive nella ricerca scientifica e nelle attività commerciali legate allo spazio. internazionalmente l'asi fornisce la delegazione italiana per l'agenzia spaziale europea e le sue sussidiarie.'''
init_output = extract(init_text)
with gr.Blocks(css="footer {visibility: hidden}", theme=gr.themes.Default(text_size="lg", spacing_size="lg")) as interface:
with gr.Row():
gr.Markdown(header)
with gr.Row():
text = gr.Text(label="Extract entities", lines = 10, value = init_text)
with gr.Row():
with gr.Column():
button = gr.Button("Extract").style(full_width=False)
with gr.Row():
with gr.Column():
entities = gr.Markdown(init_output)
with gr.Row():
with gr.Column():
gr.Markdown("<center>The input examples in this demo are extracted from https://it.wikipedia.org</center>")
button.click(extract, inputs=[text], outputs = [entities])
interface.launch()