email_classify / app.py
aideveloper24's picture
Update app.py
cb536ff verified
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os
app = Flask(__name__)
# Load the model and tokenizer directly
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased-finetuned-sst-2-english")
model = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased-finetuned-sst-2-english")
model.eval() # Set the model to evaluation mode
@app.route('/', methods=['GET'])
def home():
"""Home endpoint to check if API is running"""
response = {
'status': 'API is running',
'usage': {
'endpoint': '/classify',
'method': 'POST',
'body': {'subject': 'Your email subject here'}
}
}
return jsonify(response)
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({'status': 'healthy'})
@app.route('/classify', methods=['POST'])
def classify_email():
"""Classify email subject"""
try:
# Get request data
data = request.get_json()
if not data or 'subject' not in data:
return jsonify({
'error': 'No subject provided. Please send a JSON with "subject" field.'
}), 400
# Get the subject
subject = data['subject']
# Tokenize
inputs = tokenizer(subject, return_tensors="pt", truncation=True, max_length=512)
# Predict
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Get probabilities
probabilities = torch.nn.functional.softmax(logits, dim=1)
predicted_class_id = logits.argmax().item()
confidence = probabilities[0][predicted_class_id].item()
# Define custom categories (Modify this as needed)
CUSTOM_LABELS = {
0: "Negative",
1: "Positive"
}
result = {
'category': CUSTOM_LABELS[predicted_class_id],
'confidence': round(confidence, 3),
'all_categories': {
label: round(prob.item(), 3)
for label, prob in zip(CUSTOM_LABELS.values(), probabilities[0])
}
}
return jsonify(result)
except Exception as e:
print(f"Error in classification: {str(e)}")
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
# Use port 7860 for Hugging Face Spaces or any other port for local testing
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port)