import time import streamlit as st from annotated_text import annotated_text from flair.data import Sentence from flair.models import SequenceTagger checkpoints = [ "flair/pos-english", ] colors = {'ADD': '#b9d9a6', 'AFX': '#eddc92', 'CC': '#95e9d7', 'CD': '#e797db', 'DT': '#9ff48b', 'EX': '#ed92b4', 'FW': '#decfa1', 'HYPH': '#ada7d7', 'IN': '#85fad8', 'JJ': '#8ba4f4', 'JJR': '#e7a498', 'JJS': '#e5c79a', 'LS': '#eb94b6', 'MD': '#e698ae', 'NFP': '#d9d1a6', 'NN': '#96e89f', 'NNP': '#e698c6', 'NNPS': '#ddbfa2', 'NNS': '#f788cd', 'PDT': '#f19c8d', 'POS': '#8ed5f0', 'PRP': '#c4d8a6', 'PRP$': '#e39bdc', 'RB': '#8df1e2', 'RBR': '#d7f787', 'RBS': '#f986f0', 'RP': '#878df8', 'SYM': '#83fe80', 'TO': '#a6d8c9', 'UH': '#d9a6cc', 'VB': '#a1deda', 'VBD': '#8fefe1', 'VBG': '#e3c79b', 'VBN': '#fb81fe', 'VBP': '#d5fe81', 'VBZ': '#8084ff', 'WDT': '#dd80fe', 'WP': '#9ce3e3', 'WP$': '#9fbddf', 'WRB': '#dea1b5', 'XX': '#93b8ec'} @st.cache(suppress_st_warning=True, allow_output_mutation=True) def get_model(model_name): return SequenceTagger.load(model_name) # Load the model def getPos(s: Sentence): texts = [] labels = [] for t in s.tokens: for label in t.annotation_layers.keys(): texts.append(t.text) labels.append(t.get_labels(label)[0].value) return texts, labels def getDictFromPOS(texts, labels): return [{ "text": t, "label": l } for t, l in zip(texts, labels)] def getAnnotatedFromPOS(texts, labels): return [(t,l,colors[l]) for t, l in zip(texts, labels)] def main(): st.title("Part of Speech Categorizer") st.write("Paste or type text, submit and the machine will attempt to identify parts of speech. Please note that although the machine can read apostrophes, it cannot read other punctuation marks such as commas or periods.") st.write(" The following is the part of speech category key: ADD Email, AFX Affix, CC Coordinating conjunction, CD Cardinal number, DT Determiner, EX Existential there, FW Foreign word, HYPH Hyphen, IN Preposition or subordinating conjunction, JJ Adjective, JJR Adjective, comparative, JJS Adjective, superlative, LS List item marker, MD Modal, NFP Superfluous punctuation, NN Noun, singular or mass, NNP Proper noun, singular, NNPS Proper noun, plural, NNS Noun, plural, PDT Predeterminer, POS Possessive ending, PRP Personal pronoun, PRP$ Possessive pronoun, RB Adverb, RBR Adverb, comparative, RBS Adverb, superlative, RP Particle, SYM Symbol, TO to, UH Interjection, VB Verb, base form, VBD Verb, past tense, VBG Verb, gerund or present participle, VBN Verb, past participle, VBP Verb, non-3rd person singular present, VBZ Verb, 3rd person singular present, WDT Wh-determiner, WP Wh-pronoun, WP$ Possessive wh-pronoun, WRB Wh-adverb, XX Unknown") checkpoint = st.selectbox("Choose model", checkpoints) model = get_model(checkpoint) default_text = "Please note that although the machine can read apostrophes it cannot read other punctuation marks such as commas or periods" input_text = st.text_area( label="Original text", value=default_text, ) start = None if st.button("Submit"): start = time.time() with st.spinner("Computing"): # Build Sentence s = Sentence(input_text) # predict tags model.predict(s) try: texts, labels = getPos(s) st.header("Labels:") anns = getAnnotatedFromPOS(texts, labels) annotated_text(*anns) st.header("JSON:") st.json(getDictFromPOS(texts, labels)) except Exception as e: st.error("Some error occured!" + str(e)) st.stop() st.write("---") if start is not None: st.text(f"prediction took {time.time() - start:.2f}s") if __name__ == "__main__": main()