File size: 2,672 Bytes
685ba0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import Dict, Optional, Union

from transformers import (
    AutoModelForSequenceClassification,
    AutoModelForTokenClassification,
    AutoTokenizer,
    TokenClassificationPipeline,
)

from pipeline import NewsPipeline

CATEGORY_EMOJIS = {
    "Automobile": "πŸš—",
    "Entertainment": "🍿",
    "Politics": "βš–οΈ",
    "Science": "πŸ§ͺ",
    "Sports": "πŸ€",
    "Technology": "πŸ’»",
    "World": "🌍",
}
FAKE_EMOJIS = {"Fake": "πŸ‘»", "Real": "πŸ‘"}
CLICKBAIT_EMOJIS = {"Clickbait": "🎣", "Normal": "βœ…"}


class NewsAnalyzer:
    def __init__(
        self,
        category_model_name: str,
        fake_model_name: str,
        clickbait_model_name: str,
        ner_model_name: str,
    ) -> None:
        self.category_pipe = NewsPipeline(
            model=AutoModelForSequenceClassification.from_pretrained(
                category_model_name
            ),
            tokenizer=AutoTokenizer.from_pretrained(category_model_name),
            emojis=CATEGORY_EMOJIS,
        )
        self.fake_pipe = NewsPipeline(
            model=AutoModelForSequenceClassification.from_pretrained(fake_model_name),
            tokenizer=AutoTokenizer.from_pretrained(fake_model_name),
            emojis=FAKE_EMOJIS,
        )
        self.clickbait_pipe = NewsPipeline(
            model=AutoModelForSequenceClassification.from_pretrained(
                clickbait_model_name
            ),
            tokenizer=AutoTokenizer.from_pretrained(clickbait_model_name),
            emojis=CLICKBAIT_EMOJIS,
        )
        self.ner_pipe = TokenClassificationPipeline(
            model=AutoModelForTokenClassification.from_pretrained(ner_model_name),
            tokenizer=AutoTokenizer.from_pretrained(ner_model_name),
            aggregation_strategy="simple",
        )

    def __call__(
        self, headline: str, content: Optional[str] = None
    ) -> Dict[str, Union[str, float]]:
        return {
            "category": self.category_pipe(headline=headline, content=content),
            "fake": self.fake_pipe(headline=headline, content=content),
            "clickbait": self.clickbait_pipe(headline=headline, content=None),
            "ner": {
                "headline": self.ner_pipe(headline),
                "content": self.ner_pipe(content) if content else None,
            },
        }


if __name__ == "__main__":
    analyzer = NewsAnalyzer(
        category_model_name="elozano/news-category",
        fake_model_name="elozano/news-fake",
        clickbait_model_name="elozano/news-clickbait",
        ner_model_name="dslim/bert-base-NER",
    )
    prediction = analyzer(headline="Lakers Won!")
    print(prediction)