File size: 2,602 Bytes
5c1dd23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1bcc63
5c1dd23
 
 
 
 
 
 
 
 
 
 
e1bcc63
5c1dd23
e1bcc63
5c1dd23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import time
import streamlit as st
from annotated_text import annotated_text

from flair.data import Sentence
from flair.models import SequenceTagger

checkpoints = [
   "flair/pos-english",
]

colors = {'ADD': '#b9d9a6', 'AFX': '#eddc92', 'CC': '#95e9d7', 'CD': '#e797db', 'DT': '#9ff48b', 'EX': '#ed92b4', 'FW': '#decfa1', 'HYPH': '#ada7d7', 'IN': '#85fad8', 'JJ': '#8ba4f4', 'JJR': '#e7a498', 'JJS': '#e5c79a', 'LS': '#eb94b6', 'MD': '#e698ae', 'NFP': '#d9d1a6', 'NN': '#96e89f', 'NNP': '#e698c6', 'NNPS': '#ddbfa2', 'NNS': '#f788cd', 'PDT': '#f19c8d', 'POS': '#8ed5f0', 'PRP': '#c4d8a6', 'PRP$': '#e39bdc', 'RB': '#8df1e2', 'RBR': '#d7f787', 'RBS': '#f986f0', 'RP': '#878df8', 'SYM': '#83fe80', 'TO': '#a6d8c9', 'UH': '#d9a6cc', 'VB': '#a1deda', 'VBD': '#8fefe1', 'VBG': '#e3c79b', 'VBN': '#fb81fe', 'VBP': '#d5fe81', 'VBZ': '#8084ff', 'WDT': '#dd80fe', 'WP': '#9ce3e3', 'WP$': '#9fbddf', 'WRB': '#dea1b5', 'XX': '#93b8ec'}

@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_model(model_name):
  return SequenceTagger.load(model_name) # Load the model
  
def getPos(s: Sentence):
  texts = []
  labels = []
  for t in s.tokens:
    for label in t.annotation_layers.keys():
      texts.append(t.text)
      labels.append(t.get_labels(label)[0].value)          
  return texts, labels
  
def getDictFromPOS(texts, labels):
  return [{ "text": t, "label": l } for t, l in zip(texts, labels)]

def getAnnotatedFromPOS(texts, labels):
  return [(t,l,colors[l]) for t, l in zip(texts, labels)]

def main():

  st.title("Part of Speech Categorizer")

  checkpoint = st.selectbox("Choose model", checkpoints)
  model = get_model(checkpoint)
 
  default_text = "This is an example sentence."
  input_text = st.text_area(
    label="Original text",
    value=default_text,
  )

  start = None
  if st.button("Submit"):
    start = time.time()
    with st.spinner("Computing"):
    

            # Build Sentence
            s = Sentence(input_text)
	
            # predict tags
            model.predict(s)

            try:
	
                texts, labels = getPos(s)
	                
                st.header("Labels:")
                anns = getAnnotatedFromPOS(texts, labels)
                annotated_text(*anns)

                st.header("JSON:")
                st.json(getDictFromPOS(texts, labels))

            except Exception as e:
                st.error("Some error occured!" + str(e))
                st.stop()
	
    st.write("---")


	
    if start is not None:
        st.text(f"prediction took {time.time() - start:.2f}s")
	
	
if __name__ == "__main__":
    main()