Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
CHANGED
@@ -1,37 +1,91 @@
|
|
1 |
-
from flask import Flask, request, jsonify
|
2 |
import torch
|
3 |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
|
4 |
import os
|
5 |
|
6 |
app = Flask(__name__)
|
7 |
|
8 |
-
#
|
9 |
-
|
10 |
-
|
11 |
-
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
|
12 |
-
model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
|
13 |
-
model.eval()
|
14 |
-
print("Model and tokenizer loaded successfully!")
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
def classify_text(text):
|
23 |
try:
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
|
|
|
|
|
|
|
|
26 |
with torch.no_grad():
|
27 |
-
outputs =
|
28 |
logits = outputs.logits
|
29 |
|
|
|
30 |
probabilities = torch.nn.functional.softmax(logits, dim=1)
|
31 |
predicted_class_id = logits.argmax().item()
|
32 |
confidence = probabilities[0][predicted_class_id].item()
|
33 |
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
'category': CUSTOM_LABELS[predicted_class_id],
|
36 |
'confidence': round(confidence, 3),
|
37 |
'all_categories': {
|
@@ -39,41 +93,14 @@ def classify_text(text):
|
|
39 |
for label, prob in zip(CUSTOM_LABELS.values(), probabilities[0])
|
40 |
}
|
41 |
}
|
42 |
-
except Exception as e:
|
43 |
-
print(f"Error in classify_text: {str(e)}")
|
44 |
-
raise
|
45 |
-
|
46 |
-
@app.route('/classify', methods=['POST'])
|
47 |
-
def classify_email():
|
48 |
-
try:
|
49 |
-
data = request.get_json()
|
50 |
-
|
51 |
-
if not data or 'subject' not in data:
|
52 |
-
return jsonify({
|
53 |
-
'error': 'No subject provided. Please send a JSON with "subject" field.'
|
54 |
-
}), 400
|
55 |
|
56 |
-
subject = data['subject']
|
57 |
-
result = classify_text(subject)
|
58 |
return jsonify(result)
|
59 |
|
60 |
except Exception as e:
|
61 |
-
print(f"Error in
|
62 |
return jsonify({'error': str(e)}), 500
|
63 |
|
64 |
-
@app.route('/', methods=['GET'])
|
65 |
-
def home():
|
66 |
-
return jsonify({
|
67 |
-
'status': 'API is running',
|
68 |
-
'model_name': MODEL_NAME,
|
69 |
-
'usage': {
|
70 |
-
'endpoint': '/classify',
|
71 |
-
'method': 'POST',
|
72 |
-
'body': {'subject': 'Your email subject here'}
|
73 |
-
}
|
74 |
-
})
|
75 |
-
|
76 |
if __name__ == '__main__':
|
|
|
77 |
port = int(os.environ.get('PORT', 7860))
|
78 |
-
|
79 |
-
app.run(host='0.0.0.0', port=port, debug=True)
|
|
|
1 |
+
from flask import Flask, request, jsonify, make_response
|
2 |
import torch
|
3 |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
|
4 |
import os
|
5 |
|
6 |
app = Flask(__name__)
|
7 |
|
8 |
+
# Global variables to store model and tokenizer
|
9 |
+
global_tokenizer = None
|
10 |
+
global_model = None
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
def load_model():
|
13 |
+
"""Load the model and tokenizer"""
|
14 |
+
global global_tokenizer, global_model
|
15 |
+
try:
|
16 |
+
print("Loading model and tokenizer...")
|
17 |
+
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
|
18 |
+
global_tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
|
19 |
+
global_model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
|
20 |
+
global_model.eval()
|
21 |
+
print("Model loaded successfully!")
|
22 |
+
return True
|
23 |
+
except Exception as e:
|
24 |
+
print(f"Error loading model: {str(e)}")
|
25 |
+
return False
|
26 |
+
|
27 |
+
# Load model at startup
|
28 |
+
load_model()
|
29 |
+
|
30 |
+
@app.route('/', methods=['GET'])
|
31 |
+
def home():
|
32 |
+
"""Home endpoint to check if API is running"""
|
33 |
+
response = {
|
34 |
+
'status': 'API is running',
|
35 |
+
'model_status': 'loaded' if global_model is not None else 'not loaded',
|
36 |
+
'usage': {
|
37 |
+
'endpoint': '/classify',
|
38 |
+
'method': 'POST',
|
39 |
+
'body': {'subject': 'Your email subject here'}
|
40 |
+
}
|
41 |
+
}
|
42 |
+
return jsonify(response)
|
43 |
+
|
44 |
+
@app.route('/health', methods=['GET'])
|
45 |
+
def health_check():
|
46 |
+
"""Health check endpoint"""
|
47 |
+
if global_model is None or global_tokenizer is None:
|
48 |
+
return jsonify({'status': 'unhealthy', 'error': 'Model not loaded'}), 503
|
49 |
+
return jsonify({'status': 'healthy'})
|
50 |
+
|
51 |
+
@app.route('/classify', methods=['POST'])
|
52 |
+
def classify_email():
|
53 |
+
"""Classify email subject"""
|
54 |
+
if global_model is None or global_tokenizer is None:
|
55 |
+
return jsonify({'error': 'Model not loaded'}), 503
|
56 |
|
|
|
57 |
try:
|
58 |
+
# Get request data
|
59 |
+
data = request.get_json()
|
60 |
+
|
61 |
+
if not data or 'subject' not in data:
|
62 |
+
return jsonify({
|
63 |
+
'error': 'No subject provided. Please send a JSON with "subject" field.'
|
64 |
+
}), 400
|
65 |
+
|
66 |
+
# Get the subject
|
67 |
+
subject = data['subject']
|
68 |
|
69 |
+
# Tokenize
|
70 |
+
inputs = global_tokenizer(subject, return_tensors="pt", truncation=True, max_length=512)
|
71 |
+
|
72 |
+
# Predict
|
73 |
with torch.no_grad():
|
74 |
+
outputs = global_model(**inputs)
|
75 |
logits = outputs.logits
|
76 |
|
77 |
+
# Get probabilities
|
78 |
probabilities = torch.nn.functional.softmax(logits, dim=1)
|
79 |
predicted_class_id = logits.argmax().item()
|
80 |
confidence = probabilities[0][predicted_class_id].item()
|
81 |
|
82 |
+
# Map to custom labels
|
83 |
+
CUSTOM_LABELS = {
|
84 |
+
0: "Business/Professional",
|
85 |
+
1: "Personal/Casual"
|
86 |
+
}
|
87 |
+
|
88 |
+
result = {
|
89 |
'category': CUSTOM_LABELS[predicted_class_id],
|
90 |
'confidence': round(confidence, 3),
|
91 |
'all_categories': {
|
|
|
93 |
for label, prob in zip(CUSTOM_LABELS.values(), probabilities[0])
|
94 |
}
|
95 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
|
|
|
|
97 |
return jsonify(result)
|
98 |
|
99 |
except Exception as e:
|
100 |
+
print(f"Error in classification: {str(e)}")
|
101 |
return jsonify({'error': str(e)}), 500
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
if __name__ == '__main__':
|
104 |
+
# Use port 7860 for Hugging Face Spaces
|
105 |
port = int(os.environ.get('PORT', 7860))
|
106 |
+
app.run(host='0.0.0.0', port=port)
|
|