File size: 3,875 Bytes
5c1dd23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1bcc63
46170b1
977de64
 
5c1dd23
 
 
a89e61f
5c1dd23
 
 
 
 
 
e1bcc63
5c1dd23
e1bcc63
5c1dd23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import time
import streamlit as st
from annotated_text import annotated_text

from flair.data import Sentence
from flair.models import SequenceTagger

checkpoints = [
   "flair/pos-english",
]

colors = {'ADD': '#b9d9a6', 'AFX': '#eddc92', 'CC': '#95e9d7', 'CD': '#e797db', 'DT': '#9ff48b', 'EX': '#ed92b4', 'FW': '#decfa1', 'HYPH': '#ada7d7', 'IN': '#85fad8', 'JJ': '#8ba4f4', 'JJR': '#e7a498', 'JJS': '#e5c79a', 'LS': '#eb94b6', 'MD': '#e698ae', 'NFP': '#d9d1a6', 'NN': '#96e89f', 'NNP': '#e698c6', 'NNPS': '#ddbfa2', 'NNS': '#f788cd', 'PDT': '#f19c8d', 'POS': '#8ed5f0', 'PRP': '#c4d8a6', 'PRP$': '#e39bdc', 'RB': '#8df1e2', 'RBR': '#d7f787', 'RBS': '#f986f0', 'RP': '#878df8', 'SYM': '#83fe80', 'TO': '#a6d8c9', 'UH': '#d9a6cc', 'VB': '#a1deda', 'VBD': '#8fefe1', 'VBG': '#e3c79b', 'VBN': '#fb81fe', 'VBP': '#d5fe81', 'VBZ': '#8084ff', 'WDT': '#dd80fe', 'WP': '#9ce3e3', 'WP$': '#9fbddf', 'WRB': '#dea1b5', 'XX': '#93b8ec'}

@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_model(model_name):
  return SequenceTagger.load(model_name) # Load the model
  
def getPos(s: Sentence):
  texts = []
  labels = []
  for t in s.tokens:
    for label in t.annotation_layers.keys():
      texts.append(t.text)
      labels.append(t.get_labels(label)[0].value)          
  return texts, labels
  
def getDictFromPOS(texts, labels):
  return [{ "text": t, "label": l } for t, l in zip(texts, labels)]

def getAnnotatedFromPOS(texts, labels):
  return [(t,l,colors[l]) for t, l in zip(texts, labels)]

def main():

  st.title("Part of Speech Categorizer")
  st.write("Paste or type text, submit and the machine will attempt to identify parts of speech. Please note that although the machine can read apostrophes, it cannot read other punctuation marks such as commas or periods.")
  st.write(" The following is the part of speech category key: ADD	Email, AFX	Affix, CC	Coordinating conjunction, CD	Cardinal number, DT	Determiner, EX	Existential there, FW	Foreign word, HYPH	Hyphen, IN	Preposition or subordinating conjunction, JJ	Adjective, JJR	Adjective, comparative, JJS	Adjective, superlative, LS	List item marker, MD	Modal, NFP	Superfluous punctuation, NN	Noun, singular or mass, NNP	Proper noun, singular, NNPS	Proper noun, plural, NNS	Noun, plural, PDT	Predeterminer, POS	Possessive ending, PRP	Personal pronoun, PRP$	Possessive pronoun, RB	Adverb, RBR	Adverb, comparative, RBS	Adverb, superlative, RP	Particle, SYM	Symbol, TO	to, UH	Interjection, VB	Verb, base form, VBD	Verb, past tense, VBG	Verb, gerund or present participle, VBN	Verb, past participle, VBP	Verb, non-3rd person singular present, VBZ	Verb, 3rd person singular present, WDT	Wh-determiner, WP	Wh-pronoun, WP$	Possessive wh-pronoun, WRB	Wh-adverb, XX	Unknown")

  checkpoint = st.selectbox("Choose model", checkpoints)
  model = get_model(checkpoint)
 
  default_text = "Please note that although the machine can read apostrophes it cannot read other punctuation marks such as commas or periods"
  input_text = st.text_area(
    label="Original text",
    value=default_text,
  )

  start = None
  if st.button("Submit"):
    start = time.time()
    with st.spinner("Computing"):
    

            # Build Sentence
            s = Sentence(input_text)
	
            # predict tags
            model.predict(s)

            try:
	
                texts, labels = getPos(s)
	                
                st.header("Labels:")
                anns = getAnnotatedFromPOS(texts, labels)
                annotated_text(*anns)

                st.header("JSON:")
                st.json(getDictFromPOS(texts, labels))

            except Exception as e:
                st.error("Some error occured!" + str(e))
                st.stop()
	
    st.write("---")


	
    if start is not None:
        st.text(f"prediction took {time.time() - start:.2f}s")
	
	
if __name__ == "__main__":
    main()