ongkn commited on
Commit
5fa4331
1 Parent(s): f672687

Upload tool

Browse files
Files changed (4) hide show
  1. app.py +4 -0
  2. requirements.txt +3 -0
  3. text_sentiment_analyzer.py +58 -0
  4. tool_config.json +5 -0
app.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from transformers import launch_gradio_demo
2
+ from text_sentiment_analyzer import SentAnalClassifierTool
3
+
4
+ launch_gradio_demo(SentAnalClassifierTool)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ trainDistilBERT
text_sentiment_analyzer.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A script for a text sentiment analysis tool for the 🤗 Transformers Agent library.
3
+ """
4
+
5
+ from transformers import Tool
6
+ from transformers.tools.base import get_default_device
7
+ from transformers import pipeline
8
+ from transformers import DistilBertTokenizerFast
9
+ from trainDistilBERT import DistilBertForMulticlassSequenceClassification
10
+ import torch
11
+
12
+
13
+
14
+ class SentAnalClassifierTool(Tool):
15
+ """
16
+ A tool for sentiment analysis
17
+ """
18
+ ckpt = "ongknsro/ACARISBERT-DistilBERT"
19
+ name = "text_sentiment_analyzer"
20
+ description = (
21
+ "This is a tool that returns a sentiment label for a given text sequence. "
22
+ "It takes the raw text as input, and "
23
+ "returns a sentiment label as output."
24
+ )
25
+
26
+ inputs = ["text"]
27
+ outputs = ["text"]
28
+
29
+ def __init__(self, device=None, **hub_kwargs) -> None:
30
+ super().__init__()
31
+
32
+ self.device = device
33
+ self.pipeline = None
34
+ self.hub_kwargs = hub_kwargs
35
+
36
+ def setup(self):
37
+ if self.device is None:
38
+ self.device = get_default_device()
39
+
40
+ self.tokenizer = DistilBertTokenizerFast.from_pretrained(self.ckpt)
41
+
42
+ self.model = DistilBertForMulticlassSequenceClassification.from_pretrained(self.ckpt).to(self.device)
43
+
44
+ self.pipeline = pipeline("sentiment-analysis", model=self.model, tokenizer=self.tokenizer, top_k=None, device=0)
45
+
46
+ self.is_initialized = True
47
+
48
+ def __call__(self, task: str):
49
+ if not self.is_initialized:
50
+ self.setup()
51
+
52
+ outputs = self.pipeline(task)
53
+ labels = [item["label"] for item in outputs[0]]
54
+ logits = [item["score"] for item in outputs[0]]
55
+ probs = torch.softmax(torch.tensor(logits), dim=0)
56
+ label = labels[torch.argmax(probs).item()]
57
+
58
+ return label
tool_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "description": "This is a tool that returns a sentiment label for a given text sequence. It takes the raw text as input, and returns a sentiment label as output.",
3
+ "name": "text_sentiment_analyzer",
4
+ "tool_class": "text_sentiment_analyzer.SentAnalClassifierTool"
5
+ }