NERmembert / app.py
bourdoiscatie's picture
Update app.py
c2845f7 verified
import streamlit as st
from annotated_text import annotated_text
from transformers import pipeline
from PIL import Image
import re
st.sidebar.header("**Instructions**")
st.sidebar.markdown("Démonstrateur des modèles [NERmembert](https://hf.co/collections/CATIE-AQ/french-ner-pack-658aefafe3f7a2dcf0e4dbb4) entraînés sur 385 000 à 420 000 données en français en fonction de la configuration. Les modèles sont capables d'étiquetter les entités LOC (Localisations), PER (Personnalités), ORG (Organisations) et MISC (Divers). Ils sont disponibles en version *base* (110M de paramètres) et *large* (335M de paramètres). Pour les essayer, sélectionnez la version de votre choix ci-dessous, puis renseignez un texte. Enfin appuyez sur le bouton « Appliquer le modèle » pour observer la réponse trouvée par le modèle. Pour en savoir plus sur ces modèles, vous pouvez lire l'[article de blog](https://blog.vaniila.ai/NER/) détaillant la démarche suvie.")
version = st.sidebar.radio("Choix de la version du modèle :", ["version 4 entités base", "version 4 entités large", "version 3 entités base", "version 3 entités large"])
st.sidebar.markdown("---")
st.sidebar.markdown("Ce modèle a été entraîné via la plateforme [*Vaniila*](https://www.vaniila.ai/) du [CATIE](https://www.catie.fr/).")
image_path = 'Vaniila.png'
image = Image.open(image_path)
st.sidebar.image(image, caption=None, width=None, use_column_width=None, clamp=False, channels="RGB", output_format="auto")
@st.cache_resource
def load_model(version,text):
if version == "version 3 entités base":
ner = pipeline('token-classification', model='CATIE-AQ/NERmembert-base-3entities', tokenizer='CATIE-AQ/NERmembert-base-3entities', aggregation_strategy="simple")
result = ner(text)
return result
if version == "version 3 entités large":
ner = pipeline('token-classification', model='CATIE-AQ/NERmembert-large-3entities', tokenizer='CATIE-AQ/NERmembert-large-3entities', aggregation_strategy="simple")
result = ner(text)
return result
if version == "version 4 entités base":
ner = pipeline('token-classification', model='CATIE-AQ/NERmembert-base-4entities', tokenizer='CATIE-AQ/NERmembert-base-4entities', aggregation_strategy="simple")
result = ner(text)
return result
else:
ner = pipeline('token-classification', model='CATIE-AQ/NERmembert-large-4entities', tokenizer='CATIE-AQ/NERmembert-large-4entities', aggregation_strategy="simple")
result = ner(text)
return result
def getcolor(texts, labels):
colors = {'LOC': '#38419D', 'PER': '#BF3131', 'ORG': '#597E52', 'MISC':'#F1C232'}
return [(t,l,colors[l]) for t, l in zip(texts, labels)]
def color_annotation(to_print,text) :
text_ner = []
label_ner = []
for i in range(len(to_print)) :
text_ner.append(to_print[i]["word"])
label_ner.append(to_print[i]["entity_group"])
anns = getcolor(text_ner, label_ner)
anns = list(set(anns))
text_ner = list(set(text_ner))
text_ner = list(sorted(text_ner, key = len))
for i in range(len(anns)):
for j in range(len(text_ner)):
if text_ner[j] == anns[i][0]:
text = text.replace(text_ner[j],str(anns[i]))
for i in re.findall(r"\((.*?)\)", text) : # pour gérer les cas de mots inclus dans des n_grams
if "(" in i:
text = text.replace(i+")",i.split(', ')[0][2:-1])
text = text.replace(")",')","').replace(')","","',')","').replace("(",'","(').replace('","","(','","(').replace("'-","-")
return text
st.markdown("<h2 style='text-align: center'>NERmembert", unsafe_allow_html=True)
st.markdown("<h4 style='text-align: center'>"+version, unsafe_allow_html=True)
option = st.selectbox(
'Choix du mode',
('Texte libre', 'Exemple 1', 'Exemple 2'))
if option == "Exemple 1":
text = st.text_area("Votre texte", value="Le dévoilement du logo officiel des JO s'est déroulé le 21 octobre 2019 au Grand Rex. Ce nouvel emblème et cette nouvelle typographie ont été conçus par le designer Sylvain Boyer avec les agences Royalties & Ecobranding. Rond, il rassemble trois symboles : une médaille d'or, la flamme olympique et Marianne, symbolisée par un visage de femme mais privée de son bonnet phrygien caractéristique. La typographie dessinée fait référence à l'Art déco, mouvement artistique des années 1920, décennie pendant laquelle ont eu lieu pour la dernière fois les Jeux olympiques à Paris en 1924. Pour la première fois, ce logo sera unique pour les Jeux olympiques et les Jeux paralympiques.",height=175)
if text:
to_print = load_model(version,text)
display = color_annotation(to_print,text)
list_to_display = [] # pour pouvoir afficher la couleur, on doit passer les mots à colorier de str en tuple
for i in range(len(display.split('","'))):
if "#" in display.split('","')[i]:
list_to_display.append(eval(display.split('","')[i]))
else :
list_to_display.append(display.split('","')[i])
annotated_text(*list_to_display)
st.write("\n")
with st.expander("Afficher le score pour chacune des entitées trouvées :"):
for i in range(len(to_print)) :
st.write("- Score pour que ",to_print[i]["word"]," soit de type", to_print[i]["entity_group"]," : ",round(to_print[i]["score"],3))
elif option == "Exemple 2":
text = st.text_area("Votre texte", value="Assurés de disputer l'Euro 2024 en Allemagne l'été prochain (du 14 juin au 14 juillet) depuis leur victoire aux Pays-Bas, les Bleus ont fait le nécessaire pour avoir des certitudes. Avec six victoires en six matchs officiels et un seul but encaissé, Didier Deschamps a consolidé les acquis de la dernière Coupe du monde de football. Les joueurs clés sont connus : Kylian Mbappé, Aurélien Tchouameni, Antoine Griezmann, Ibrahima Konaté ou encore Mike Maignan.",height=175)
if text:
to_print = load_model(version,text)
display = color_annotation(to_print,text)
list_to_display = [] # pour pouvoir afficher la couleur, on doit passer les mots à colorier de str en tuple
for i in range(len(display.split('","'))):
if "#" in display.split('","')[i]:
list_to_display.append(eval(display.split('","')[i]))
else :
list_to_display.append(display.split('","')[i])
annotated_text(*list_to_display)
st.write("\n")
with st.expander("Afficher le score pour chacune des entitées trouvées :"):
for i in range(len(to_print)) :
st.write("- Score pour que ",to_print[i]["word"]," soit de type", to_print[i]["entity_group"]," : ",round(to_print[i]["score"],3))
else:
text = st.text_area("Votre texte", value="",height=175)
if text:
col1, col2, col3 = st.columns(3)
if col2.button('Appliquer le modèle'):
to_print = load_model(version,text)
display = color_annotation(to_print,text)
list_to_display = [] # pour pouvoir afficher la couleur, on doit passer les mots à colorier de str en tuple
for i in range(len(display.split('","'))):
if "#" in display.split('","')[i]:
list_to_display.append(eval(display.split('","')[i]))
else :
list_to_display.append(display.split('","')[i])
annotated_text(*list_to_display)
st.write("\n")
with st.expander("Afficher le score pour chacune des entitées trouvées :"):
for i in range(len(to_print)) :
st.write("- Score pour que ",to_print[i]["word"]," soit de type", to_print[i]["entity_group"]," : ",round(to_print[i]["score"],3))