File size: 2,760 Bytes
049266b
49c3a11
049266b
2c5db9f
 
9a31420
049266b
 
 
 
 
 
 
 
 
 
2c5db9f
049266b
64f92ba
 
 
 
 
 
 
 
 
 
 
049266b
9a31420
 
869f68c
9a31420
 
869f68c
 
 
 
 
 
204126b
2c5db9f
049266b
 
e6c30c5
049266b
 
 
869f68c
049266b
 
 
 
 
 
 
 
 
 
 
 
 
601925f
049266b
 
601925f
049266b
 
64f92ba
 
 
2c5db9f
869f68c
c5643ee
64f92ba
dbcd12b
8dffc3b
64f92ba
049266b
 
 
 
 
64f92ba
049266b
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import time
import random
import streamlit as st
from annotated_text import annotated_text

from colour import Color

from flair.data import Sentence
from flair.models import SequenceTagger

checkpoints = [
    "qanastek/pos-french",
]

@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_model(model_name):
    return SequenceTagger.load(model_name) # Load the model

def getPos(s: Sentence):
    texts = []
    labels = []
    for t in s.tokens:
        for label in t.annotation_layers.keys():
            texts.append(t.text)
            labels.append(t.get_labels(label)[0].value)          
    return texts, labels

def getDictFromPOS(texts, labels):
    return [{ "text": t, "label": l } for t, l in zip(texts, labels)]

def randomColor():
    rgb = (random.uniform(0.0,1.0), random.uniform(0.0,1.0), random.uniform(0.0,1.0))
    hsl = (random.uniform(0.0,1.0), random.uniform(0.35,1.0), 0.75)
    return str(Color(rgb=rgb, hsl=hsl))

def get_colors(model):
    labels = [t.decode("utf-8") for t in model.tag_dictionary.idx2item if t.isupper() and len(t) > 1]
    colors = [randomColor() for t in model.tag_dictionary.idx2item if t.isupper() and len(t) > 1]
    return dict(zip(labels, colors))

def getAnnotatedFromPOS(texts, labels, colors):
    return [(t,l,colors[l]) for t, l in zip(texts, labels)]

def main():

    st.title("🥖 French Part-Of-Speech Tagging")

    checkpoint = st.selectbox("Choose model", checkpoints)
    model = get_model(checkpoint)
    colors = get_colors(model)

    default_text = "George Washington est allé à Washington"
    input_text = st.text_area(
        label="Original text",
        value=default_text,
    )

    start = None
    if st.button("🧠 Compute"):
        start = time.time()
        with st.spinner("Search for Part-Of-Speech Tags 🔍"):
            
            # Build Sentence
            s = Sentence(input_text)

            # predict tags
            model.predict(s)

            try:

                texts, labels = getPos(s)
                
                st.header("Labels:")
                anns = getAnnotatedFromPOS(texts, labels, colors)
                annotated_text(*anns)

                st.header("JSON:")
                st.json(getDictFromPOS(texts, labels))

            except Exception as e:
                st.error("Some error occured!" + str(e))
                st.stop()

    st.write("---")

    st.markdown(
        "Built by [Yanis Labrak](https://www.linkedin.com/in/yanis-labrak-8a7412145/) 🚀"
    )
    st.markdown(
        "_Source code made with [FlairNLP](https://github.com/flairNLP/flair)_"
    )

    if start is not None:
        st.text(f"prediction took {time.time() - start:.2f}s")


if __name__ == "__main__":
    main()