Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import subprocess | |
import sys | |
def install(package): | |
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) | |
install("numpy") | |
install("torch") | |
install("transformers") | |
install("unidecode") | |
import numpy as np | |
import torch | |
from transformers import AutoTokenizer | |
from transformers import DistilBertForTokenClassification | |
from collections import Counter | |
from unidecode import unidecode | |
import string | |
import re | |
tokenizer = AutoTokenizer.from_pretrained("osiria/blaze-it-ner") | |
model = DistilBertForTokenClassification.from_pretrained("osiria/blaze-it-ner", num_labels = 5) | |
device = torch.device("cpu") | |
model = model.to(device) | |
model.eval() | |
from transformers import pipeline | |
ner = pipeline('ner', model=model, tokenizer=tokenizer, device=-1) | |
header = '''-------------------------------------------------------------------------------------------------- | |
<style> | |
.vertical-text { | |
writing-mode: vertical-lr; | |
text-orientation: upright; | |
background-color:red; | |
} | |
</style> | |
<center> | |
<body> | |
<span class="vertical-text" style="background-color:lightgreen;border-radius: 3px;padding: 3px;"> </span> | |
<span class="vertical-text" style="background-color:orange;border-radius: 3px;padding: 3px;"> D</span> | |
<span class="vertical-text" style="background-color:lightblue;border-radius: 3px;padding: 3px;"> E</span> | |
<span class="vertical-text" style="background-color:tomato;border-radius: 3px;padding: 3px;"> M</span> | |
<span class="vertical-text" style="background-color:lightgrey;border-radius: 3px;padding: 3px;"> O</span> | |
<span class="vertical-text" style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"> </span> | |
</body> | |
</center> | |
<br> | |
--------------------------------------------------------------------------------------------------''' | |
paragraph = '''<b>What's BLAZE-IT?</b> | |
This app is a demo of [BLAZE-IT](https://huggingface.co/osiria/blaze-it), a <b>lightweight</b> and <b>uncased</b> italian language model (<b>55M parameters</b> and <b>220MB</b> size). The model is here fine-tuned for named entity recognition on WikiNER (cross-validated F1 score of 89.53%) plus a custom, hand-crafted dataset of 3.500 manually annotated Wikipedia paragraphs. | |
It can recognize entities of the following types (in order to make the most of the color-coding, it is recommended to use the light theme for the interface): | |
- <span style="background-color:lightgreen;border-radius: 3px;padding: 3px;"><b>ᴘᴇʀ</b> person</span>: names of persons | |
- <span style="background-color:orange;border-radius: 3px;padding: 3px;"><b>ʟᴏᴄ</b> location</span>: names of places | |
- <span style="background-color:lightblue;border-radius: 3px;padding: 3px;"><b>ᴏʀɢ</b> organization</span>: names of organizations | |
- <span style="background-color:tomato;border-radius: 3px;padding: 3px;"><b>ᴍɪsᴄ</b> miscellanea</span>: mixed type entities | |
- <span style="background-color:lightgrey;border-radius: 3px;padding: 3px;"><b>ᴅᴀᴛᴇ</b> date</span>: regex-based dates | |
- <span style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"><b>ᴛᴀɢ</b> tag</span>: most relevant entities, of any type | |
The <b>ᴍɪsᴄ</b> class has mixed nature, and it mainly covers names of events or products. Occasionally, entities of other classes might be labeled as <b>ᴍɪsᴄ</b> if the model is not confident enough about their identification. | |
The execution time in this app depends on the availability of the underlying cloud instance, and is not a reflection of the model inference time. | |
If unknown tokens are present in the text, they will interfere with the prediction, and the model may behave erratically. In that case, a warning sign will be displayed. | |
''' | |
maps = {"O": "NONE", "PER": "PER", "LOC": "LOC", "ORG": "ORG", "MISC": "MISC", "DATE": "DATE"} | |
reg_month = "(?:gennaio|febbraio|marzo|aprile|maggio|giugno|luglio|agosto|settembre|ottobre|novembre|dicembre|january|february|march|april|may|june|july|august|september|october|november|december)" | |
reg_date = "(?:\d{1,2}\°{0,1}|primo|\d{1,2}\º{0,1})" + " " + reg_month + " " + "\d{4}|" | |
reg_date = reg_date + reg_month + " " + "\d{4}|" | |
reg_date = reg_date + "\d{1,2}" + " " + reg_month | |
reg_date = reg_date + "\d{1,2}" + "(?:\/|\.)\d{1,2}(?:\/|\.)" + "\d{4}|" | |
reg_date = reg_date + "(?<=dal )\d{4}|(?<=al )\d{4}|(?<=nel )\d{4}|(?<=anno )\d{4}|(?<=del )\d{4}|" | |
reg_date = reg_date + "\d{1,5} a\.c\.|\d{1,5} d\.c\." | |
map_punct = {"’": "'", "«": '"', "»": '"', "”": '"', "“": '"', "–": "-", "$": ""} | |
unk_tok = 9005 | |
merge_th_1 = 0.8 | |
merge_th_2 = 0.4 | |
min_th = 0.6 | |
def extract(text): | |
text = text.strip() | |
for mp in map_punct: | |
text = text.replace(mp, map_punct[mp]) | |
text = re.sub("\[\d+\]", "", text) | |
warn_flag = False | |
res_total = [] | |
out_text = "" | |
for p_text in text.split("\n"): | |
if p_text: | |
toks = tokenizer.encode(p_text) | |
if unk_tok in toks: | |
warn_flag = True | |
res_orig = ner(p_text, aggregation_strategy = "first") | |
res_orig = [el for r, el in enumerate(res_orig) if len(el["word"].strip()) > 1] | |
res = [] | |
for r, ent in enumerate(res_orig): | |
if r > 0 and ent["score"] < merge_th_1 and ent["start"] <= res[-1]["end"] + 1 and ent["score"] <= res[-1]["score"]: | |
res[-1]["word"] = res[-1]["word"] + " " + ent["word"] | |
res[-1]["score"] = merge_th_1*(res[-1]["score"] > merge_th_2) | |
res[-1]["end"] = ent["end"] | |
elif r < len(res_orig) - 1 and ent["score"] < merge_th_1 and res_orig[r+1]["start"] <= ent["end"] + 1 and res_orig[r+1]["score"] > ent["score"]: | |
res_orig[r+1]["word"] = ent["word"] + " " + res_orig[r+1]["word"] | |
res_orig[r+1]["score"] = merge_th_1*(res_orig[r+1]["score"] > merge_th_2) | |
res_orig[r+1]["start"] = ent["start"] | |
else: | |
res.append(ent) | |
res = [el for r, el in enumerate(res) if el["score"] >= min_th] | |
dates = [{"entity_group": "DATE", "score": 1.0, "word": p_text[el.span()[0]:el.span()[1]], "start": el.span()[0], "end": el.span()[1]} for el in re.finditer(reg_date, p_text, flags = re.IGNORECASE)] | |
res.extend(dates) | |
res = sorted(res, key = lambda t: t["start"]) | |
res_total.extend(res) | |
chunks = [("", "", 0, "NONE")] | |
for el in res: | |
if maps[el["entity_group"]] != "NONE": | |
tag = maps[el["entity_group"]] | |
chunks.append((p_text[el["start"]: el["end"]], p_text[chunks[-1][2]:el["end"]], el["end"], tag)) | |
if chunks[-1][2] < len(p_text): | |
chunks.append(("END", p_text[chunks[-1][2]:], -1, "NONE")) | |
chunks = chunks[1:] | |
n_text = [] | |
for i, chunk in enumerate(chunks): | |
rep = chunk[0] | |
if chunk[3] == "PER": | |
rep = '<span style="background-color:lightgreen;border-radius: 3px;padding: 3px;"><b>ᴘᴇʀ</b> ' + chunk[0] + '</span>' | |
elif chunk[3] == "LOC": | |
rep = '<span style="background-color:orange;border-radius: 3px;padding: 3px;"><b>ʟᴏᴄ</b> ' + chunk[0] + '</span>' | |
elif chunk[3] == "ORG": | |
rep = '<span style="background-color:lightblue;border-radius: 3px;padding: 3px;"><b>ᴏʀɢ</b> ' + chunk[0] + '</span>' | |
elif chunk[3] == "MISC": | |
rep = '<span style="background-color:tomato;border-radius: 3px;padding: 3px;"><b>ᴍɪsᴄ</b> ' + chunk[0] + '</span>' | |
elif chunk[3] == "DATE": | |
rep = '<span style="background-color:lightgrey;border-radius: 3px;padding: 3px;"><b>ᴅᴀᴛᴇ</b> ' + chunk[0] + '</span>' | |
n_text.append(chunk[1].replace(chunk[0], rep)) | |
n_text = "".join(n_text) | |
if out_text: | |
out_text = out_text + "<br>" + n_text | |
else: | |
out_text = n_text | |
tags = [el["word"] for el in res_total if el["entity_group"] not in ['DATE', None]] | |
cnt = Counter(tags) | |
tags = sorted(list(set([el for el in tags if cnt[el] > 1])), key = lambda t: cnt[t]*np.exp(-tags.index(t)))[::-1] | |
tags = [" ".join(re.sub("[^A-Za-z0-9\s]", "", unidecode(tag)).split()) for tag in tags] | |
tags = ['<span style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"><b>ᴛᴀɢ </b> ' + el + '</span>' for el in tags] | |
tags = " ".join(tags) | |
if tags: | |
out_text = out_text + "<br><br><b>Tags:</b> " + tags | |
if warn_flag: | |
out_text = out_text + "<br><br><b>Warning ⚠️:</b> Unknown tokens detected in text. The model might behave erratically" | |
return out_text | |
init_text = '''l'agenzia spaziale europea, nota internazionalmente con l'acronimo esa dalla denominazione inglese european space agency, è un'agenzia internazionale fondata nel 1975 incaricata di coordinare i progetti spaziali di 22 paesi europei. il suo quartier generale si trova a parigi in francia, con uffici a mosca, bruxelles, washington e houston. | |
attualmente il direttore generale dell'agenzia è l'austriaco josef aschbacher, il quale ha sostituito il tedesco johann-dietrich wörner il primo marzo 2021. | |
lo spazioporto dell'esa è il centre spatial guyanais a kourou, nella guyana francese, un sito scelto, come tutte le basi di lancio, per via della sua vicinanza con l'equatore. durante gli ultimi anni il lanciatore ariane 5 ha consentito all'esa di raggiungere una posizione di primo piano nei lanci commerciali e l'esa è il principale concorrente della nasa nell'esplorazione spaziale. | |
le missioni scientifiche dell'esa hanno le loro basi al centro europeo per la ricerca e la tecnologia spaziale (estec) di noordwijk, nei paesi bassi. il centro europeo per le operazioni spaziali (esoc), di darmstadt in germania, è responsabile del controllo dei satelliti esa in orbita. [...] | |
l'agenzia spaziale italiana (asi) venne fondata nel 1988 per promuovere, coordinare e condurre le attività spaziali in italia. opera in collaborazione con il ministero dell'università e della ricerca scientifica e coopera in numerosi progetti con entità attive nella ricerca scientifica e nelle attività commerciali legate allo spazio. internazionalmente l'asi fornisce la delegazione italiana per l'agenzia spaziale europea e le sue sussidiarie.''' | |
init_output = extract(init_text) | |
with gr.Blocks(theme=gr.themes.Default(text_size="lg", spacing_size="lg")) as interface: | |
with gr.Row(): | |
gr.Markdown(header) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(paragraph) | |
with gr.Column(): | |
incipit = gr.Markdown("<b>Highlighted entities<b>") | |
entities = gr.Markdown(init_output) | |
with gr.Row(): | |
with gr.Column(): | |
text = gr.Text(label="Extract entities", lines = 10, value = init_text) | |
with gr.Column(): | |
gr.Examples([["aristotele nacque nel 384 a.c. o nel 383 a.c. a stagira, l'attuale stavro, colonia greca situata nella parte nord-orientale della penisola calcidica della tracia. si dice che il padre, nicomaco, sia vissuto presso aminta iii, re dei macedoni, prestandogli i servigi di medico e di amico. aristotele, come figlio del medico reale, doveva pertanto risiedere nella capitale del regno di macedonia"], | |
["mi chiamo edoardo, vivo a roma e lavoro per l'agenzia spaziale italiana, nella missione prisma"], | |
["wikipedia è un'enciclopedia online a contenuto libero, collaborativa, multilingue e gratuita, nata nel 2001, sostenuta e ospitata dalla wikimedia foundation, un'organizzazione non a scopo di lucro statunitense. lanciata da jimmy wales e larry sanger il 15 gennaio 2001, inizialmente nell'edizione in lingua inglese, nei mesi successivi ha aggiunto edizioni in numerose altre lingue"]], | |
inputs=[text]) | |
with gr.Row(): | |
button = gr.Button("Extract").style(full_width=False) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("<center>The input examples in this demo are extracted from https://it.wikipedia.org</center>") | |
button.click(extract, inputs=[text], outputs = [entities]) | |
interface.launch() |