File size: 628 Bytes
685ba0e
 
4b1cd4e
 
685ba0e
 
 
 
 
 
 
 
 
 
 
e567dcc
685ba0e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from typing import Dict, Optional

from transformers import TextClassificationPipeline


class NewsPipeline(TextClassificationPipeline):
    def __init__(self, emojis: Dict[str, str], **kwargs) -> None:
        self.emojis = emojis
        super().__init__(**kwargs)

    def __call__(self, headline: str, content: Optional[str]) -> str:
        if content:
            text = f" {self.tokenizer.sep_token} ".join([headline, content])
        else:
            text = headline
        prediction = super().__call__(text, padding=True, truncation=True)[0]
        return {**prediction, "emoji": self.emojis[prediction["label"]]}