arcleife's picture
Update app.py
4b95252 verified
from __future__ import annotations
import gradio as gr
import torch
import os
import polars as pl
import re
import json
from datetime import datetime, timezone, timedelta
from optimum.pipelines import pipeline
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer
from hf_dataset_saver import HuggingFaceDatasetSaver
# Get environment variable
hf_token = os.getenv('HF_TOKEN')
if torch.cuda.is_available():
print("GPU is enabled.")
print("device count: {}, current device: {}".format(torch.cuda.device_count(), torch.cuda.current_device()))
else:
print("GPU is not enabled.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Prepare logger for flagging
hf_writer = HuggingFaceDatasetSaver(hf_token, "crowdsourced-sentiment_analysis")
# Prepare model
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base", token=hf_token)
model = ORTModelForSequenceClassification.from_pretrained("arcleife/roberta-sentiment-id-onnx", num_labels=3, token=hf_token).to(device)
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=device, return_token_type_ids=False, accelerator="ort")
def get_label(result):
if result[0]['label'] == "LABEL_0":
return "POSITIVE"
elif result[0]['label'] == "LABEL_1":
return "NEUTRAL"
else:
return "NEGATIVE"
def text_classification(text):
result = pipe(text)
sentiment_label = get_label(result)
sentiment_score = result[0]['score']
return sentiment_label, sentiment_score
examples=["Makanannya ga enak ini", "Nyaman ya tempatnya"]
io = gr.Interface(fn=text_classification,
inputs=gr.Textbox(lines=2, label="Text", placeholder="Enter text here..."),
outputs=["text", "number"],
title="Text Classification",
description="Enter a text and see the text classification result!",
examples=examples,
# flagging_mode="manual",
# flagging_options=["TOXIC", "NONTOXIC"],
# flagging_callback=hf_writer
)
io.launch(inline=False)