Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import torch | |
import torch.nn.functional as F | |
from huggingface_hub import login | |
import os | |
# Authenticate with Hugging Face using token from environment variable | |
try: | |
hf_token = os.environ.get("HUGGINGFACE_TOKEN") | |
if hf_token: | |
login(hf_token) | |
else: | |
print("Warning: HUGGINGFACE_TOKEN not found in environment variables") | |
except Exception as e: | |
print(f"Authentication error: {e}") | |
# Load MentalBERT model & tokenizer | |
try: | |
MODEL_NAME = "mental/mental-bert-base-uncased" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSequenceClassification.from_pretrained( | |
MODEL_NAME, | |
num_labels=2, | |
problem_type="single_label_classification" | |
) | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
raise | |
LABELS = { | |
"neutral": {"index": 0, "description": "Emotionally balanced or calm"}, | |
"emotional": {"index": 1, "description": "Showing emotional content"} | |
} | |
def analyze_text(text): | |
# Tokenize input | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
# Get model predictions | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probs = F.softmax(logits, dim=-1)[0] | |
# Get emotion scores | |
emotions = { | |
label: float(probs[info["index"]]) | |
for label, info in LABELS.items() | |
} | |
return emotions | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=analyze_text, | |
inputs=gr.Textbox(label="Enter text to analyze", lines=3), | |
outputs=gr.Json(label="Emotion Analysis"), | |
title="MentalBERT Emotion Analysis", | |
description="Analyze the emotional content of text using MentalBERT (specialized for mental health content)", | |
examples=[ | |
["I feel really anxious about my upcoming presentation"], | |
["I've been feeling quite depressed lately"], | |
["I'm managing my stress levels well today"], | |
["Just had a great therapy session!"] | |
], | |
allow_flagging="never" | |
) | |
# Launch the interface with CORS support | |
iface.launch(share=True, server_name="0.0.0.0") | |