|
from flask import Flask, request, jsonify |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
import os |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased-finetuned-sst-2-english") |
|
model = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased-finetuned-sst-2-english") |
|
model.eval() |
|
|
|
@app.route('/', methods=['GET']) |
|
def home(): |
|
"""Home endpoint to check if API is running""" |
|
response = { |
|
'status': 'API is running', |
|
'usage': { |
|
'endpoint': '/classify', |
|
'method': 'POST', |
|
'body': {'subject': 'Your email subject here'} |
|
} |
|
} |
|
return jsonify(response) |
|
|
|
@app.route('/health', methods=['GET']) |
|
def health_check(): |
|
"""Health check endpoint""" |
|
return jsonify({'status': 'healthy'}) |
|
|
|
@app.route('/classify', methods=['POST']) |
|
def classify_email(): |
|
"""Classify email subject""" |
|
try: |
|
|
|
data = request.get_json() |
|
|
|
if not data or 'subject' not in data: |
|
return jsonify({ |
|
'error': 'No subject provided. Please send a JSON with "subject" field.' |
|
}), 400 |
|
|
|
|
|
subject = data['subject'] |
|
|
|
|
|
inputs = tokenizer(subject, return_tensors="pt", truncation=True, max_length=512) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
|
|
|
|
probabilities = torch.nn.functional.softmax(logits, dim=1) |
|
predicted_class_id = logits.argmax().item() |
|
confidence = probabilities[0][predicted_class_id].item() |
|
|
|
|
|
CUSTOM_LABELS = { |
|
0: "Negative", |
|
1: "Positive" |
|
} |
|
|
|
result = { |
|
'category': CUSTOM_LABELS[predicted_class_id], |
|
'confidence': round(confidence, 3), |
|
'all_categories': { |
|
label: round(prob.item(), 3) |
|
for label, prob in zip(CUSTOM_LABELS.values(), probabilities[0]) |
|
} |
|
} |
|
|
|
return jsonify(result) |
|
|
|
except Exception as e: |
|
print(f"Error in classification: {str(e)}") |
|
return jsonify({'error': str(e)}), 500 |
|
|
|
if __name__ == '__main__': |
|
|
|
port = int(os.environ.get('PORT', 7860)) |
|
app.run(host='0.0.0.0', port=port) |
|
|