File size: 2,182 Bytes
00fe6e7
 
 
 
 
 
 
 
 
beb81ec
 
 
00fe6e7
35c9bc8
00fe6e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
beb81ec
00fe6e7
beb81ec
00fe6e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de9d075
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from __future__ import annotations

import gradio as gr
import torch
import os
import polars as pl
import re
import json
from datetime import datetime, timezone, timedelta
from optimum.pipelines import pipeline
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer

from hf_dataset_saver import HuggingFaceDatasetSaver

# Get environment variable
hf_token = os.getenv('HF_TOKEN')

if torch.cuda.is_available():
    print("GPU is enabled.")
    print("device count: {}, current device: {}".format(torch.cuda.device_count(), torch.cuda.current_device()))
else:
    print("GPU is not enabled.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Prepare logger for flagging
hf_writer = HuggingFaceDatasetSaver(hf_token, "crowdsourced-sentiment_analysis")

# Prepare model
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base", token=hf_token)
model = ORTModelForSequenceClassification.from_pretrained("arcleife/roberta-sentiment-id-onnx", num_labels=3, token=hf_token).to(device)

pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=device, return_token_type_ids=False, accelerator="ort")

def get_label(result):
    if result[0]['label'] == "LABEL_0":
        return "POSITIVE"
    elif result[0]['label'] == "LABEL_1":
        return "NEUTRAL"
    else:
        return "NEGATIVE"
    
def text_classification(text):
    result = pipe(text)
    sentiment_label = get_label(result)
    sentiment_score = result[0]['score']
    return sentiment_label, sentiment_score

examples=["Makanannya ga enak ini", "Nyaman ya tempatnya"]

io = gr.Interface(fn=text_classification, 
                  inputs=gr.Textbox(lines=2, label="Text", placeholder="Enter text here..."), 
                  outputs=["text", "number"],
                  title="Text Classification",
                  description="Enter a text and see the text classification result!",
                  examples=examples,
                  # flagging_mode="manual",
                  # flagging_options=["TOXIC", "NONTOXIC"],
                  # flagging_callback=hf_writer
                 )

io.launch(inline=False)