elozano commited on
Commit
30ad188
1 Parent(s): deefcae

NER model added

Browse files
Files changed (3) hide show
  1. app.py +29 -10
  2. news_pipeline.py +17 -0
  3. requirements.txt +1 -0
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
-
3
  from news_pipeline import NewsPipeline
4
 
5
  CATEGORY_EMOJIS = {
@@ -34,15 +34,34 @@ def app():
34
 
35
  with st.spinner("Analyzing article..."):
36
  prediction = news_pipe(headline, content)
37
- st.markdown(
38
- f"{CATEGORY_EMOJIS[prediction['category']]} **Category**: {prediction['category']}"
39
- )
40
- st.markdown(
41
- f"{FAKE_EMOJIS[prediction['fake']]} **Fake**: {'Yes' if prediction['fake'] == 'Fake' else 'No'}"
42
- )
43
- st.markdown(
44
- f"{CLICKBAIT_EMOJIS[prediction['clickbait']]} **Clickbait**: {'Yes' if prediction['clickbait'] == 'Clickbait' else 'No'}"
45
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
48
  if __name__ == "__main__":
 
1
  import streamlit as st
2
+ from annotated_text import annotated_text
3
  from news_pipeline import NewsPipeline
4
 
5
  CATEGORY_EMOJIS = {
 
34
 
35
  with st.spinner("Analyzing article..."):
36
  prediction = news_pipe(headline, content)
37
+ col1, _, col2 = st.columns([2, 1, 6])
38
+ with col1:
39
+ st.subheader("Analysis:")
40
+ st.markdown(
41
+ f"{CATEGORY_EMOJIS[prediction['category']]} **Category**: {prediction['category']}"
42
+ )
43
+ st.markdown(
44
+ f"{FAKE_EMOJIS[prediction['fake']]} **Fake**: {'Yes' if prediction['fake'] == 'Fake' else 'No'}"
45
+ )
46
+ st.markdown(
47
+ f"{CLICKBAIT_EMOJIS[prediction['clickbait']]} **Clickbait**: {'Yes' if prediction['clickbait'] == 'Clickbait' else 'No'}"
48
+ )
49
+ with col2:
50
+ st.subheader("Headline")
51
+ annotated_text(*parse_text(headline, prediction["ner"]["headline"]))
52
+ st.subheader("Content")
53
+ annotated_text(*parse_text(content, prediction["ner"]["content"]))
54
+
55
+
56
+ def parse_text(text, prediction):
57
+ start = 0
58
+ parsed_text = []
59
+ for p in prediction:
60
+ parsed_text.append(text[start : p["start"]])
61
+ parsed_text.append((p["word"], p["entity_group"]))
62
+ start = p["end"]
63
+ parsed_text.append(text[start:])
64
+ return parsed_text
65
 
66
 
67
  if __name__ == "__main__":
news_pipeline.py CHANGED
@@ -2,8 +2,10 @@ from typing import Dict
2
 
3
  from transformers import (
4
  AutoModelForSequenceClassification,
 
5
  AutoTokenizer,
6
  TextClassificationPipeline,
 
7
  )
8
 
9
 
@@ -29,6 +31,13 @@ class NewsPipeline:
29
  ),
30
  tokenizer=AutoTokenizer.from_pretrained("elozano/news-clickbait"),
31
  )
 
 
 
 
 
 
 
32
 
33
  def __call__(self, headline: str, content: str) -> Dict[str, str]:
34
  category_article_text = f" {self.category_tokenizer.sep_token} ".join(
@@ -41,4 +50,12 @@ class NewsPipeline:
41
  "category": self.category_pipeline(category_article_text)[0]["label"],
42
  "fake": self.fake_pipeline(fake_article_text)[0]["label"],
43
  "clickbait": self.clickbait_pipeline(headline)[0]["label"],
 
 
 
 
 
 
 
 
44
  }
 
2
 
3
  from transformers import (
4
  AutoModelForSequenceClassification,
5
+ AutoModelForTokenClassification,
6
  AutoTokenizer,
7
  TextClassificationPipeline,
8
+ TokenClassificationPipeline,
9
  )
10
 
11
 
 
31
  ),
32
  tokenizer=AutoTokenizer.from_pretrained("elozano/news-clickbait"),
33
  )
34
+ self.ner_pipeline = TokenClassificationPipeline(
35
+ tokenizer=AutoTokenizer.from_pretrained("dslim/bert-base-NER"),
36
+ model=AutoModelForTokenClassification.from_pretrained(
37
+ "dslim/bert-base-NER"
38
+ ),
39
+ aggregation_strategy="simple",
40
+ )
41
 
42
  def __call__(self, headline: str, content: str) -> Dict[str, str]:
43
  category_article_text = f" {self.category_tokenizer.sep_token} ".join(
 
50
  "category": self.category_pipeline(category_article_text)[0]["label"],
51
  "fake": self.fake_pipeline(fake_article_text)[0]["label"],
52
  "clickbait": self.clickbait_pipeline(headline)[0]["label"],
53
+ "ner": {
54
+ "headline": list(
55
+ filter(lambda x: x["score"] > 0.8, self.ner_pipeline(headline))
56
+ ),
57
+ "content": list(
58
+ filter(lambda x: x["score"] > 0.8, self.ner_pipeline(content))
59
+ ),
60
+ },
61
  }
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  transformers
2
  torch
 
 
1
  transformers
2
  torch
3
+ st-annotated-text