Spaces:
Running
Running
# imports | |
from transformers import pipeline | |
import gradio as gr | |
import pandas as pd | |
# define nlp mask | |
model = "siebert/sentiment-roberta-large-english" | |
nlp = pipeline(model=model) # set device=0 to use GPU (CPU default, -1) | |
# perform inference on given file | |
def inference(df, filename): | |
# texts & ids | |
texts = df[df.columns[1]].to_list() | |
ids = df[df.columns[0]].to_list() | |
# create new df based on csv inputs | |
new_df = pd.DataFrame(columns=[df.columns[0], df.columns[1], "Label", "Score"]) | |
# iterate over texts, perform inference | |
for index in range(len(texts)): | |
preds = nlp(texts[index]) | |
pred_sentiment = preds[0]["label"] | |
pred_score = preds[0]["score"] | |
print(texts[index]) | |
print(preds) | |
# write data into df | |
# predicted sentiment | |
new_df.at[index, "Label"] = pred_sentiment | |
# predicted score | |
new_df.at[index, "Score"] = pred_score | |
# write text | |
new_df.at[index, df.columns[1]] = texts[index] | |
# write ID | |
new_df.at[index, df.columns[0]] = ids[index] | |
# export new file | |
n_filename = filename.name.split(".")[0] + "_csiebert_sentiment.csv" | |
new_df.to_csv(n_filename, index=False) | |
# return new file | |
return n_filename | |
# handle file reading for both csv and excel files | |
def read_file(filename): | |
# check type of input file | |
if filename.name.split(".")[1] == "csv": | |
print("entered") | |
# read file, drop index if exists | |
df = pd.read_csv(filename.name, index_col=False) | |
# perform inference on given .csv file | |
result = inference(df=df, filename=filename) | |
print("computed") | |
return result | |
elif filename.name.split(".")[1] == "xlsx": | |
df = pd.read_excel(filename.name, index_col=False) | |
# handle Unnamed | |
if df.columns[0] == "Unnamed: 0": | |
df = df.drop("Unnamed: 0", axis=1) | |
# perform inference on given .xlsx file | |
result = inference(df=df, filename=filename) | |
return result | |
# if neither csv nor xlsx provided -> exit | |
else: | |
return | |
gr.Interface(read_file, | |
inputs=[gr.inputs.File(label="Input file")], | |
outputs=[gr.outputs.File(label="Output file")], | |
description="Sentiment analysis: Input a csv/xlsx of form ID, Text. App performs sentiment analysis on Texts and exports results as new csv to download.", | |
allow_flagging=False, | |
layout="horizontal", | |
).launch() |