qanastek's picture
Update
869f68c
raw history blame
No virus
2.76 kB
import time
import random
import streamlit as st
from annotated_text import annotated_text
from colour import Color
from flair.data import Sentence
from flair.models import SequenceTagger
checkpoints = [
"qanastek/pos-french",
]
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_model(model_name):
return SequenceTagger.load(model_name) # Load the model
def getPos(s: Sentence):
texts = []
labels = []
for t in s.tokens:
for label in t.annotation_layers.keys():
texts.append(t.text)
labels.append(t.get_labels(label)[0].value)
return texts, labels
def getDictFromPOS(texts, labels):
return [{ "text": t, "label": l } for t, l in zip(texts, labels)]
def randomColor():
rgb = (random.uniform(0.0,1.0), random.uniform(0.0,1.0), random.uniform(0.0,1.0))
hsl = (random.uniform(0.0,1.0), random.uniform(0.35,1.0), 0.75)
return str(Color(rgb=rgb, hsl=hsl))
def get_colors(model):
labels = [t.decode("utf-8") for t in model.tag_dictionary.idx2item if t.isupper() and len(t) > 1]
colors = [randomColor() for t in model.tag_dictionary.idx2item if t.isupper() and len(t) > 1]
return dict(zip(labels, colors))
def getAnnotatedFromPOS(texts, labels, colors):
return [(t,l,colors[t]) for t, l in zip(texts, labels)]
def main():
st.title("🥖 French Part-Of-Speech Tagging")
checkpoint = st.selectbox("Choose model", checkpoints)
model = get_model(checkpoint)
colors = get_colors(model)
default_text = "George Washington est allé à Washington"
input_text = st.text_area(
label="Original text",
value=default_text,
)
start = None
if st.button("🧠 Compute"):
start = time.time()
with st.spinner("Search for Part-Of-Speech Tags 🔍"):
# Build Sentence
s = Sentence(input_text)
# predict tags
model.predict(s)
try:
texts, labels = getPos(s)
st.header("Labels:")
anns = getAnnotatedFromPOS(texts, labels, colors)
annotated_text(*anns)
st.header("JSON:")
st.json(getDictFromPOS(texts, labels))
except Exception as e:
st.error("Some error occured!" + str(e))
st.stop()
st.write("---")
st.markdown(
"Built by [Yanis Labrak](https://www.linkedin.com/in/yanis-labrak-8a7412145/) 🚀"
)
st.markdown(
"_Source code made with [FlairNLP](https://github.com/flairNLP/flair)_"
)
if start is not None:
st.text(f"prediction took {time.time() - start:.2f}s")
if __name__ == "__main__":
main()