text_sentiment_analyzer / text_sentiment_analyzer.py
ongkn's picture
Update text_sentiment_analyzer.py
fa999a3
"""
A script for a text sentiment analysis tool for the 🤗 Transformers Agent library.
"""
from transformers import Tool
from transformers.tools.base import get_default_device
from transformers import pipeline
from transformers import DistilBertTokenizerFast
from trainDistilBERT import DistilBertForMulticlassSequenceClassification
import torch
class SentAnalClassifierTool(Tool):
"""
A tool for sentiment analysis
"""
ckpt = "ongknsro/ACARISBERT-DistilBERT"
name = "text_sentiment_analyzer"
description = (
"This is a tool that returns a sentiment label for a given text sequence. "
"It takes raw text as input, and "
"returns a sentiment label as output."
)
inputs = ["text"]
outputs = ["text"]
def __init__(self, device=None, **hub_kwargs) -> None:
super().__init__()
self.device = device
self.pipeline = None
self.hub_kwargs = hub_kwargs
def setup(self):
if self.device is None:
self.device = get_default_device()
self.tokenizer = DistilBertTokenizerFast.from_pretrained(self.ckpt)
self.model = DistilBertForMulticlassSequenceClassification.from_pretrained(self.ckpt).to(self.device)
self.pipeline = pipeline("sentiment-analysis", model=self.model, tokenizer=self.tokenizer, top_k=None, device=0)
self.is_initialized = True
def __call__(self, task: str):
if not self.is_initialized:
self.setup()
outputs = self.pipeline(task)
labels = [item["label"] for item in outputs[0]]
logits = [item["score"] for item in outputs[0]]
probs = torch.softmax(torch.tensor(logits), dim=0)
label = labels[torch.argmax(probs).item()]
return label