Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import streamlit as st
|
3 |
+
from annotated_text import annotated_text
|
4 |
+
|
5 |
+
from flair.data import Sentence
|
6 |
+
from flair.models import SequenceTagger
|
7 |
+
|
8 |
+
checkpoints = [
|
9 |
+
"flair/pos-english",
|
10 |
+
]
|
11 |
+
|
12 |
+
colors = {'ADD': '#b9d9a6', 'AFX': '#eddc92', 'CC': '#95e9d7', 'CD': '#e797db', 'DT': '#9ff48b', 'EX': '#ed92b4', 'FW': '#decfa1', 'HYPH': '#ada7d7', 'IN': '#85fad8', 'JJ': '#8ba4f4', 'JJR': '#e7a498', 'JJS': '#e5c79a', 'LS': '#eb94b6', 'MD': '#e698ae', 'NFP': '#d9d1a6', 'NN': '#96e89f', 'NNP': '#e698c6', 'NNPS': '#ddbfa2', 'NNS': '#f788cd', 'PDT': '#f19c8d', 'POS': '#8ed5f0', 'PRP': '#c4d8a6', 'PRP$': '#e39bdc', 'RB': '#8df1e2', 'RBR': '#d7f787', 'RBS': '#f986f0', 'RP': '#878df8', 'SYM': '#83fe80', 'TO': '#a6d8c9', 'UH': '#d9a6cc', 'VB': '#a1deda', 'VBD': '#8fefe1', 'VBG': '#e3c79b', 'VBN': '#fb81fe', 'VBP': '#d5fe81', 'VBZ': '#8084ff', 'WDT': '#dd80fe', 'WP': '#9ce3e3', 'WP$': '#9fbddf', 'WRB': '#dea1b5', 'XX': '#93b8ec'}
|
13 |
+
|
14 |
+
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
|
15 |
+
def get_model(model_name):
|
16 |
+
return SequenceTagger.load(model_name) # Load the model
|
17 |
+
|
18 |
+
def getPos(s: Sentence):
|
19 |
+
texts = []
|
20 |
+
labels = []
|
21 |
+
for t in s.tokens:
|
22 |
+
for label in t.annotation_layers.keys():
|
23 |
+
texts.append(t.text)
|
24 |
+
labels.append(t.get_labels(label)[0].value)
|
25 |
+
return texts, labels
|
26 |
+
|
27 |
+
def getDictFromPOS(texts, labels):
|
28 |
+
return [{ "text": t, "label": l } for t, l in zip(texts, labels)]
|
29 |
+
|
30 |
+
def getAnnotatedFromPOS(texts, labels):
|
31 |
+
return [(t,l,colors[l]) for t, l in zip(texts, labels)]
|
32 |
+
|
33 |
+
def main():
|
34 |
+
|
35 |
+
st.title("🥖 French Part-Of-Speech Tagging")
|
36 |
+
|
37 |
+
checkpoint = st.selectbox("Choose model", checkpoints)
|
38 |
+
model = get_model(checkpoint)
|
39 |
+
|
40 |
+
default_text = "This is an example sentence."
|
41 |
+
input_text = st.text_area(
|
42 |
+
label="Original text",
|
43 |
+
value=default_text,
|
44 |
+
)
|
45 |
+
|
46 |
+
start = None
|
47 |
+
if st.button("🧠 Compute"):
|
48 |
+
start = time.time()
|
49 |
+
with st.spinner("Search for Part-Of-Speech Tags 🔍"):
|
50 |
+
|
51 |
+
|
52 |
+
# Build Sentence
|
53 |
+
s = Sentence(input_text)
|
54 |
+
|
55 |
+
# predict tags
|
56 |
+
model.predict(s)
|
57 |
+
|
58 |
+
try:
|
59 |
+
|
60 |
+
texts, labels = getPos(s)
|
61 |
+
|
62 |
+
st.header("Labels:")
|
63 |
+
anns = getAnnotatedFromPOS(texts, labels)
|
64 |
+
annotated_text(*anns)
|
65 |
+
|
66 |
+
st.header("JSON:")
|
67 |
+
st.json(getDictFromPOS(texts, labels))
|
68 |
+
|
69 |
+
except Exception as e:
|
70 |
+
st.error("Some error occured!" + str(e))
|
71 |
+
st.stop()
|
72 |
+
|
73 |
+
st.write("---")
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
if start is not None:
|
78 |
+
st.text(f"prediction took {time.time() - start:.2f}s")
|
79 |
+
|
80 |
+
|
81 |
+
if __name__ == "__main__":
|
82 |
+
main()
|