aideveloper24 commited on
Commit
efa6633
·
verified ·
1 Parent(s): b3df9a0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -47
app.py CHANGED
@@ -1,37 +1,91 @@
1
- from flask import Flask, request, jsonify
2
  import torch
3
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
4
  import os
5
 
6
  app = Flask(__name__)
7
 
8
- # Initialize model and tokenizer globally
9
- print("Loading model and tokenizer...")
10
- MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
11
- tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
12
- model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
13
- model.eval()
14
- print("Model and tokenizer loaded successfully!")
15
 
16
- # Custom labels
17
- CUSTOM_LABELS = {
18
- 0: "Business/Professional",
19
- 1: "Personal/Casual"
20
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- def classify_text(text):
23
  try:
24
- inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
 
 
 
 
 
 
 
 
 
25
 
 
 
 
 
26
  with torch.no_grad():
27
- outputs = model(**inputs)
28
  logits = outputs.logits
29
 
 
30
  probabilities = torch.nn.functional.softmax(logits, dim=1)
31
  predicted_class_id = logits.argmax().item()
32
  confidence = probabilities[0][predicted_class_id].item()
33
 
34
- return {
 
 
 
 
 
 
35
  'category': CUSTOM_LABELS[predicted_class_id],
36
  'confidence': round(confidence, 3),
37
  'all_categories': {
@@ -39,41 +93,14 @@ def classify_text(text):
39
  for label, prob in zip(CUSTOM_LABELS.values(), probabilities[0])
40
  }
41
  }
42
- except Exception as e:
43
- print(f"Error in classify_text: {str(e)}")
44
- raise
45
-
46
- @app.route('/classify', methods=['POST'])
47
- def classify_email():
48
- try:
49
- data = request.get_json()
50
-
51
- if not data or 'subject' not in data:
52
- return jsonify({
53
- 'error': 'No subject provided. Please send a JSON with "subject" field.'
54
- }), 400
55
 
56
- subject = data['subject']
57
- result = classify_text(subject)
58
  return jsonify(result)
59
 
60
  except Exception as e:
61
- print(f"Error in classify_email: {str(e)}")
62
  return jsonify({'error': str(e)}), 500
63
 
64
- @app.route('/', methods=['GET'])
65
- def home():
66
- return jsonify({
67
- 'status': 'API is running',
68
- 'model_name': MODEL_NAME,
69
- 'usage': {
70
- 'endpoint': '/classify',
71
- 'method': 'POST',
72
- 'body': {'subject': 'Your email subject here'}
73
- }
74
- })
75
-
76
  if __name__ == '__main__':
 
77
  port = int(os.environ.get('PORT', 7860))
78
- print(f"Starting server on port {port}...")
79
- app.run(host='0.0.0.0', port=port, debug=True)
 
1
+ from flask import Flask, request, jsonify, make_response
2
  import torch
3
  from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
4
  import os
5
 
6
  app = Flask(__name__)
7
 
8
+ # Global variables to store model and tokenizer
9
+ global_tokenizer = None
10
+ global_model = None
 
 
 
 
11
 
12
+ def load_model():
13
+ """Load the model and tokenizer"""
14
+ global global_tokenizer, global_model
15
+ try:
16
+ print("Loading model and tokenizer...")
17
+ MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
18
+ global_tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
19
+ global_model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
20
+ global_model.eval()
21
+ print("Model loaded successfully!")
22
+ return True
23
+ except Exception as e:
24
+ print(f"Error loading model: {str(e)}")
25
+ return False
26
+
27
+ # Load model at startup
28
+ load_model()
29
+
30
+ @app.route('/', methods=['GET'])
31
+ def home():
32
+ """Home endpoint to check if API is running"""
33
+ response = {
34
+ 'status': 'API is running',
35
+ 'model_status': 'loaded' if global_model is not None else 'not loaded',
36
+ 'usage': {
37
+ 'endpoint': '/classify',
38
+ 'method': 'POST',
39
+ 'body': {'subject': 'Your email subject here'}
40
+ }
41
+ }
42
+ return jsonify(response)
43
+
44
+ @app.route('/health', methods=['GET'])
45
+ def health_check():
46
+ """Health check endpoint"""
47
+ if global_model is None or global_tokenizer is None:
48
+ return jsonify({'status': 'unhealthy', 'error': 'Model not loaded'}), 503
49
+ return jsonify({'status': 'healthy'})
50
+
51
+ @app.route('/classify', methods=['POST'])
52
+ def classify_email():
53
+ """Classify email subject"""
54
+ if global_model is None or global_tokenizer is None:
55
+ return jsonify({'error': 'Model not loaded'}), 503
56
 
 
57
  try:
58
+ # Get request data
59
+ data = request.get_json()
60
+
61
+ if not data or 'subject' not in data:
62
+ return jsonify({
63
+ 'error': 'No subject provided. Please send a JSON with "subject" field.'
64
+ }), 400
65
+
66
+ # Get the subject
67
+ subject = data['subject']
68
 
69
+ # Tokenize
70
+ inputs = global_tokenizer(subject, return_tensors="pt", truncation=True, max_length=512)
71
+
72
+ # Predict
73
  with torch.no_grad():
74
+ outputs = global_model(**inputs)
75
  logits = outputs.logits
76
 
77
+ # Get probabilities
78
  probabilities = torch.nn.functional.softmax(logits, dim=1)
79
  predicted_class_id = logits.argmax().item()
80
  confidence = probabilities[0][predicted_class_id].item()
81
 
82
+ # Map to custom labels
83
+ CUSTOM_LABELS = {
84
+ 0: "Business/Professional",
85
+ 1: "Personal/Casual"
86
+ }
87
+
88
+ result = {
89
  'category': CUSTOM_LABELS[predicted_class_id],
90
  'confidence': round(confidence, 3),
91
  'all_categories': {
 
93
  for label, prob in zip(CUSTOM_LABELS.values(), probabilities[0])
94
  }
95
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
 
 
97
  return jsonify(result)
98
 
99
  except Exception as e:
100
+ print(f"Error in classification: {str(e)}")
101
  return jsonify({'error': str(e)}), 500
102
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  if __name__ == '__main__':
104
+ # Use port 7860 for Hugging Face Spaces
105
  port = int(os.environ.get('PORT', 7860))
106
+ app.run(host='0.0.0.0', port=port)