""" A script for a text sentiment analysis tool for the 🤗 Transformers Agent library. """ from transformers import Tool from transformers.tools.base import get_default_device from transformers import pipeline from transformers import DistilBertTokenizerFast from trainDistilBERT import DistilBertForMulticlassSequenceClassification import torch class SentAnalClassifierTool(Tool): """ A tool for sentiment analysis """ ckpt = "ongknsro/ACARISBERT-DistilBERT" name = "text_sentiment_analyzer" description = ( "This is a tool that returns a sentiment label for a given text sequence. " "It takes raw text as input, and " "returns a sentiment label as output." ) inputs = ["text"] outputs = ["text"] def __init__(self, device=None, **hub_kwargs) -> None: super().__init__() self.device = device self.pipeline = None self.hub_kwargs = hub_kwargs def setup(self): if self.device is None: self.device = get_default_device() self.tokenizer = DistilBertTokenizerFast.from_pretrained(self.ckpt) self.model = DistilBertForMulticlassSequenceClassification.from_pretrained(self.ckpt).to(self.device) self.pipeline = pipeline("sentiment-analysis", model=self.model, tokenizer=self.tokenizer, top_k=None, device=0) self.is_initialized = True def __call__(self, task: str): if not self.is_initialized: self.setup() outputs = self.pipeline(task) labels = [item["label"] for item in outputs[0]] logits = [item["score"] for item in outputs[0]] probs = torch.softmax(torch.tensor(logits), dim=0) label = labels[torch.argmax(probs).item()] return label