news-analyzer / news_pipeline.py
elozano's picture
NER model added
30ad188
raw
history blame
No virus
2.36 kB
from typing import Dict
from transformers import (
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoTokenizer,
TextClassificationPipeline,
TokenClassificationPipeline,
)
class NewsPipeline:
def __init__(self) -> None:
self.category_tokenizer = AutoTokenizer.from_pretrained("elozano/news-category")
self.category_pipeline = TextClassificationPipeline(
model=AutoModelForSequenceClassification.from_pretrained(
"elozano/news-category"
),
tokenizer=self.category_tokenizer,
)
self.fake_tokenizer = AutoTokenizer.from_pretrained("elozano/news-fake")
self.fake_pipeline = TextClassificationPipeline(
model=AutoModelForSequenceClassification.from_pretrained(
"elozano/news-fake"
),
tokenizer=self.fake_tokenizer,
)
self.clickbait_pipeline = TextClassificationPipeline(
model=AutoModelForSequenceClassification.from_pretrained(
"elozano/news-clickbait"
),
tokenizer=AutoTokenizer.from_pretrained("elozano/news-clickbait"),
)
self.ner_pipeline = TokenClassificationPipeline(
tokenizer=AutoTokenizer.from_pretrained("dslim/bert-base-NER"),
model=AutoModelForTokenClassification.from_pretrained(
"dslim/bert-base-NER"
),
aggregation_strategy="simple",
)
def __call__(self, headline: str, content: str) -> Dict[str, str]:
category_article_text = f" {self.category_tokenizer.sep_token} ".join(
[headline, content]
)
fake_article_text = f" {self.fake_tokenizer.sep_token} ".join(
[headline, content]
)
return {
"category": self.category_pipeline(category_article_text)[0]["label"],
"fake": self.fake_pipeline(fake_article_text)[0]["label"],
"clickbait": self.clickbait_pipeline(headline)[0]["label"],
"ner": {
"headline": list(
filter(lambda x: x["score"] > 0.8, self.ner_pipeline(headline))
),
"content": list(
filter(lambda x: x["score"] > 0.8, self.ner_pipeline(content))
),
},
}