from typing import Dict from transformers import ( AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoTokenizer, TextClassificationPipeline, TokenClassificationPipeline, ) class NewsPipeline: def __init__(self) -> None: self.category_tokenizer = AutoTokenizer.from_pretrained("elozano/news-category") self.category_pipeline = TextClassificationPipeline( model=AutoModelForSequenceClassification.from_pretrained( "elozano/news-category" ), tokenizer=self.category_tokenizer, ) self.fake_tokenizer = AutoTokenizer.from_pretrained("elozano/news-fake") self.fake_pipeline = TextClassificationPipeline( model=AutoModelForSequenceClassification.from_pretrained( "elozano/news-fake" ), tokenizer=self.fake_tokenizer, ) self.clickbait_pipeline = TextClassificationPipeline( model=AutoModelForSequenceClassification.from_pretrained( "elozano/news-clickbait" ), tokenizer=AutoTokenizer.from_pretrained("elozano/news-clickbait"), ) self.ner_pipeline = TokenClassificationPipeline( tokenizer=AutoTokenizer.from_pretrained("dslim/bert-base-NER"), model=AutoModelForTokenClassification.from_pretrained( "dslim/bert-base-NER" ), aggregation_strategy="simple", ) def __call__(self, headline: str, content: str) -> Dict[str, str]: category_article_text = f" {self.category_tokenizer.sep_token} ".join( [headline, content] ) fake_article_text = f" {self.fake_tokenizer.sep_token} ".join( [headline, content] ) return { "category": self.category_pipeline(category_article_text)[0]["label"], "fake": self.fake_pipeline(fake_article_text)[0]["label"], "clickbait": self.clickbait_pipeline(headline)[0]["label"], "ner": { "headline": list( filter(lambda x: x["score"] > 0.8, self.ner_pipeline(headline)) ), "content": list( filter(lambda x: x["score"] > 0.8, self.ner_pipeline(content)) ), }, }