Wootang01's picture
Update app.py
977de64
import time
import streamlit as st
from annotated_text import annotated_text
from flair.data import Sentence
from flair.models import SequenceTagger
checkpoints = [
"flair/pos-english",
]
colors = {'ADD': '#b9d9a6', 'AFX': '#eddc92', 'CC': '#95e9d7', 'CD': '#e797db', 'DT': '#9ff48b', 'EX': '#ed92b4', 'FW': '#decfa1', 'HYPH': '#ada7d7', 'IN': '#85fad8', 'JJ': '#8ba4f4', 'JJR': '#e7a498', 'JJS': '#e5c79a', 'LS': '#eb94b6', 'MD': '#e698ae', 'NFP': '#d9d1a6', 'NN': '#96e89f', 'NNP': '#e698c6', 'NNPS': '#ddbfa2', 'NNS': '#f788cd', 'PDT': '#f19c8d', 'POS': '#8ed5f0', 'PRP': '#c4d8a6', 'PRP$': '#e39bdc', 'RB': '#8df1e2', 'RBR': '#d7f787', 'RBS': '#f986f0', 'RP': '#878df8', 'SYM': '#83fe80', 'TO': '#a6d8c9', 'UH': '#d9a6cc', 'VB': '#a1deda', 'VBD': '#8fefe1', 'VBG': '#e3c79b', 'VBN': '#fb81fe', 'VBP': '#d5fe81', 'VBZ': '#8084ff', 'WDT': '#dd80fe', 'WP': '#9ce3e3', 'WP$': '#9fbddf', 'WRB': '#dea1b5', 'XX': '#93b8ec'}
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_model(model_name):
return SequenceTagger.load(model_name) # Load the model
def getPos(s: Sentence):
texts = []
labels = []
for t in s.tokens:
for label in t.annotation_layers.keys():
texts.append(t.text)
labels.append(t.get_labels(label)[0].value)
return texts, labels
def getDictFromPOS(texts, labels):
return [{ "text": t, "label": l } for t, l in zip(texts, labels)]
def getAnnotatedFromPOS(texts, labels):
return [(t,l,colors[l]) for t, l in zip(texts, labels)]
def main():
st.title("Part of Speech Categorizer")
st.write("Paste or type text, submit and the machine will attempt to identify parts of speech. Please note that although the machine can read apostrophes, it cannot read other punctuation marks such as commas or periods.")
st.write(" The following is the part of speech category key: ADD Email, AFX Affix, CC Coordinating conjunction, CD Cardinal number, DT Determiner, EX Existential there, FW Foreign word, HYPH Hyphen, IN Preposition or subordinating conjunction, JJ Adjective, JJR Adjective, comparative, JJS Adjective, superlative, LS List item marker, MD Modal, NFP Superfluous punctuation, NN Noun, singular or mass, NNP Proper noun, singular, NNPS Proper noun, plural, NNS Noun, plural, PDT Predeterminer, POS Possessive ending, PRP Personal pronoun, PRP$ Possessive pronoun, RB Adverb, RBR Adverb, comparative, RBS Adverb, superlative, RP Particle, SYM Symbol, TO to, UH Interjection, VB Verb, base form, VBD Verb, past tense, VBG Verb, gerund or present participle, VBN Verb, past participle, VBP Verb, non-3rd person singular present, VBZ Verb, 3rd person singular present, WDT Wh-determiner, WP Wh-pronoun, WP$ Possessive wh-pronoun, WRB Wh-adverb, XX Unknown")
checkpoint = st.selectbox("Choose model", checkpoints)
model = get_model(checkpoint)
default_text = "Please note that although the machine can read apostrophes it cannot read other punctuation marks such as commas or periods"
input_text = st.text_area(
label="Original text",
value=default_text,
)
start = None
if st.button("Submit"):
start = time.time()
with st.spinner("Computing"):
# Build Sentence
s = Sentence(input_text)
# predict tags
model.predict(s)
try:
texts, labels = getPos(s)
st.header("Labels:")
anns = getAnnotatedFromPOS(texts, labels)
annotated_text(*anns)
st.header("JSON:")
st.json(getDictFromPOS(texts, labels))
except Exception as e:
st.error("Some error occured!" + str(e))
st.stop()
st.write("---")
if start is not None:
st.text(f"prediction took {time.time() - start:.2f}s")
if __name__ == "__main__":
main()