qanastek commited on
Commit
9a31420
1 Parent(s): 49c3a11
Files changed (2) hide show
  1. app.py +7 -4
  2. requirements.txt +1 -1
app.py CHANGED
@@ -3,7 +3,7 @@ import random
3
  import streamlit as st
4
  from annotated_text import annotated_text
5
 
6
- import matplotlib
7
 
8
  from flair.data import Sentence
9
  from flair.models import SequenceTagger
@@ -12,8 +12,6 @@ checkpoints = [
12
  "qanastek/pos-french",
13
  ]
14
 
15
- colors = list(matplotlib.colors.cnames.values())
16
-
17
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
18
  def get_model(model_name):
19
  return SequenceTagger.load(model_name) # Load the model
@@ -30,8 +28,13 @@ def getPos(s: Sentence):
30
  def getDictFromPOS(texts, labels):
31
  return [{ "text": t, "label": l } for t, l in zip(texts, labels)]
32
 
 
 
 
 
 
33
  def getAnnotatedFromPOS(texts, labels):
34
- return [(t,l,random.choice(colors)) for t, l in zip(texts, labels)]
35
 
36
  def main():
37
 
3
  import streamlit as st
4
  from annotated_text import annotated_text
5
 
6
+ from colour import Color
7
 
8
  from flair.data import Sentence
9
  from flair.models import SequenceTagger
12
  "qanastek/pos-french",
13
  ]
14
 
 
 
15
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
16
  def get_model(model_name):
17
  return SequenceTagger.load(model_name) # Load the model
28
  def getDictFromPOS(texts, labels):
29
  return [{ "text": t, "label": l } for t, l in zip(texts, labels)]
30
 
31
+ def randomColor():
32
+ rgb = (random.uniform(0.0,1.0), random.uniform(0.0,1.0), random.uniform(0.0,1.0))
33
+ hsl = (random.uniform(0.0,1.0), random.uniform(0.35,1.0), 0.62)
34
+ return str(Color(rgb=rgb, hsl=hsl))
35
+
36
  def getAnnotatedFromPOS(texts, labels):
37
+ return [(t,l,randomColor()) for t, l in zip(texts, labels)]
38
 
39
  def main():
40
 
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
  flair==0.8.0.post1
2
  st-annotated-text
3
- matplotlib
1
  flair==0.8.0.post1
2
  st-annotated-text
3
+ colour