qanastek commited on
Commit
869f68c
1 Parent(s): 9a31420
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -30,11 +30,16 @@ def getDictFromPOS(texts, labels):
30
 
31
  def randomColor():
32
  rgb = (random.uniform(0.0,1.0), random.uniform(0.0,1.0), random.uniform(0.0,1.0))
33
- hsl = (random.uniform(0.0,1.0), random.uniform(0.35,1.0), 0.62)
34
  return str(Color(rgb=rgb, hsl=hsl))
35
 
36
- def getAnnotatedFromPOS(texts, labels):
37
- return [(t,l,randomColor()) for t, l in zip(texts, labels)]
 
 
 
 
 
38
 
39
  def main():
40
 
@@ -42,6 +47,7 @@ def main():
42
 
43
  checkpoint = st.selectbox("Choose model", checkpoints)
44
  model = get_model(checkpoint)
 
45
 
46
  default_text = "George Washington est allé à Washington"
47
  input_text = st.text_area(
@@ -65,7 +71,7 @@ def main():
65
  texts, labels = getPos(s)
66
 
67
  st.header("Labels:")
68
- anns = getAnnotatedFromPOS(texts, labels)
69
  annotated_text(*anns)
70
 
71
  st.header("JSON:")
30
 
31
  def randomColor():
32
  rgb = (random.uniform(0.0,1.0), random.uniform(0.0,1.0), random.uniform(0.0,1.0))
33
+ hsl = (random.uniform(0.0,1.0), random.uniform(0.35,1.0), 0.75)
34
  return str(Color(rgb=rgb, hsl=hsl))
35
 
36
+ def get_colors(model):
37
+ labels = [t.decode("utf-8") for t in model.tag_dictionary.idx2item if t.isupper() and len(t) > 1]
38
+ colors = [randomColor() for t in model.tag_dictionary.idx2item if t.isupper() and len(t) > 1]
39
+ return dict(zip(labels, colors))
40
+
41
+ def getAnnotatedFromPOS(texts, labels, colors):
42
+ return [(t,l,colors[t]) for t, l in zip(texts, labels)]
43
 
44
  def main():
45
 
47
 
48
  checkpoint = st.selectbox("Choose model", checkpoints)
49
  model = get_model(checkpoint)
50
+ colors = get_colors(model)
51
 
52
  default_text = "George Washington est allé à Washington"
53
  input_text = st.text_area(
71
  texts, labels = getPos(s)
72
 
73
  st.header("Labels:")
74
+ anns = getAnnotatedFromPOS(texts, labels, colors)
75
  annotated_text(*anns)
76
 
77
  st.header("JSON:")