elozano commited on
Commit
4b1cd4e
1 Parent(s): bf18bb1

Custom entities color

Browse files
Files changed (2) hide show
  1. app.py +16 -3
  2. pipeline.py +2 -1
app.py CHANGED
@@ -5,6 +5,13 @@ from annotated_text import annotated_text
5
 
6
  from analyzer import NewsAnalyzer
7
 
 
 
 
 
 
 
 
8
 
9
  def run() -> None:
10
  analyzer = NewsAnalyzer(
@@ -15,7 +22,7 @@ def run() -> None:
15
  )
16
  st.title("📰 News Analyzer")
17
  headline = st.text_input("Headline:")
18
- content = st.text_input("Content:")
19
  if headline == "":
20
  st.error("Please, provide a headline.")
21
  else:
@@ -26,7 +33,7 @@ def run() -> None:
26
  button = st.button("Analyze")
27
  if button:
28
  predictions = analyzer(headline=headline, content=content)
29
- col1, _, col2 = st.columns([2, 1, 5])
30
 
31
  with col1:
32
  st.subheader("Analysis:")
@@ -64,7 +71,13 @@ def parse_entities(
64
  parsed_text = []
65
  for entity in entities:
66
  parsed_text.append(text[start : entity["start"]])
67
- parsed_text.append((entity["word"], entity["entity_group"]))
 
 
 
 
 
 
68
  start = entity["end"]
69
  parsed_text.append(text[start:])
70
  return parsed_text
 
5
 
6
  from analyzer import NewsAnalyzer
7
 
8
+ ENTITY_COLOR = {
9
+ "PER": "#b2ffff",
10
+ "LOC": "#ffffb2",
11
+ "ORG": "#adfbaf",
12
+ "MISC": "#ffb2b2",
13
+ }
14
+
15
 
16
  def run() -> None:
17
  analyzer = NewsAnalyzer(
 
22
  )
23
  st.title("📰 News Analyzer")
24
  headline = st.text_input("Headline:")
25
+ content = st.text_area("Content:", height=200)
26
  if headline == "":
27
  st.error("Please, provide a headline.")
28
  else:
 
33
  button = st.button("Analyze")
34
  if button:
35
  predictions = analyzer(headline=headline, content=content)
36
+ col1, _, col2 = st.columns([2, 1, 4])
37
 
38
  with col1:
39
  st.subheader("Analysis:")
 
71
  parsed_text = []
72
  for entity in entities:
73
  parsed_text.append(text[start : entity["start"]])
74
+ parsed_text.append(
75
+ (
76
+ entity["word"],
77
+ entity["entity_group"],
78
+ ENTITY_COLOR[entity["entity_group"]],
79
+ )
80
+ )
81
  start = entity["end"]
82
  parsed_text.append(text[start:])
83
  return parsed_text
pipeline.py CHANGED
@@ -1,6 +1,7 @@
1
- from transformers import TextClassificationPipeline
2
  from typing import Dict, Optional
3
 
 
 
4
 
5
  class NewsPipeline(TextClassificationPipeline):
6
  def __init__(self, emojis: Dict[str, str], **kwargs) -> None:
 
 
1
  from typing import Dict, Optional
2
 
3
+ from transformers import TextClassificationPipeline
4
+
5
 
6
  class NewsPipeline(TextClassificationPipeline):
7
  def __init__(self, emojis: Dict[str, str], **kwargs) -> None: