destroyer795's picture
fix: make model more robust and aware of the tone of the comment.
b45c7de
from flask import Flask, request, jsonify
from transformers import pipeline
import torch
app = Flask(__name__)
# 1. MODEL CONFIGURATION
# Ensure this path matches your unzipped folder name exactly.
MODEL_PATH = "./sentiment_analyzer_pro"
# Load the DistilBERT pipeline.
# We use device=-1 to ensure it runs on CPU, which is standard for free Hugging Face Spaces.
print("Loading DistilBERT 3-class model...")
try:
classifier = pipeline(
"sentiment-analysis",
model=MODEL_PATH,
tokenizer=MODEL_PATH,
device=-1
)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# 2. PREDICTION ENDPOINT
@app.route('/predict', methods=['POST'])
def predict_endpoint():
"""
Receives JSON input: {"text": "Your review here"}
Returns JSON: {"sentiment": "Label", "score": 0.99, "confidence_flag": "High/Low"}
"""
data = request.get_json()
# Validate input
if not data or 'text' not in data:
return jsonify({'error': 'No text provided'}), 400
sentence = data['text']
# Perform inference
# Result is a list: [{'label': 'POSITIVE', 'score': 0.98}]
result = classifier(sentence)[0]
label = result['label']
score = result['score']
# 3. INTELLIGENT SARCASM/MIXED LOGIC
# We use 0.70 (70%) as the "Sureness" threshold.
# If the model is less than 70% confident, we categorize it as Neutral/Mixed.
# This captures sarcasm where the model sees conflicting emotional signals.
if score < 0.70:
final_sentiment = "Neutral / Mixed"
confidence_flag = "Low"
else:
# Standardize labels from 'POSITIVE' to 'Positive'
final_sentiment = label.capitalize()
confidence_flag = "High"
return jsonify({
'sentiment': final_sentiment,
'score': round(score, 4),
'confidence_flag': confidence_flag
})
# 4. HEALTH CHECK
@app.route('/', methods=['GET'])
def health_check():
return "Sentiment Analyzer Pro API is online."
if __name__ == '__main__':
# Port 7860 is required for Hugging Face Spaces deployment.
# host='0.0.0.0' allows external connections (like your Chrome Extension).
app.run(host='0.0.0.0', port=7860)