File size: 3,745 Bytes
88ccef7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import time
import streamlit as st
from annotated_text import annotated_text

from flair.data import Sentence
from flair.models import SequenceTagger

checkpoints = [
    "qanastek/pos-french",
]

colors = {'DET': '#b9d9a6', 'NFP': '#eddc92', 'ADJFP': '#95e9d7', 'AUX': '#e797db', 'VPPMS': '#9ff48b', 'ADV': '#ed92b4', 'PREP': '#decfa1', 'PDEMMS': '#ada7d7', 'NMS': '#85fad8', 'COSUB': '#8ba4f4', 'PINDMS': '#e7a498', 'PPOBJMS': '#e5c79a', 'VERB': '#eb94b6', 'DETFS': '#e698ae', 'NFS': '#d9d1a6', 'YPFOR': '#96e89f', 'VPPFS': '#e698c6', 'PUNCT': '#ddbfa2', 'DETMS': '#f788cd', 'PROPN': '#f19c8d', 'ADJMS': '#8ed5f0', 'PPER3FS': '#c4d8a6', 'ADJFS': '#e39bdc', 'COCO': '#8df1e2', 'NMP': '#d7f787', 'PREL': '#f986f0', 'PPER1S': '#878df8', 'ADJMP': '#83fe80', 'VPPMP': '#a6d8c9', 'DINTMS': '#d9a6cc', 'PPER3MS': '#a1deda', 'PPER3MP': '#8fefe1', 'PREF': '#e3c79b', 'ADJ': '#fb81fe', 'DINTFS': '#d5fe81', 'CHIF': '#8084ff', 'XFAMIL': '#dd80fe', 'PRELFS': '#9ce3e3', 'SYM': '#9fbddf', 'NOUN': '#dea1b5', 'MOTINC': '#93b8ec', 'PINDFS': '#f787a5', 'PPOBJMP': '#dca3d2', 'NUM': '#b2e897', 'PREFP': '#e39cd0', 'PDEMFS': '#d8a7cb', 'VPPFP': '#83d9fb', 'PPER3FP': '#a1ddaa', 'PPOBJFS': '#e9ca95', 'PINDMP': '#e897e3', 'PRON': '#e29dcc', 'PPOBJFP': '#86f9dc', 'PART': '#aa96e8', 'PDEMMP': '#b2d7a8', 'PRELMS': '#e39bde', 'PDEMFP': '#b1e599', 'PRELFP': '#bbe39b', 'INTJ': '#bde996', 'PREFS': '#b39be4', 'PINDFP': '#e2e897', 'PRELMP': '#a5c0da', 'PINTFS': '#ceff80', 'PPER2S': '#d5a2dd', 'VPPRE': '#e78af4', '<START>': '#e6a899', '<STOP>': '#9adde5'}

@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_model(model_name):
    return SequenceTagger.load(model_name) # Charge le modèle

def getPos(s: Sentence):
    texts = []
    labels = []
    for t in s.tokens:
        for label in t.annotation_layers.keys():
            texts.append(t.text)
            labels.append(t.get_labels(label)[0].value)          
    return texts, labels

def getDictFromPOS(texts, labels):
    return [{ "texte": t, "étiquette": l } for t, l in zip(texts, labels)]

def getAnnotatedFromPOS(texts, labels):
    return [(t,l,colors[l]) for t, l in zip(texts, labels)]

def main():

    st.title("🥖 Étiqueteur morphosyntaxique étendu pour le français")

    checkpoint = st.selectbox("Choix du modèle", checkpoints)
    model = get_model(checkpoint)
    
    default_text = "George Washington est allé à Washington"
    input_text = st.text_area(
        label="Texte",
        value=default_text,
    )

    start = None
    if st.button("🧠 Calculer"):
        start = time.time()
        with st.spinner("Calcul des étiquettes morphosyntaxiques en cours... 🔍"):
            
            # Build Sentence
            s = Sentence(input_text)

            # predict tags
            model.predict(s)

            try:

                texts, labels = getPos(s)
                
                st.header("Étiquettes:")
                anns = getAnnotatedFromPOS(texts, labels)
                annotated_text(*anns)

                st.header("JSON:")
                st.json(getDictFromPOS(texts, labels))

            except Exception as e:
                st.error("Une erreur s'est produite!" + str(e))
                st.stop()

    st.write("---")

    st.markdown(
        "Construit par [Yanis Labrak](https://www.linkedin.com/in/yanis-labrak-8a7412145/) & [Richard Dufour](https://cv.archives-ouvertes.fr/richard-dufour) avec [FlairNLP](https://github.com/flairNLP/flair) 🚀"
    )
    st.markdown(
        "_Ce travail a été soutenu financièrement par [Zenidoc](https://zenidoc.fr/)_"
    )

    if start is not None:
        st.text(f"La prédiction a prise {time.time() - start:.2f}s")

if __name__ == "__main__":
    main()