SaherMuhamed commited on
Commit
11576c7
·
1 Parent(s): 69797a7

add the fine tuned BERT model with FAST API integrated in the Flask app

Browse files
Dockerfile CHANGED
@@ -1,20 +1,22 @@
1
- FROM python:3.9-slim
2
-
3
- WORKDIR /code
4
-
5
- # Install system dependencies
6
- RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
7
-
8
- # Copy requirements and install
9
- COPY requirements.txt .
10
- RUN pip install --no-cache-dir -r requirements.txt
11
-
12
- # Copy the rest of the code
13
- COPY . .
14
-
15
- # Expose FastAPI port
16
- EXPOSE 8000
17
-
18
- # Hugging Face Spaces expects the app to run on 0.0.0.0:8000
19
- ENV FLASK_APP=src.app
20
- CMD ["flask", "run", "--host=0.0.0.0", "--port=5000", "--no-debugger", "--no-reload"]
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Copy requirements first for better caching
6
+ COPY requirements.txt .
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ # Copy the entire project
10
+ COPY . .
11
+
12
+ # Create necessary directories if they don't exist
13
+ RUN mkdir -p /app/model /app/intent_classifier_model /app/intent_classifier_tokenizer
14
+
15
+ # Expose the port
16
+ EXPOSE 7860
17
+
18
+ # Set environment variables
19
+ ENV PYTHONPATH=/app
20
+
21
+ # Run the FastAPI application
22
+ CMD ["uvicorn", "model.api.api:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
  title: Intent Classifier Chatbot
3
  emoji: 🤖
4
- colorFrom: green
5
  colorTo: purple
6
  sdk: docker
7
  pinned: false
 
8
  license: apache-2.0
9
  short_description: Intent Detection API using BERT and Flask
10
  ---
@@ -68,6 +69,50 @@ The project uses the **CLINC150 dataset**, a benchmark dataset for intent classi
68
 
69
  ---
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ## 🤗 Hugging Face Spaces Configuration
72
 
73
  To deploy this project on [Hugging Face Spaces](https://huggingface.co/spaces), you can use a `README.md` and a `config.json` file to configure your Space for inference.
@@ -87,6 +132,4 @@ Example `config.json` for inference API:
87
 
88
  - Make sure your `requirements.txt` lists all dependencies.
89
  - The `entrypoint` should point to your main app file (e.g., `app.py` or `main.py`).
90
- - For more details and advanced configuration, see the [Spaces config reference](https://huggingface.co/docs/hub/spaces-config-reference).
91
-
92
- ---
 
1
  ---
2
  title: Intent Classifier Chatbot
3
  emoji: 🤖
4
+ colorFrom: blue
5
  colorTo: purple
6
  sdk: docker
7
  pinned: false
8
+ app_port: 7860
9
  license: apache-2.0
10
  short_description: Intent Detection API using BERT and Flask
11
  ---
 
69
 
70
  ---
71
 
72
+ # Intent Classifier Chatbot
73
+
74
+ A sophisticated intent classification system built with BERT and FastAPI that can predict user intents from natural language text.
75
+
76
+ ## Features
77
+
78
+ - **Advanced NLP**: Uses BERT-based transformer model for accurate intent classification
79
+ - **150+ Intent Classes**: Trained on the CLINC150 dataset with comprehensive intent coverage
80
+ - **Real-time Prediction**: FastAPI backend for fast inference
81
+ - **Clean UI**: Simple and intuitive web interface
82
+ - **Production Ready**: Dockerized for easy deployment
83
+
84
+ ## How to Use
85
+
86
+ 1. Enter your message in the text area
87
+ 2. Click "Predict Intent"
88
+ 3. See the AI's prediction of your intent
89
+
90
+ Try examples like:
91
+ - "Set an alarm for 7am" → Alarm
92
+ - "Transfer money to John" → Transfer
93
+ - "What's the weather like?" → Weather
94
+ - "Book a flight to Paris" → Book Flight
95
+
96
+ ## Model Details
97
+
98
+ - **Architecture**: BERT for Sequence Classification
99
+ - **Dataset**: CLINC150 (151 intent classes including out-of-scope)
100
+ - **Accuracy**: High performance on intent classification tasks
101
+ - **Preprocessing**: Advanced tokenization and text normalization
102
+
103
+ ## Tech Stack
104
+
105
+ - **Backend**: FastAPI, PyTorch, Transformers
106
+ - **Frontend**: HTML, CSS, JavaScript
107
+ - **Model**: BERT-base fine-tuned on CLINC150
108
+ - **Deployment**: Docker, Hugging Face Spaces
109
+
110
+ ## Author
111
+
112
+ **Saher Muhamed**
113
+ - GitHub: [@sahermuhamed1](https://github.com/sahermuhamed1)
114
+ - Email: sahermuhamed176@gmail.com
115
+
116
  ## 🤗 Hugging Face Spaces Configuration
117
 
118
  To deploy this project on [Hugging Face Spaces](https://huggingface.co/spaces), you can use a `README.md` and a `config.json` file to configure your Space for inference.
 
132
 
133
  - Make sure your `requirements.txt` lists all dependencies.
134
  - The `entrypoint` should point to your main app file (e.g., `app.py` or `main.py`).
135
+ - For more details and advanced configuration, see the [Spaces config reference](https://huggingface.co/docs/hub/spaces-config-reference).
 
 
model/api/__pycache__/api.cpython-39.pyc CHANGED
Binary files a/model/api/__pycache__/api.cpython-39.pyc and b/model/api/__pycache__/api.cpython-39.pyc differ
 
model/api/api.py CHANGED
@@ -1,10 +1,12 @@
1
- from fastapi import FastAPI
 
 
2
  from pydantic import BaseModel
3
  from transformers import BertForSequenceClassification, BertTokenizer
4
  import torch
5
  import os
6
 
7
- app = FastAPI()
8
 
9
  # Get the absolute path to the model directory
10
  BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -13,7 +15,7 @@ BASE_DIR = os.path.dirname(BASE_DIR)
13
  MODEL_DIR = os.path.join(BASE_DIR, "intent_classifier_model")
14
  TOKENIZER_DIR = os.path.join(BASE_DIR, "intent_classifier_tokenizer")
15
 
16
- # Ensure model and tokenizer directories exist
17
  if not os.path.isdir(MODEL_DIR):
18
  raise FileNotFoundError(f"Model directory not found: {MODEL_DIR}")
19
  if not os.path.isdir(TOKENIZER_DIR):
@@ -23,23 +25,355 @@ if not os.path.isdir(TOKENIZER_DIR):
23
  model = BertForSequenceClassification.from_pretrained(MODEL_DIR, local_files_only=True)
24
  tokenizer = BertTokenizer.from_pretrained(TOKENIZER_DIR, local_files_only=True)
25
 
26
- # Load intent label mapping
27
- from datasets import load_dataset
28
- dataset = load_dataset("clinc_oos", "small")
29
- int2str = dataset["train"].features["intent"].int2str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  class Query(BaseModel):
32
- text: str
 
33
 
 
34
  @app.post("/predict")
35
- def predict_intent(query: Query):
36
- inputs = tokenizer(query.text, return_tensors="pt", truncation=True, padding=True, max_length=128)
37
- with torch.no_grad():
38
- outputs = model(**inputs)
39
- prediction = outputs.logits.argmax(dim=-1).item()
40
- intent = int2str(prediction)
41
- if intent == "oos":
42
- return {"intent": "out of scope (OOS)"}
43
- else:
44
- intent = intent.replace("_", " ").title()
45
- return {"intent": intent}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from fastapi.responses import HTMLResponse
3
+ from fastapi.staticfiles import StaticFiles
4
  from pydantic import BaseModel
5
  from transformers import BertForSequenceClassification, BertTokenizer
6
  import torch
7
  import os
8
 
9
+ app = FastAPI(title="Intent Classifier API", description="BERT-based intent classification system")
10
 
11
  # Get the absolute path to the model directory
12
  BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
 
15
  MODEL_DIR = os.path.join(BASE_DIR, "intent_classifier_model")
16
  TOKENIZER_DIR = os.path.join(BASE_DIR, "intent_classifier_tokenizer")
17
 
18
+ # Ensure model and tokenizer directories exist
19
  if not os.path.isdir(MODEL_DIR):
20
  raise FileNotFoundError(f"Model directory not found: {MODEL_DIR}")
21
  if not os.path.isdir(TOKENIZER_DIR):
 
25
  model = BertForSequenceClassification.from_pretrained(MODEL_DIR, local_files_only=True)
26
  tokenizer = BertTokenizer.from_pretrained(TOKENIZER_DIR, local_files_only=True)
27
 
28
+ # Complete CLINC150 intent labels in exact order (151 total)
29
+ INTENT_LABELS = ['restaurant_reviews',
30
+ 'nutrition_info',
31
+ 'account_blocked',
32
+ 'oil_change_how',
33
+ 'time',
34
+ 'weather',
35
+ 'redeem_rewards',
36
+ 'interest_rate',
37
+ 'gas_type',
38
+ 'accept_reservations',
39
+ 'smart_home',
40
+ 'user_name',
41
+ 'report_lost_card',
42
+ 'repeat',
43
+ 'whisper_mode',
44
+ 'what_are_your_hobbies',
45
+ 'order',
46
+ 'jump_start',
47
+ 'schedule_meeting',
48
+ 'meeting_schedule',
49
+ 'freeze_account',
50
+ 'what_song',
51
+ 'meaning_of_life',
52
+ 'restaurant_reservation',
53
+ 'traffic',
54
+ 'make_call',
55
+ 'text',
56
+ 'bill_balance',
57
+ 'improve_credit_score',
58
+ 'change_language',
59
+ 'no',
60
+ 'measurement_conversion',
61
+ 'timer',
62
+ 'flip_coin',
63
+ 'do_you_have_pets',
64
+ 'balance',
65
+ 'tell_joke',
66
+ 'last_maintenance',
67
+ 'exchange_rate',
68
+ 'uber',
69
+ 'car_rental',
70
+ 'credit_limit',
71
+ 'oos',
72
+ 'shopping_list',
73
+ 'expiration_date',
74
+ 'routing',
75
+ 'meal_suggestion',
76
+ 'tire_change',
77
+ 'todo_list',
78
+ 'card_declined',
79
+ 'rewards_balance',
80
+ 'change_accent',
81
+ 'vaccines',
82
+ 'reminder_update',
83
+ 'food_last',
84
+ 'change_ai_name',
85
+ 'bill_due',
86
+ 'who_do_you_work_for',
87
+ 'share_location',
88
+ 'international_visa',
89
+ 'calendar',
90
+ 'translate',
91
+ 'carry_on',
92
+ 'book_flight',
93
+ 'insurance_change',
94
+ 'todo_list_update',
95
+ 'timezone',
96
+ 'cancel_reservation',
97
+ 'transactions',
98
+ 'credit_score',
99
+ 'report_fraud',
100
+ 'spending_history',
101
+ 'directions',
102
+ 'spelling',
103
+ 'insurance',
104
+ 'what_is_your_name',
105
+ 'reminder',
106
+ 'where_are_you_from',
107
+ 'distance',
108
+ 'payday',
109
+ 'flight_status',
110
+ 'find_phone',
111
+ 'greeting',
112
+ 'alarm',
113
+ 'order_status',
114
+ 'confirm_reservation',
115
+ 'cook_time',
116
+ 'damaged_card',
117
+ 'reset_settings',
118
+ 'pin_change',
119
+ 'replacement_card_duration',
120
+ 'new_card',
121
+ 'roll_dice',
122
+ 'income',
123
+ 'taxes',
124
+ 'date',
125
+ 'who_made_you',
126
+ 'pto_request',
127
+ 'tire_pressure',
128
+ 'how_old_are_you',
129
+ 'rollover_401k',
130
+ 'pto_request_status',
131
+ 'how_busy',
132
+ 'application_status',
133
+ 'recipe',
134
+ 'calendar_update',
135
+ 'play_music',
136
+ 'yes',
137
+ 'direct_deposit',
138
+ 'credit_limit_change',
139
+ 'gas',
140
+ 'pay_bill',
141
+ 'ingredients_list',
142
+ 'lost_luggage',
143
+ 'goodbye',
144
+ 'what_can_i_ask_you',
145
+ 'book_hotel',
146
+ 'are_you_a_bot',
147
+ 'next_song',
148
+ 'change_speed',
149
+ 'plug_type',
150
+ 'maybe',
151
+ 'w2',
152
+ 'oil_change_when',
153
+ 'thank_you',
154
+ 'shopping_list_update',
155
+ 'pto_balance',
156
+ 'order_checks',
157
+ 'travel_alert',
158
+ 'fun_fact',
159
+ 'sync_device',
160
+ 'schedule_maintenance',
161
+ 'apr',
162
+ 'transfer',
163
+ 'ingredient_substitution',
164
+ 'calories',
165
+ 'current_location',
166
+ 'international_fees',
167
+ 'calculator',
168
+ 'definition',
169
+ 'next_holiday',
170
+ 'update_playlist',
171
+ 'mpg',
172
+ 'min_payment',
173
+ 'change_user_name',
174
+ 'restaurant_suggestion',
175
+ 'travel_notification',
176
+ 'cancel',
177
+ 'pto_used',
178
+ 'travel_suggestion',
179
+ 'change_volume']
180
+ def int2str(idx):
181
+ return INTENT_LABELS[idx] if 0 <= idx < len(INTENT_LABELS) else "unknown"
182
 
183
  class Query(BaseModel):
184
+ text: str = None
185
+ message: str = None
186
 
187
+ # Add compatibility endpoint for both 'message' and 'text' fields
188
  @app.post("/predict")
189
+ def predict_intent_compat(request: Query):
190
+ """Compatibility endpoint that handles both text and message fields"""
191
+ try:
192
+ # Handle both 'text' and 'message' fields for compatibility
193
+ text = request.message or request.text or ""
194
+
195
+ if not text:
196
+ return {"error": "No text or message provided"}
197
+
198
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
199
+ with torch.no_grad():
200
+ outputs = model(**inputs)
201
+ prediction = outputs.logits.argmax(dim=-1).item()
202
+
203
+ # Debug information
204
+ print(f"Input: {text}")
205
+ print(f"Raw prediction index: {prediction}")
206
+ print(f"Total labels available: {len(INTENT_LABELS)}")
207
+
208
+ intent = int2str(prediction)
209
+ print(f"Mapped intent: {intent}")
210
+
211
+ if intent == "oos":
212
+ return {"intent": "out of scope (OOS)"}
213
+ else:
214
+ intent = intent.replace("_", " ").title()
215
+ return {"intent": intent}
216
+ except Exception as e:
217
+ print(f"Error in prediction: {e}")
218
+ return {"intent": "Error", "error": str(e)}
219
+
220
+
221
+
222
+ @app.get("/", response_class=HTMLResponse)
223
+ async def read_root():
224
+ """Serve the main HTML interface"""
225
+ html_content = """
226
+ <!DOCTYPE html>
227
+ <html lang="en">
228
+ <head>
229
+ <meta charset="UTF-8">
230
+ <title>Intent Classifier Chatbot</title>
231
+ <meta name="viewport" content="width=device-width, initial-scale=1">
232
+ <style>
233
+ body {
234
+ font-family: 'Segoe UI', Arial, sans-serif;
235
+ margin: 0;
236
+ background: #f7f9fa;
237
+ color: #222;
238
+ }
239
+ .container {
240
+ max-width: 600px;
241
+ margin: 60px auto 30px auto;
242
+ background: #fff;
243
+ border-radius: 12px;
244
+ box-shadow: 0 4px 24px rgba(0,0,0,0.08);
245
+ padding: 32px 28px 24px 28px;
246
+ }
247
+ h1 {
248
+ text-align: center;
249
+ color: #2d6cdf;
250
+ margin-bottom: 18px;
251
+ }
252
+ h2 {
253
+ text-align: center;
254
+ color: #2d6cdf;
255
+ margin-bottom: 18px;
256
+ font-size: 1.5em;
257
+ }
258
+ label {
259
+ font-weight: 500;
260
+ margin-bottom: 8px;
261
+ display: block;
262
+ }
263
+ textarea {
264
+ width: 100%;
265
+ height: 100px;
266
+ padding: 12px;
267
+ border: 1px solid #d2d6dc;
268
+ border-radius: 6px;
269
+ font-size: 1em;
270
+ margin-bottom: 18px;
271
+ box-sizing: border-box;
272
+ transition: border 0.2s;
273
+ }
274
+ textarea:focus {
275
+ border: 1.5px solid #2d6cdf;
276
+ outline: none;
277
+ }
278
+ button {
279
+ width: 100%;
280
+ padding: 12px;
281
+ background: linear-gradient(90deg, #2d6cdf 60%, #4e9cff 100%);
282
+ color: #fff;
283
+ border: none;
284
+ border-radius: 6px;
285
+ font-size: 1.1em;
286
+ font-weight: 600;
287
+ cursor: pointer;
288
+ transition: background 0.2s;
289
+ }
290
+ button:hover {
291
+ background: linear-gradient(90deg, #1b4e9b 60%, #3578c7 100%);
292
+ }
293
+ .result {
294
+ margin-top: 24px;
295
+ font-size: 1.15em;
296
+ background: #eaf3ff;
297
+ border-left: 4px solid #2d6cdf;
298
+ padding: 14px 18px;
299
+ border-radius: 6px;
300
+ color: #1a3a5d;
301
+ word-break: break-word;
302
+ }
303
+ .info {
304
+ margin-top: 18px;
305
+ font-size: 0.98em;
306
+ color: #555;
307
+ background: #f3f6fa;
308
+ border-radius: 6px;
309
+ padding: 10px 14px;
310
+ }
311
+ footer {
312
+ margin-top: 40px;
313
+ text-align: center;
314
+ color: #888;
315
+ font-size: 0.97em;
316
+ padding-bottom: 18px;
317
+ }
318
+ @media (max-width: 600px) {
319
+ .container { padding: 18px 6px 18px 6px; }
320
+ }
321
+ </style>
322
+ </head>
323
+ <body>
324
+ <div class="container">
325
+ <h1>Intent Classifier Chatbot</h1>
326
+ <h2>Predict User Intent</h2>
327
+ <div class="info">
328
+ Enter a message below and click <b>Predict Intent</b> to see what the AI thinks your intent is.<br>
329
+ <span style="color:#2d6cdf;">Try: <i>"Set an alarm for 7am"</i> or <i>"Transfer money to John"</i></span>
330
+ </div>
331
+ <div class="form-group">
332
+ <label for="message">Your Message:</label>
333
+ <textarea id="message" placeholder="Type your message here..."></textarea>
334
+ </div>
335
+ <button onclick="predictIntent()">Predict Intent</button>
336
+ <div id="result" class="result" style="display: none;"></div>
337
+ </div>
338
+ <footer>
339
+ Made by <b>Saher Muhamed</b><br>
340
+ <a href="https://github.com/sahermuhamed1" target="_blank" style="color:#2d6cdf;text-decoration:none;">GitHub</a> &middot;
341
+ <a href="mailto:sahermuhamed176@gmail.com" style="color:#2d6cdf;text-decoration:none;">Contact</a>
342
+ </footer>
343
+ <script>
344
+ function predictIntent() {
345
+ const message = document.getElementById('message').value.trim();
346
+ const resultDiv = document.getElementById('result');
347
+
348
+ if (!message) {
349
+ alert('Please enter a message first!');
350
+ return;
351
+ }
352
+
353
+ resultDiv.style.display = 'block';
354
+ resultDiv.innerHTML = 'Predicting...';
355
+
356
+ fetch('/predict', {
357
+ method: 'POST',
358
+ headers: {
359
+ 'Content-Type': 'application/json',
360
+ },
361
+ body: JSON.stringify({message: message})
362
+ })
363
+ .then(response => response.json())
364
+ .then(data => {
365
+ if (data.error) {
366
+ resultDiv.innerHTML = `<span style="color: red;">Error: ${data.error}</span>`;
367
+ } else {
368
+ resultDiv.innerHTML = `<span style="color: green;">Predicted Intent: ${data.intent || 'Unknown'}</span>`;
369
+ }
370
+ })
371
+ .catch(error => {
372
+ resultDiv.innerHTML = `<span style="color: red;">Error: ${error.message}</span>`;
373
+ });
374
+ }
375
+ </script>
376
+ </body>
377
+ </html>
378
+ """
379
+ return HTMLResponse(content=html_content)
model/api/start_server.py CHANGED
@@ -2,3 +2,4 @@ import uvicorn
2
 
3
  if __name__ == "__main__":
4
  uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True)
 
 
2
 
3
  if __name__ == "__main__":
4
  uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True)
5
+
src/fastapi_server.py CHANGED
@@ -14,4 +14,4 @@ def predict(req: PredictRequest):
14
  intent = "set_alarm"
15
  else:
16
  intent = "unknown"
17
- return {"intent": intent}
 
14
  intent = "set_alarm"
15
  else:
16
  intent = "unknown"
17
+ return {"intent": intent}
src/{app.py → main.py} RENAMED
@@ -1,5 +1,5 @@
1
  # NOTE: Make sure the FastAPI server is running at http://localhost:8000 before starting this Flask app.
2
- from flask import Flask, render_template, request
3
  import requests
4
 
5
  app = Flask(__name__)
@@ -14,7 +14,7 @@ def index():
14
  user_text = request.form.get("user_text", "")
15
  if user_text:
16
  try:
17
- response = requests.post(FASTAPI_URL, json={"text": user_text})
18
  if response.status_code == 200:
19
  prediction = response.json().get("intent", "Unknown")
20
  else:
@@ -30,5 +30,28 @@ def index():
30
  prediction = f"Error: {str(e)}"
31
  return render_template("index.html", prediction=prediction, user_text=user_text)
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  if __name__ == "__main__":
34
- app.run(debug=True)
 
1
  # NOTE: Make sure the FastAPI server is running at http://localhost:8000 before starting this Flask app.
2
+ from flask import Flask, render_template, request, jsonify
3
  import requests
4
 
5
  app = Flask(__name__)
 
14
  user_text = request.form.get("user_text", "")
15
  if user_text:
16
  try:
17
+ response = requests.post(FASTAPI_URL, json={"message": user_text})
18
  if response.status_code == 200:
19
  prediction = response.json().get("intent", "Unknown")
20
  else:
 
30
  prediction = f"Error: {str(e)}"
31
  return render_template("index.html", prediction=prediction, user_text=user_text)
32
 
33
+ @app.route('/predict', methods=['POST'])
34
+ def predict():
35
+ try:
36
+ data = request.get_json()
37
+ message = data.get('message', '')
38
+
39
+ if not message:
40
+ return jsonify({'error': 'No message provided'}), 400
41
+
42
+ # Call FastAPI model server
43
+ response = requests.post(FASTAPI_URL, json={'message': message})
44
+
45
+ if response.status_code == 200:
46
+ result = response.json()
47
+ return jsonify({'intent': result.get('intent', 'Unknown')})
48
+ else:
49
+ return jsonify({'error': 'Model server error'}), 500
50
+
51
+ except requests.exceptions.ConnectionError:
52
+ return jsonify({'error': 'Could not connect to FastAPI server. Make sure it\'s running on port 8000.'}), 500
53
+ except Exception as e:
54
+ return jsonify({'error': str(e)}), 500
55
+
56
  if __name__ == "__main__":
57
+ app.run(host="0.0.0.0", port=5000, debug=False)
start.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Download model files if they don't exist
4
+ if [ ! -d "/app/intent_classifier_model" ]; then
5
+ echo "Model files not found. Please upload your trained model to the Space."
6
+ echo "Create the following directories and upload your model files:"
7
+ echo "- intent_classifier_model/ (containing the trained BERT model)"
8
+ echo "- intent_classifier_tokenizer/ (containing the tokenizer)"
9
+ fi
10
+
11
+ # Start the FastAPI application
12
+ exec uvicorn model.api.api:app --host 0.0.0.0 --port 7860
training/workspace.ipynb CHANGED
@@ -5,16 +5,7 @@
5
  "execution_count": 1,
6
  "id": "27ee9040",
7
  "metadata": {},
8
- "outputs": [
9
- {
10
- "name": "stderr",
11
- "output_type": "stream",
12
- "text": [
13
- "/home/saher/miniconda3/envs/AI/lib/python3.9/site-packages/requests/__init__.py:86: RequestsDependencyWarning: Unable to find acceptable character detection dependency (chardet or charset_normalizer).\n",
14
- " warnings.warn(\n"
15
- ]
16
- }
17
- ],
18
  "source": [
19
  "from datasets import load_dataset\n",
20
  "\n",
@@ -29,6 +20,190 @@
29
  "test_labels = dataset[\"test\"][\"intent\"]\n"
30
  ]
31
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  {
33
  "cell_type": "code",
34
  "execution_count": 2,
 
5
  "execution_count": 1,
6
  "id": "27ee9040",
7
  "metadata": {},
8
+ "outputs": [],
 
 
 
 
 
 
 
 
 
9
  "source": [
10
  "from datasets import load_dataset\n",
11
  "\n",
 
20
  "test_labels = dataset[\"test\"][\"intent\"]\n"
21
  ]
22
  },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "id": "78405e22",
27
+ "metadata": {},
28
+ "outputs": [
29
+ {
30
+ "data": {
31
+ "text/plain": [
32
+ "['restaurant_reviews',\n",
33
+ " 'nutrition_info',\n",
34
+ " 'account_blocked',\n",
35
+ " 'oil_change_how',\n",
36
+ " 'time',\n",
37
+ " 'weather',\n",
38
+ " 'redeem_rewards',\n",
39
+ " 'interest_rate',\n",
40
+ " 'gas_type',\n",
41
+ " 'accept_reservations',\n",
42
+ " 'smart_home',\n",
43
+ " 'user_name',\n",
44
+ " 'report_lost_card',\n",
45
+ " 'repeat',\n",
46
+ " 'whisper_mode',\n",
47
+ " 'what_are_your_hobbies',\n",
48
+ " 'order',\n",
49
+ " 'jump_start',\n",
50
+ " 'schedule_meeting',\n",
51
+ " 'meeting_schedule',\n",
52
+ " 'freeze_account',\n",
53
+ " 'what_song',\n",
54
+ " 'meaning_of_life',\n",
55
+ " 'restaurant_reservation',\n",
56
+ " 'traffic',\n",
57
+ " 'make_call',\n",
58
+ " 'text',\n",
59
+ " 'bill_balance',\n",
60
+ " 'improve_credit_score',\n",
61
+ " 'change_language',\n",
62
+ " 'no',\n",
63
+ " 'measurement_conversion',\n",
64
+ " 'timer',\n",
65
+ " 'flip_coin',\n",
66
+ " 'do_you_have_pets',\n",
67
+ " 'balance',\n",
68
+ " 'tell_joke',\n",
69
+ " 'last_maintenance',\n",
70
+ " 'exchange_rate',\n",
71
+ " 'uber',\n",
72
+ " 'car_rental',\n",
73
+ " 'credit_limit',\n",
74
+ " 'oos',\n",
75
+ " 'shopping_list',\n",
76
+ " 'expiration_date',\n",
77
+ " 'routing',\n",
78
+ " 'meal_suggestion',\n",
79
+ " 'tire_change',\n",
80
+ " 'todo_list',\n",
81
+ " 'card_declined',\n",
82
+ " 'rewards_balance',\n",
83
+ " 'change_accent',\n",
84
+ " 'vaccines',\n",
85
+ " 'reminder_update',\n",
86
+ " 'food_last',\n",
87
+ " 'change_ai_name',\n",
88
+ " 'bill_due',\n",
89
+ " 'who_do_you_work_for',\n",
90
+ " 'share_location',\n",
91
+ " 'international_visa',\n",
92
+ " 'calendar',\n",
93
+ " 'translate',\n",
94
+ " 'carry_on',\n",
95
+ " 'book_flight',\n",
96
+ " 'insurance_change',\n",
97
+ " 'todo_list_update',\n",
98
+ " 'timezone',\n",
99
+ " 'cancel_reservation',\n",
100
+ " 'transactions',\n",
101
+ " 'credit_score',\n",
102
+ " 'report_fraud',\n",
103
+ " 'spending_history',\n",
104
+ " 'directions',\n",
105
+ " 'spelling',\n",
106
+ " 'insurance',\n",
107
+ " 'what_is_your_name',\n",
108
+ " 'reminder',\n",
109
+ " 'where_are_you_from',\n",
110
+ " 'distance',\n",
111
+ " 'payday',\n",
112
+ " 'flight_status',\n",
113
+ " 'find_phone',\n",
114
+ " 'greeting',\n",
115
+ " 'alarm',\n",
116
+ " 'order_status',\n",
117
+ " 'confirm_reservation',\n",
118
+ " 'cook_time',\n",
119
+ " 'damaged_card',\n",
120
+ " 'reset_settings',\n",
121
+ " 'pin_change',\n",
122
+ " 'replacement_card_duration',\n",
123
+ " 'new_card',\n",
124
+ " 'roll_dice',\n",
125
+ " 'income',\n",
126
+ " 'taxes',\n",
127
+ " 'date',\n",
128
+ " 'who_made_you',\n",
129
+ " 'pto_request',\n",
130
+ " 'tire_pressure',\n",
131
+ " 'how_old_are_you',\n",
132
+ " 'rollover_401k',\n",
133
+ " 'pto_request_status',\n",
134
+ " 'how_busy',\n",
135
+ " 'application_status',\n",
136
+ " 'recipe',\n",
137
+ " 'calendar_update',\n",
138
+ " 'play_music',\n",
139
+ " 'yes',\n",
140
+ " 'direct_deposit',\n",
141
+ " 'credit_limit_change',\n",
142
+ " 'gas',\n",
143
+ " 'pay_bill',\n",
144
+ " 'ingredients_list',\n",
145
+ " 'lost_luggage',\n",
146
+ " 'goodbye',\n",
147
+ " 'what_can_i_ask_you',\n",
148
+ " 'book_hotel',\n",
149
+ " 'are_you_a_bot',\n",
150
+ " 'next_song',\n",
151
+ " 'change_speed',\n",
152
+ " 'plug_type',\n",
153
+ " 'maybe',\n",
154
+ " 'w2',\n",
155
+ " 'oil_change_when',\n",
156
+ " 'thank_you',\n",
157
+ " 'shopping_list_update',\n",
158
+ " 'pto_balance',\n",
159
+ " 'order_checks',\n",
160
+ " 'travel_alert',\n",
161
+ " 'fun_fact',\n",
162
+ " 'sync_device',\n",
163
+ " 'schedule_maintenance',\n",
164
+ " 'apr',\n",
165
+ " 'transfer',\n",
166
+ " 'ingredient_substitution',\n",
167
+ " 'calories',\n",
168
+ " 'current_location',\n",
169
+ " 'international_fees',\n",
170
+ " 'calculator',\n",
171
+ " 'definition',\n",
172
+ " 'next_holiday',\n",
173
+ " 'update_playlist',\n",
174
+ " 'mpg',\n",
175
+ " 'min_payment',\n",
176
+ " 'change_user_name',\n",
177
+ " 'restaurant_suggestion',\n",
178
+ " 'travel_notification',\n",
179
+ " 'cancel',\n",
180
+ " 'pto_used',\n",
181
+ " 'travel_suggestion',\n",
182
+ " 'change_volume']"
183
+ ]
184
+ },
185
+ "execution_count": 5,
186
+ "metadata": {},
187
+ "output_type": "execute_result"
188
+ },
189
+ {
190
+ "ename": "",
191
+ "evalue": "",
192
+ "output_type": "error",
193
+ "traceback": [
194
+ "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
195
+ "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
196
+ "\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
197
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
198
+ ]
199
+ }
200
+ ],
201
+ "source": [
202
+ "# list all the intents with its label string\n",
203
+ "intents = dataset[\"train\"].features[\"intent\"].names \n",
204
+ "intents"
205
+ ]
206
+ },
207
  {
208
  "cell_type": "code",
209
  "execution_count": 2,