asingh6's picture
Update app.py
4d3f2ae
import gradio as gr
import pandas as pd
import tempfile
import itertools
import torch
import numpy as np
from numpy import dot
from numpy.linalg import norm, multi_dot
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer
def get_score(sentence1):
class SimpleDataset:
def __init__(self, tokenized_texts):
self.tokenized_texts = tokenized_texts
def __len__(self):
return len(self.tokenized_texts["input_ids"])
def __getitem__(self, idx):
return {k: v[idx] for k, v in self.tokenized_texts.items()}
model_name = "j-hartmann/emotion-english-distilroberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
trainer = Trainer(model=model)
lines_s = [sentence1]
tokenized_texts = tokenizer(lines_s, truncation=True, padding=True)
pred_dataset = SimpleDataset(tokenized_texts)
predictions = trainer.predict(pred_dataset)
preds = predictions.predictions.argmax(-1)
labels = pd.Series(preds).map(model.config.id2label)
scores = (np.exp(predictions[0])/np.exp(predictions[0]).sum(-1,keepdims=True)).max(1)
temp = (np.exp(predictions[0])/np.exp(predictions[0]).sum(-1, keepdims=True)).tolist()
stress = []
fear = []
joy = []
neutral = []
sadness = []
for i in range(len(lines_s)):
stress.append(round(temp[i][0], 3))
fear.append(round(temp[i][2], 3))
joy.append(round(temp[i][3], 3))
neutral.append(round(temp[i][4], 3))
sadness.append(round(temp[i][5], 3))
df = pd.DataFrame(list(zip(lines_s, labels, stress, fear, joy, neutral, sadness)),
columns=['text', 'maxLabel', 'stress', 'fear', 'joy', 'neutral', 'sadness'])
return df
gr.Interface(get_score,gr.inputs.Textbox(lines=1, placeholder="This tool is awesome!", default="", label="Text 1"),"dataframe",
title="Patient Mental Health Sentiment Analysis",description="Input patient's verbal texts and the model returns the emotional state using this model: https://huggingface.co/j-hartmann/emotion-english-distilroberta-base.", layout="vertical").launch(debug=True)