jskswamy commited on
Commit
5ff8440
·
verified ·
1 Parent(s): 446a732

Uploading files via huggingface api

Browse files
Files changed (4) hide show
  1. Dockerfile +10 -8
  2. README.md +4 -4
  3. app.py +352 -751
  4. requirements.txt +7 -4
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- # Streamlit Frontend Dockerfile
2
  FROM python:3.12-slim
3
 
4
  # Set working directory
@@ -12,20 +12,22 @@ RUN pip install --no-cache-dir -r requirements.txt
12
 
13
  # Copy application code
14
  COPY app.py .
15
- COPY .streamlit/ .streamlit/
 
 
16
 
17
  # Expose port (Hugging Face Spaces uses 7860)
18
  EXPOSE 7860
19
 
20
- # Set default backend URL (can be overridden with environment variable)
21
- ENV BACKEND_URL=http://localhost:7860
 
22
 
23
- # Health check
24
  HEALTHCHECK --interval=30s \
25
  --timeout=10s \
26
  --start-period=5s \
27
  --retries=3 \
28
- CMD curl -f http://localhost:7860/_stcore/health || exit 1
29
 
30
- # Run Streamlit
31
- CMD ["streamlit", "run", "app.py"]
 
1
+ # Use Python 3.12 slim image as base
2
  FROM python:3.12-slim
3
 
4
  # Set working directory
 
12
 
13
  # Copy application code
14
  COPY app.py .
15
+
16
+ # Create models directory and copy model file
17
+ COPY ./superkart_model.joblib ./superkart_model.joblib
18
 
19
  # Expose port (Hugging Face Spaces uses 7860)
20
  EXPOSE 7860
21
 
22
+ # Set environment variables
23
+ ENV FLASK_APP=app.py
24
+ ENV FLASK_ENV=production
25
 
 
26
  HEALTHCHECK --interval=30s \
27
  --timeout=10s \
28
  --start-period=5s \
29
  --retries=3 \
30
+ CMD curl -f http://localhost:7860/ || exit 1
31
 
32
+ # Run the application with gunicorn
33
+ CMD ["gunicorn", "--bind", "0.0.0.0:7860", "--workers", "4", "--timeout", "120", "app:app"]
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Superkart Frontend
3
- emoji: 🌍
4
- colorFrom: yellow
5
- colorTo: blue
6
  sdk: docker
7
  pinned: false
8
  ---
 
1
  ---
2
+ title: Superkart Backend
3
+ emoji: 🛒
4
+ colorFrom: purple
5
+ colorTo: indigo
6
  sdk: docker
7
  pinned: false
8
  ---
app.py CHANGED
@@ -1,793 +1,394 @@
1
  """
2
- SuperKart Sales Prediction Frontend
3
 
4
- A Streamlit web application for predicting product sales using the SuperKart ML model.
5
- This frontend provides an intuitive interface for users to input product and store features
6
- and get sales predictions from the backend API.
7
  """
8
 
9
- import warnings
10
- import streamlit as st
11
- import requests
12
- import pandas as pd
13
- import argparse
14
  import os
15
- import sys
16
- from typing import Dict
17
-
18
- # Suppress SyntaxWarnings from Streamlit library
19
- warnings.filterwarnings("ignore", category=SyntaxWarning)
20
-
21
- # Page configuration
22
- st.set_page_config(
23
- page_title="SuperKart Sales Predictor",
24
- page_icon="🛒",
25
- layout="wide",
26
- initial_sidebar_state="expanded",
27
- )
28
-
29
- # Custom CSS for better styling
30
- st.markdown(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  """
32
- <style>
33
- .main-header {
34
- font-size: 3rem;
35
- color: #1f77b4;
36
- text-align: center;
37
- margin-bottom: 2rem;
38
- }
39
- .prediction-box {
40
- background-color: #f0f8ff;
41
- padding: 20px;
42
- border-radius: 10px;
43
- border-left: 5px solid #1f77b4;
44
- margin: 20px 0;
45
- }
46
- .success-box {
47
- background-color: #d4edda;
48
- padding: 15px;
49
- border-radius: 5px;
50
- border-left: 5px solid #28a745;
51
- margin: 10px 0;
52
- }
53
- .error-box {
54
- background-color: #f8d7da;
55
- padding: 15px;
56
- border-radius: 5px;
57
- border-left: 5px solid #dc3545;
58
- margin: 10px 0;
59
- }
60
- </style>
61
- """,
62
- unsafe_allow_html=True,
63
- )
64
-
65
-
66
- def get_backend_url():
67
- """Get backend URL from command line arguments, environment variables, or default."""
68
- # Check if running with Streamlit (sys.argv will contain streamlit run ...)
69
- if len(sys.argv) > 1 and "streamlit" in sys.argv[0]:
70
- # Parse additional arguments after the script name
71
- parser = argparse.ArgumentParser(description="SuperKart Frontend App")
72
- parser.add_argument(
73
- "--backend-url",
74
- type=str,
75
- default=os.getenv("BACKEND_URL", "http://localhost:7860"),
76
- help="Backend API URL (default: http://localhost:7860)",
77
- )
78
 
79
- # Only parse known args to avoid conflicts with Streamlit args
80
- try:
81
- known_args, _ = parser.parse_known_args()
82
- return known_args.backend_url
83
- except (SystemExit, argparse.ArgumentError):
84
- pass
85
 
86
- # Fallback to environment variable or default
87
- return os.getenv("BACKEND_URL", "http://localhost:7860")
 
 
88
 
 
 
 
89
 
90
- # Configuration
91
- BACKEND_URL = get_backend_url()
 
92
 
 
 
 
 
93
 
94
- def make_api_request(endpoint: str, data: Dict = None, method: str = "GET") -> Dict:
95
- """Make API request to backend service."""
96
- try:
97
- url = f"{BACKEND_URL}{endpoint}"
98
 
99
- if method == "GET":
100
- response = requests.get(url, timeout=30)
101
- elif method == "POST":
102
- response = requests.post(url, json=data, timeout=30)
103
 
104
- response.raise_for_status()
105
- return {"success": True, "data": response.json()}
106
 
107
- except requests.exceptions.ConnectionError:
108
- return {
109
- "success": False,
110
- "error": "Cannot connect to backend API. Please ensure the backend service is running.",
111
- }
112
- except requests.exceptions.Timeout:
113
- return {
114
- "success": False,
115
- "error": "Request timeout. The backend service is taking too long to respond.",
116
- }
117
- except requests.exceptions.RequestException as e:
118
- return {"success": False, "error": f"API request failed: {str(e)}"}
119
-
120
-
121
- def get_feature_info():
122
- """Get feature information from backend API."""
123
- result = make_api_request("/features")
124
- if result["success"]:
125
- return result["data"]
126
- else:
127
- st.error(f"Failed to get feature information: {result['error']}")
128
- return None
129
-
130
-
131
- def create_input_form():
132
- """Create the input form for prediction."""
133
- st.header("🔮 Product Sales Prediction")
134
-
135
- # Get feature information
136
- feature_info = get_feature_info()
137
- if not feature_info:
138
- return None
139
-
140
- # Create form
141
- with st.form("prediction_form"):
142
- col1, col2 = st.columns(2)
143
-
144
- with col1:
145
- st.subheader("📦 Product Features")
146
-
147
- product_weight = st.number_input(
148
- "Product Weight (kg)",
149
- min_value=0.1,
150
- max_value=100.0,
151
- value=12.66,
152
- step=0.1,
153
- help="Weight of the product in kilograms",
154
- )
155
-
156
- product_sugar_content = st.selectbox(
157
- "Sugar Content",
158
- options=["Low Sugar", "Regular", "No Sugar"],
159
- index=0,
160
- help="Sugar content level of the product",
161
- )
162
-
163
- product_allocated_area = st.number_input(
164
- "Allocated Display Area (Ratio)",
165
- min_value=0.0,
166
- max_value=1.0,
167
- value=0.027,
168
- step=0.001,
169
- format="%.3f",
170
- help="Ratio of allocated display area (0.0 to 1.0)",
171
- )
172
-
173
- product_type = st.selectbox(
174
- "Product Type",
175
- options=[
176
- "Dairy",
177
- "Soft Drinks",
178
- "Meat",
179
- "Fruits and Vegetables",
180
- "Household",
181
- "Baking Goods",
182
- "Snack Foods",
183
- "Frozen Foods",
184
- "Breakfast",
185
- "Health and Hygiene",
186
- "Hard Drinks",
187
- "Canned",
188
- "Bread",
189
- "Starchy Foods",
190
- "Others",
191
- "Seafood",
192
- ],
193
- index=7, # Frozen Foods
194
- help="Category of the product",
195
- )
196
-
197
- product_mrp = st.number_input(
198
- "Maximum Retail Price ($)",
199
- min_value=1.0,
200
- max_value=1000.0,
201
- value=117.08,
202
- step=0.01,
203
- format="%.2f",
204
- help="Maximum retail price in USD",
205
- )
206
-
207
- with col2:
208
- st.subheader("🏪 Store Features")
209
-
210
- store_establishment_year = st.selectbox(
211
- "Store Establishment Year",
212
- options=[1987, 1998, 1999, 2009],
213
- index=3, # 2009
214
- help="Year when the store was established",
215
- )
216
-
217
- store_size = st.selectbox(
218
- "Store Size",
219
- options=["Small", "Medium", "High"],
220
- index=1, # Medium
221
- help="Size category of the store",
222
- )
223
-
224
- store_location_city_type = st.selectbox(
225
- "City Type",
226
- options=["Tier 1", "Tier 2", "Tier 3"],
227
- index=1, # Tier 2
228
- help="Type of city where the store is located",
229
- )
230
-
231
- store_type = st.selectbox(
232
- "Store Type",
233
- options=[
234
- "Supermarket Type1",
235
- "Supermarket Type2",
236
- "Supermarket Type3",
237
- "Departmental Store",
238
- "Food Mart",
239
- ],
240
- index=1, # Supermarket Type2
241
- help="Type/format of the store",
242
- )
243
-
244
- # Submit button
245
- submitted = st.form_submit_button("🎯 Predict Sales", type="primary")
246
-
247
- if submitted:
248
- # Prepare input data
249
- input_data = {
250
- "Product_Weight": product_weight,
251
- "Product_Sugar_Content": product_sugar_content,
252
- "Product_Allocated_Area": product_allocated_area,
253
- "Product_Type": product_type,
254
- "Product_MRP": product_mrp,
255
- "Store_Establishment_Year": store_establishment_year,
256
- "Store_Size": store_size,
257
- "Store_Location_City_Type": store_location_city_type,
258
- "Store_Type": store_type,
259
- }
260
-
261
- return input_data
262
-
263
- return None
264
-
265
-
266
- def display_prediction_result(prediction_data: Dict):
267
- """Display the prediction result with EDA-based insights."""
268
- predicted_sales = prediction_data["predicted_sales"]
269
-
270
- # Main prediction display
271
- st.markdown('<div class="prediction-box">', unsafe_allow_html=True)
272
- col1, col2, col3 = st.columns([1, 2, 1])
273
-
274
- with col2:
275
- st.markdown(
276
- f"""
277
- <div style="text-align: center;">
278
- <h2>💰 Predicted Sales Revenue</h2>
279
- <h1 style="color: #28a745; font-size: 4rem;">${predicted_sales:,.2f}</h1>
280
- </div>
281
- """,
282
- unsafe_allow_html=True,
283
- )
284
 
285
- st.markdown("</div>", unsafe_allow_html=True)
286
-
287
- # EDA-based insights and business metrics
288
- st.subheader("📊 Sales Analysis & Business Insights")
289
-
290
- # Based on EDA: Sales range $33-$8,000, Mean: $3,464, Median: $3,452, Std: $1,066
291
- sales_mean = 3464
292
- sales_median = 3452
293
- sales_std = 1066
294
- sales_q1 = 2762
295
- sales_q3 = 4145
296
-
297
- col1, col2, col3, col4 = st.columns(4)
298
-
299
- with col1:
300
- # Performance vs Mean
301
- vs_mean = ((predicted_sales - sales_mean) / sales_mean) * 100
302
- delta_color = "normal" if abs(vs_mean) < 10 else "inverse"
303
- st.metric(
304
- label="📊 vs Dataset Mean",
305
- value=f"${predicted_sales:,.2f}",
306
- delta=f"{vs_mean:+.1f}%",
307
- delta_color=delta_color,
308
- )
309
 
310
- with col2:
311
- # Performance vs Median
312
- vs_median = ((predicted_sales - sales_median) / sales_median) * 100
313
- delta_color = "normal" if abs(vs_median) < 10 else "inverse"
314
- st.metric(
315
- label="📈 vs Dataset Median",
316
- value=f"${sales_median:,.2f}",
317
- delta=f"{vs_median:+.1f}%",
318
- delta_color=delta_color,
319
- )
320
 
321
- with col3:
322
- # Percentile ranking based on EDA quartiles
323
- if predicted_sales <= sales_q1:
324
- percentile = "Bottom 25%"
325
- percentile_color = "🔴"
326
- elif predicted_sales <= sales_median:
327
- percentile = "25th-50th"
328
- percentile_color = "🟡"
329
- elif predicted_sales <= sales_q3:
330
- percentile = "50th-75th"
331
- percentile_color = "🟠"
332
- else:
333
- percentile = "Top 25%"
334
- percentile_color = "🟢"
335
-
336
- st.metric(
337
- label="🎯 Performance Percentile",
338
- value=f"{percentile_color} {percentile}",
339
- delta=None,
340
- )
341
 
342
- with col4:
343
- # Standard deviation analysis
344
- z_score = (predicted_sales - sales_mean) / sales_std
345
- if abs(z_score) <= 1:
346
- volatility = "Normal"
347
- vol_color = "🟢"
348
- elif abs(z_score) <= 2:
349
- volatility = "Moderate"
350
- vol_color = "🟡"
351
- else:
352
- volatility = "High"
353
- vol_color = "🔴"
354
-
355
- st.metric(
356
- label="📉 Sales Volatility",
357
- value=f"{vol_color} {volatility}",
358
- delta=f"σ: {z_score:+.1f}",
359
- )
360
 
361
- # Business insights section
362
- st.subheader("💼 Business Recommendations & Next Steps")
 
 
363
 
364
- # Performance Summary Box
365
- if predicted_sales >= sales_q3: # Top 25%
366
- performance_level = "⭐ Excellent"
367
- performance_color = "#28a745"
368
- summary_message = (
369
- "This product is predicted to perform in the top 25% of SuperKart sales!"
370
- )
371
- elif predicted_sales >= sales_median: # Above median
372
- performance_level = "✅ Good"
373
- performance_color = "#17a2b8"
374
- summary_message = (
375
- "This product is predicted to perform above the historical average."
376
- )
377
- elif predicted_sales >= sales_q1: # Above bottom quartile
378
- performance_level = "⚠️ Below Average"
379
- performance_color = "#ffc107"
380
- summary_message = (
381
- "This product may underperform compared to typical SuperKart sales."
382
- )
383
- else: # Bottom 25%
384
- performance_level = "🔴 Needs Attention"
385
- performance_color = "#dc3545"
386
- summary_message = (
387
- "This product is predicted to be in the bottom 25% of sales performance."
388
- )
389
 
390
- # Performance summary box
391
- st.markdown(
392
- f"""
393
- <div style="background-color: {performance_color}20; padding: 20px; border-radius: 10px;
394
- border-left: 5px solid {performance_color}; margin: 15px 0;">
395
- <h4 style="color: {performance_color}; margin: 0 0 10px 0;">
396
- {performance_level} Performance Expected
397
- </h4>
398
- <p style="margin: 0; font-size: 16px;">{summary_message}</p>
399
- </div>
400
- """,
401
- unsafe_allow_html=True,
402
- )
403
 
404
- # Three-column layout for insights
405
- col1, col2, col3 = st.columns(3)
406
-
407
- with col1:
408
- st.markdown("#### 💰 Financial Impact")
409
-
410
- # Revenue tier classification (moved to top for consistency)
411
- if predicted_sales >= 5000:
412
- tier = "🏆 Premium Tier"
413
- elif predicted_sales >= 3000:
414
- tier = "🥈 Standard Tier"
415
- else:
416
- tier = "🥉 Value Tier"
417
- st.info(f"**Revenue Classification:** {tier}")
418
-
419
- # Financial metrics with clear labels
420
- profit_margin = 0.2 # 20% profit margin
421
- estimated_profit = predicted_sales * profit_margin
422
- st.metric("Predicted Revenue", f"${predicted_sales:,.0f}")
423
- st.metric("Estimated Profit (20%)", f"${estimated_profit:,.0f}")
424
-
425
- with col2:
426
- st.markdown("#### 📊 Market Position")
427
-
428
- # Clear market positioning
429
- vs_mean_pct = ((predicted_sales - sales_mean) / sales_mean) * 100
430
- if vs_mean_pct > 10:
431
- position = "🚀 Above Market Average"
432
- elif vs_mean_pct > -10:
433
- position = "📊 Market Average"
434
- else:
435
- position = "📉 Below Market Average"
436
-
437
- st.success(position)
438
- st.write(f"**vs Historical Mean:** {vs_mean_pct:+.1f}%")
439
- st.write("**Market Range:** \\$33 - \\$8,000")
440
- st.write(f"**Your Prediction:** ${predicted_sales:,.0f}")
441
-
442
- with col3:
443
- st.markdown("#### 🎯 Action Items")
444
-
445
- # Clear, actionable recommendations
446
- if predicted_sales < sales_q1:
447
- st.warning("**Low Performance Risk**")
448
- st.write("**Immediate Actions:**")
449
- st.write("• Launch promotional campaign")
450
- st.write("• Review pricing strategy")
451
- st.write("• Optimize product placement")
452
- st.write("• Analyze competitor offerings")
453
- elif predicted_sales > sales_q3:
454
- st.success("**High Performance Opportunity**")
455
- st.write("**Recommended Actions:**")
456
- st.write("• Ensure adequate stock levels")
457
- st.write("• Consider premium pricing")
458
- st.write("• Expand to similar products")
459
- st.write("• Allocate prime shelf space")
460
- else:
461
- st.info("**Standard Performance Expected**")
462
- st.write("**Monitor & Optimize:**")
463
- st.write("• Track actual vs predicted")
464
- st.write("• A/B test marketing approaches")
465
- st.write("• Monitor competitor activity")
466
- st.write("• Adjust inventory as needed")
467
-
468
-
469
- def create_input_summary(input_data: Dict):
470
- """Create a summary of input features."""
471
- st.subheader("📋 Input Summary")
472
-
473
- # Create two columns for better layout
474
- col1, col2 = st.columns(2)
475
-
476
- with col1:
477
- st.markdown("**Product Information:**")
478
- st.write(f"• Weight: {input_data['Product_Weight']} kg")
479
- st.write(f"• Sugar Content: {input_data['Product_Sugar_Content']}")
480
- st.write(f"• Display Area: {input_data['Product_Allocated_Area']:.3f}")
481
- st.write(f"• Type: {input_data['Product_Type']}")
482
- st.write(f"• MRP: ${input_data['Product_MRP']:.2f}")
483
-
484
- with col2:
485
- st.markdown("**Store Information:**")
486
- st.write(f"• Establishment Year: {input_data['Store_Establishment_Year']}")
487
- st.write(f"• Size: {input_data['Store_Size']}")
488
- st.write(f"• City Type: {input_data['Store_Location_City_Type']}")
489
- st.write(f"• Store Type: {input_data['Store_Type']}")
490
-
491
-
492
- def create_batch_prediction():
493
- """Create batch prediction interface."""
494
- st.header("📊 Batch Prediction")
495
-
496
- st.markdown("""
497
- Upload a CSV file with multiple products to get batch predictions.
498
- The CSV should contain all required columns with the same names as in the single prediction form.
499
- """)
500
-
501
- # File uploader
502
- uploaded_file = st.file_uploader(
503
- "Choose a CSV file",
504
- type="csv",
505
- help="Upload a CSV file with product and store features",
506
  )
507
 
508
- if uploaded_file is not None:
509
- try:
510
- # Read the CSV file
511
- df = pd.read_csv(uploaded_file)
512
-
513
- # Display the uploaded data
514
- st.subheader("📂 Uploaded Data")
515
- st.dataframe(df.head(10))
516
 
517
- if st.button("🚀 Run Batch Prediction", type="primary"):
518
- # Convert DataFrame to list of dictionaries
519
- predictions_data = df.to_dict("records")
520
 
521
- # Make batch prediction request
522
- result = make_api_request(
523
- "/predict/batch", {"predictions": predictions_data}, "POST"
524
- )
525
 
526
- if result["success"]:
527
- batch_results = result["data"]
528
-
529
- # Display results
530
- st.subheader("📈 Batch Prediction Results")
531
-
532
- col1, col2, col3 = st.columns(3)
533
- with col1:
534
- st.metric(
535
- "✅ Successful", batch_results["successful_predictions"]
536
- )
537
- with col2:
538
- st.metric("❌ Failed", batch_results["failed_predictions"])
539
- with col3:
540
- st.metric("📊 Total", len(predictions_data))
541
-
542
- # Show successful predictions
543
- if batch_results["results"]:
544
- st.subheader("🎯 Successful Predictions")
545
-
546
- # Create a user-friendly results DataFrame
547
- display_results = []
548
- for result in batch_results["results"]:
549
- # Extract readable product info
550
- input_features = result["input_features"]
551
-
552
- # Determine performance category
553
- sales = result["predicted_sales"]
554
- if sales >= 4145: # Top 25% (Q3)
555
- category = "🟢 High"
556
- elif sales >= 3452: # Above median
557
- category = "🟡 Good"
558
- elif sales >= 2762: # Above Q1
559
- category = "🟠 Average"
560
- else:
561
- category = "🔴 Low"
562
-
563
- display_row = {
564
- "Row": result["index"] + 1,
565
- "Product Type": input_features["Product_Type"],
566
- "Weight (kg)": input_features["Product_Weight"],
567
- "MRP ($)": f"${input_features['Product_MRP']:.2f}",
568
- "Store Size": input_features["Store_Size"],
569
- "Store Type": input_features["Store_Type"],
570
- "Predicted Sales": f"${sales:,.2f}",
571
- "Performance": category,
572
- }
573
- display_results.append(display_row)
574
-
575
- display_df = pd.DataFrame(display_results)
576
-
577
- # Show the clean results table
578
- st.dataframe(
579
- display_df, use_container_width=True, hide_index=True
580
- )
581
-
582
- # Summary statistics
583
- sales_values = [
584
- result["predicted_sales"]
585
- for result in batch_results["results"]
586
- ]
587
-
588
- col1, col2, col3, col4 = st.columns(4)
589
- with col1:
590
- st.metric("💰 Total Revenue", f"${sum(sales_values):,.0f}")
591
- with col2:
592
- st.metric(
593
- "📊 Average Sale",
594
- f"${sum(sales_values) / len(sales_values):,.0f}",
595
- )
596
- with col3:
597
- high_performers = len(
598
- [s for s in sales_values if s >= 4145]
599
- )
600
- st.metric("🟢 High Performers", f"{high_performers}")
601
- with col4:
602
- low_performers = len([s for s in sales_values if s < 2762])
603
- st.metric("🔴 Needs Attention", f"{low_performers}")
604
-
605
- # Download options
606
- col1, col2 = st.columns(2)
607
- with col1:
608
- # Download user-friendly results
609
- csv_display = display_df.to_csv(index=False)
610
- st.download_button(
611
- label="📥 Download Summary Results",
612
- data=csv_display,
613
- file_name="batch_predictions_summary.csv",
614
- mime="text/csv",
615
- )
616
-
617
- with col2:
618
- # Download detailed results for technical users
619
- detailed_results = []
620
- for result in batch_results["results"]:
621
- detailed_row = {
622
- "row_index": result["index"],
623
- "predicted_sales": result["predicted_sales"],
624
- **result["input_features"],
625
- }
626
- detailed_results.append(detailed_row)
627
-
628
- detailed_df = pd.DataFrame(detailed_results)
629
- csv_detailed = detailed_df.to_csv(index=False)
630
- st.download_button(
631
- label="🔧 Download Detailed Results",
632
- data=csv_detailed,
633
- file_name="batch_predictions_detailed.csv",
634
- mime="text/csv",
635
- )
636
-
637
- # Show errors if any
638
- if batch_results["errors"]:
639
- st.subheader("⚠️ Prediction Errors")
640
- errors_df = pd.DataFrame(batch_results["errors"])
641
- st.dataframe(errors_df)
642
-
643
- else:
644
- st.error(f"Batch prediction failed: {result['error']}")
645
-
646
- except Exception as e:
647
- st.error(f"Error processing file: {str(e)}")
648
-
649
-
650
- def main():
651
- """Main application function."""
652
- # Title and description
653
- st.markdown(
654
- '<h1 class="main-header">🛒 SuperKart Sales Predictor</h1>',
655
- unsafe_allow_html=True,
656
- )
657
 
658
- st.markdown(
659
- """
660
- <div style="text-align: center; margin-bottom: 2rem;">
661
- <p style="font-size: 1.2rem; color: #666;">
662
- Predict product sales revenue using machine learning based on product and store characteristics
663
- </p>
664
- </div>
665
- """,
666
- unsafe_allow_html=True,
667
- )
668
 
669
- # Check backend health
670
- health_result = make_api_request("/")
671
- if not health_result["success"]:
672
- st.error(
673
- f"⚠️ Backend API is not available at `{BACKEND_URL}`. Please ensure the backend service is running."
674
- )
675
- st.info(
676
- """
677
- **How to specify a different backend URL:**
678
-
679
- 1. **Command line argument:**
680
- ```
681
- streamlit run app.py -- --backend-url http://your-backend:5050
682
- ```
683
-
684
- 2. **Environment variable:**
685
- ```
686
- export BACKEND_URL=http://your-backend:5050
687
- streamlit run app.py
688
- ```
689
- """
690
- )
691
- st.stop()
692
-
693
- # Sidebar navigation
694
- st.sidebar.title("🧭 Navigation")
695
-
696
- # Display current backend URL and connection status
697
- st.sidebar.markdown("---")
698
- st.sidebar.markdown("**🔗 Backend Configuration**")
699
- st.sidebar.code(BACKEND_URL, language=None)
700
-
701
- # Show connection status
702
- if health_result["success"]:
703
- st.sidebar.success("🟢 Connected")
704
- if "data" in health_result and "model_loaded" in health_result["data"]:
705
- model_status = (
706
- "🤖 Model Loaded"
707
- if health_result["data"]["model_loaded"]
708
- else "⚠️ Model Not Loaded"
709
- )
710
- st.sidebar.info(model_status)
711
- else:
712
- st.sidebar.error("🔴 Disconnected")
713
-
714
- st.sidebar.markdown("---")
715
-
716
- app_mode = st.sidebar.selectbox(
717
- "Choose App Mode",
718
- ["Single Prediction", "Batch Prediction", "API Documentation"],
719
- )
720
 
721
- if app_mode == "Single Prediction":
722
- # Single prediction interface
723
- input_data = create_input_form()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724
 
725
- if input_data:
726
- # Make prediction
727
- result = make_api_request("/predict", input_data, "POST")
728
 
729
- if result["success"]:
730
- prediction_data = result["data"]
731
 
732
- # Display results
733
- display_prediction_result(prediction_data)
 
734
 
735
- # Show input summary
736
- with st.expander("📋 View Input Details", expanded=False):
737
- create_input_summary(input_data)
738
 
739
- # Success message
740
- st.markdown(
741
- '<div class="success-box">✅ Prediction completed successfully!</div>',
742
- unsafe_allow_html=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
743
  )
744
 
745
- else:
746
- st.markdown(
747
- f'<div class="error-box">❌ Prediction failed: {result["error"]}</div>',
748
- unsafe_allow_html=True,
749
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750
 
751
- elif app_mode == "Batch Prediction":
752
- create_batch_prediction()
753
-
754
- elif app_mode == "API Documentation":
755
- st.header("📚 API Documentation")
756
-
757
- # Get feature information
758
- feature_info = get_feature_info()
759
-
760
- if feature_info:
761
- st.subheader("🔧 Required Features")
762
-
763
- features_df = pd.DataFrame(
764
- [
765
- {"Feature": k, "Description": v}
766
- for k, v in feature_info["feature_descriptions"].items()
767
- ]
768
- )
769
- st.table(features_df)
770
-
771
- st.subheader("📝 Example Input")
772
- st.json(feature_info["example_input"])
773
-
774
- st.subheader("🌐 API Endpoints")
775
- st.markdown("""
776
- - **GET /**: Health check
777
- - **POST /predict**: Single prediction
778
- - **POST /predict/batch**: Batch prediction
779
- - **GET /features**: Get feature information
780
- """)
781
-
782
- # Footer
783
- st.markdown("---")
784
- st.markdown(
785
- "<div style='text-align: center; color: #666;'>"
786
- "SuperKart Sales Prediction System | Krishnaswamy Subramanian"
787
- "</div>",
788
- unsafe_allow_html=True,
789
- )
790
 
 
 
 
791
 
792
  if __name__ == "__main__":
793
- main()
 
 
 
1
  """
2
+ SuperKart Sales Prediction Flask API
3
 
4
+ This Flask application provides a REST API for predicting product sales using a pre-trained
5
+ Random Forest model. The API accepts product and store features and returns predicted sales revenue.
 
6
  """
7
 
 
 
 
 
 
8
  import os
9
+ import joblib
10
+ import pandas as pd
11
+ from flask import Flask, request, jsonify
12
+ from flask_cors import CORS
13
+ import logging
14
+ from typing import Any, Dict
15
+ from pydantic import BaseModel, ValidationError, field_validator
16
+ from datetime import datetime
17
+
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Initialize Flask app
23
+ app = Flask(__name__)
24
+ CORS(app) # Enable CORS for frontend integration
25
+
26
+ # Global variables for model and preprocessing pipeline
27
+ model = None
28
+ feature_columns = None
29
+
30
+ # Define user input features (what user provides)
31
+ USER_INPUT_FEATURES = [
32
+ "Product_Weight",
33
+ "Product_Sugar_Content",
34
+ "Product_Allocated_Area",
35
+ "Product_Type",
36
+ "Product_MRP",
37
+ "Store_Establishment_Year",
38
+ "Store_Size",
39
+ "Store_Location_City_Type",
40
+ "Store_Type",
41
+ ]
42
+
43
+ # Define model features (what model expects after preprocessing)
44
+ MODEL_FEATURES = [
45
+ "Product_Weight",
46
+ "Product_Sugar_Content",
47
+ "Product_Allocated_Area",
48
+ "Product_Type",
49
+ "Product_MRP",
50
+ "Store_Size",
51
+ "Store_Location_City_Type",
52
+ "Store_Type",
53
+ "Store_Age",
54
+ ]
55
+
56
+
57
+ # Pydantic model for input validation
58
+ class PredictionInput(BaseModel):
59
+ Product_Weight: float
60
+ Product_Sugar_Content: str
61
+ Product_Allocated_Area: float
62
+ Product_Type: str
63
+ Product_MRP: float
64
+ Store_Establishment_Year: int
65
+ Store_Size: str
66
+ Store_Location_City_Type: str
67
+ Store_Type: str
68
+
69
+ @field_validator("Product_Weight")
70
+ @classmethod
71
+ def validate_product_weight(cls, v: float) -> float:
72
+ if v <= 0:
73
+ raise ValueError("Product_Weight must be greater than 0")
74
+ if v < 4.0 or v > 22.0:
75
+ raise ValueError("Product_Weight must be between 4.0 and 22.0")
76
+ return v
77
+
78
+ @field_validator("Product_Allocated_Area")
79
+ @classmethod
80
+ def validate_allocated_area(cls, v: float) -> float:
81
+ if v < 0 or v > 1:
82
+ raise ValueError("Product_Allocated_Area must be between 0 and 1")
83
+ return v
84
+
85
+ @field_validator("Product_MRP")
86
+ @classmethod
87
+ def validate_mrp(cls, v: float) -> float:
88
+ if v <= 0:
89
+ raise ValueError("Product_MRP must be greater than 0")
90
+ if v < 31.0 or v > 266.0:
91
+ raise ValueError("Product_MRP must be between 31.0 and 266.0")
92
+ return v
93
+
94
+ @field_validator("Store_Establishment_Year")
95
+ @classmethod
96
+ def validate_establishment_year(cls, v: int) -> int:
97
+ valid_years = [1987, 1998, 1999, 2009]
98
+ if v not in valid_years:
99
+ raise ValueError(f"Store_Establishment_Year must be one of: {valid_years}")
100
+ return v
101
+
102
+ @field_validator("Product_Sugar_Content")
103
+ @classmethod
104
+ def validate_sugar_content(cls, v: str) -> str:
105
+ valid = ["Low Sugar", "Regular", "No Sugar"]
106
+ if v not in valid:
107
+ raise ValueError(f"Product_Sugar_Content must be one of: {valid}")
108
+ return v
109
+
110
+ @field_validator("Product_Type")
111
+ @classmethod
112
+ def validate_product_type(cls, v: str) -> str:
113
+ valid = [
114
+ "Dairy",
115
+ "Soft Drinks",
116
+ "Meat",
117
+ "Fruits and Vegetables",
118
+ "Household",
119
+ "Baking Goods",
120
+ "Snack Foods",
121
+ "Frozen Foods",
122
+ "Breakfast",
123
+ "Health and Hygiene",
124
+ "Hard Drinks",
125
+ "Canned",
126
+ "Bread",
127
+ "Starchy Foods",
128
+ "Others",
129
+ "Seafood",
130
+ ]
131
+ if v not in valid:
132
+ raise ValueError(f"Product_Type must be one of: {valid}")
133
+ return v
134
+
135
+ @field_validator("Store_Size")
136
+ @classmethod
137
+ def validate_store_size(cls, v: str) -> str:
138
+ valid = ["Small", "Medium", "High"]
139
+ if v not in valid:
140
+ raise ValueError(f"Store_Size must be one of: {valid}")
141
+ return v
142
+
143
+ @field_validator("Store_Location_City_Type")
144
+ @classmethod
145
+ def validate_city_type(cls, v: str) -> str:
146
+ valid = ["Tier 1", "Tier 2", "Tier 3"]
147
+ if v not in valid:
148
+ raise ValueError(f"Store_Location_City_Type must be one of: {valid}")
149
+ return v
150
+
151
+ @field_validator("Store_Type")
152
+ @classmethod
153
+ def validate_store_type(cls, v: str) -> str:
154
+ valid = [
155
+ "Supermarket Type1",
156
+ "Supermarket Type2",
157
+ "Supermarket Type3",
158
+ "Departmental Store",
159
+ "Food Mart",
160
+ ]
161
+ if v not in valid:
162
+ raise ValueError(f"Store_Type must be one of: {valid}")
163
+ return v
164
+
165
+
166
+ def load_model(model_path: str):
167
  """
168
+ Load the trained model from the specified path.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ Args:
171
+ model_path (str): Path to the model file.
 
 
 
 
172
 
173
+ Returns:
174
+ bool: True if model loaded successfully, False otherwise.
175
+ """
176
+ global model, feature_columns
177
 
178
+ try:
179
+ if not os.path.exists(model_path):
180
+ raise FileNotFoundError(f"Model file not found at: {model_path}")
181
 
182
+ # Load the trained model (which includes preprocessing pipeline)
183
+ model = joblib.load(model_path)
184
+ logger.info(f"✅ Model loaded successfully from: {model_path}")
185
 
186
+ # Set feature columns
187
+ feature_columns = MODEL_FEATURES
188
+ logger.info(f"📋 Model features: {MODEL_FEATURES}")
189
+ logger.info(f"📋 User input features: {USER_INPUT_FEATURES}")
190
 
191
+ return True
 
 
 
192
 
193
+ except Exception as e:
194
+ logger.error(f"❌ Error loading model: {str(e)}")
195
+ return False
 
196
 
 
 
197
 
198
+ def convert_establishment_year_to_age(data: Dict[str, Any]) -> Dict[str, Any]:
199
+ """Convert Store_Establishment_Year to Store_Age."""
200
+ # Create a copy to avoid modifying the original
201
+ converted_data = data.copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ # Get current year
204
+ current_year = datetime.now().year
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ # Convert establishment year to age
207
+ if "Store_Establishment_Year" in converted_data:
208
+ establishment_year = converted_data.pop("Store_Establishment_Year")
209
+ converted_data["Store_Age"] = current_year - establishment_year
 
 
 
 
 
 
210
 
211
+ return converted_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
+ def preprocess_input(data: Dict[str, Any]) -> pd.DataFrame:
215
+ """Convert input data to DataFrame format expected by the model."""
216
+ # First convert establishment year to age
217
+ converted_data = convert_establishment_year_to_age(data)
218
 
219
+ # Create DataFrame with model features
220
+ df = pd.DataFrame([converted_data])
221
+ df = df[MODEL_FEATURES]
222
+ return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
+ @app.route("/", methods=["GET"])
226
+ def health_check():
227
+ """Health check endpoint."""
228
+ return jsonify(
229
+ {
230
+ "status": "healthy",
231
+ "message": "SuperKart Sales Prediction API is running",
232
+ "model_loaded": model is not None,
233
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  )
235
 
 
 
 
 
 
 
 
 
236
 
237
+ @app.route("/predict", methods=["POST"])
238
+ def predict():
239
+ """Predict sales for given product and store features."""
240
 
241
+ if model is None:
242
+ return jsonify({"error": "Model not loaded. Please check server logs."}), 500
 
 
243
 
244
+ try:
245
+ # Get JSON data from request
246
+ data = request.get_json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
+ if not data:
249
+ return jsonify(
250
+ {
251
+ "error": "No data provided. Please send JSON data in the request body."
252
+ }
253
+ ), 400
 
 
 
 
254
 
255
+ # Validate input using Pydantic
256
+ try:
257
+ validated = PredictionInput(**data)
258
+ except ValidationError as ve:
259
+ return jsonify(
260
+ {"error": "Input validation failed", "details": ve.errors()}
261
+ ), 400
262
+
263
+ # Preprocess input data
264
+ input_df = preprocess_input(validated.model_dump())
265
+
266
+ # Make prediction
267
+ prediction = model.predict(input_df)
268
+ predicted_sales = float(prediction[0])
269
+
270
+ # Prepare response
271
+ response = {
272
+ "predicted_sales": round(predicted_sales, 2),
273
+ "currency": "USD",
274
+ "input_features": validated.model_dump(),
275
+ "status": "success",
276
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
+ logger.info(f"✅ Prediction successful: ${predicted_sales:.2f}")
279
+ return jsonify(response)
280
+
281
+ except Exception as e:
282
+ logger.error(f"❌ Prediction error: {str(e)}")
283
+ return jsonify({"error": f"Prediction failed: {str(e)}"}), 500
284
+
285
+
286
+ @app.route("/features", methods=["GET"])
287
+ def get_features():
288
+ """Get information about expected input features."""
289
+
290
+ feature_info = {
291
+ "required_features": USER_INPUT_FEATURES,
292
+ "feature_descriptions": {
293
+ "Product_Weight": "Weight of the product (4.0-22.0 kg)",
294
+ "Product_Sugar_Content": "Sugar content (Low Sugar, Regular, No Sugar)",
295
+ "Product_Allocated_Area": "Allocated display area ratio (0.0-1.0)",
296
+ "Product_Type": "Product category (16 types: Dairy, Soft Drinks, Meat, etc.)",
297
+ "Product_MRP": "Maximum retail price (31.0-266.0 USD)",
298
+ "Store_Establishment_Year": "Year store was established (1987, 1998, 1999, 2009)",
299
+ "Store_Size": "Store size (Small, Medium, High)",
300
+ "Store_Location_City_Type": "City type (Tier 1, Tier 2, Tier 3)",
301
+ "Store_Type": "Store type (Supermarket Type1/2/3, Departmental Store, Food Mart)",
302
+ },
303
+ "example_input": {
304
+ "Product_Weight": 12.66,
305
+ "Product_Sugar_Content": "Low Sugar",
306
+ "Product_Allocated_Area": 0.027,
307
+ "Product_Type": "Frozen Foods",
308
+ "Product_MRP": 117.08,
309
+ "Store_Establishment_Year": 2009,
310
+ "Store_Size": "Medium",
311
+ "Store_Location_City_Type": "Tier 2",
312
+ "Store_Type": "Supermarket Type2",
313
+ },
314
+ }
315
 
316
+ return jsonify(feature_info)
 
 
317
 
 
 
318
 
319
+ @app.route("/predict/batch", methods=["POST"])
320
+ def predict_batch():
321
+ """Predict sales for multiple products at once."""
322
 
323
+ if model is None:
324
+ return jsonify({"error": "Model not loaded. Please check server logs."}), 500
 
325
 
326
+ try:
327
+ # Get JSON data from request
328
+ data = request.get_json()
329
+
330
+ if not data or "predictions" not in data:
331
+ return jsonify(
332
+ {
333
+ "error": 'No data provided. Please send JSON with "predictions" array.'
334
+ }
335
+ ), 400
336
+
337
+ predictions_data = data["predictions"]
338
+ if not isinstance(predictions_data, list):
339
+ return jsonify({"error": "Predictions must be an array of objects."}), 400
340
+
341
+ results = []
342
+ errors = []
343
+
344
+ for i, item in enumerate(predictions_data):
345
+ try:
346
+ # Validate input using Pydantic
347
+ try:
348
+ validated = PredictionInput(**item)
349
+ except ValidationError as ve:
350
+ errors.append({"index": i, "error": ve.errors(), "input": item})
351
+ continue
352
+
353
+ # Preprocess and predict
354
+ input_df = preprocess_input(validated.model_dump())
355
+ prediction = model.predict(input_df)
356
+ predicted_sales = float(prediction[0])
357
+
358
+ results.append(
359
+ {
360
+ "index": i,
361
+ "predicted_sales": round(predicted_sales, 2),
362
+ "input_features": validated.model_dump(),
363
+ }
364
  )
365
 
366
+ except Exception as e:
367
+ errors.append({"index": i, "error": str(e), "input": item})
368
+
369
+ response = {
370
+ "successful_predictions": len(results),
371
+ "failed_predictions": len(errors),
372
+ "results": results,
373
+ "errors": errors,
374
+ "status": "completed",
375
+ }
376
+
377
+ logger.info(
378
+ f"✅ Batch prediction completed: {len(results)} successful, {len(errors)} failed"
379
+ )
380
+ return jsonify(response)
381
+
382
+ except Exception as e:
383
+ logger.error(f"❌ Batch prediction error: {str(e)}")
384
+ return jsonify({"error": f"Batch prediction failed: {str(e)}"}), 500
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
+ # Load model on module import (for Gunicorn compatibility)
388
+ if not load_model("./superkart_model.joblib"):
389
+ logger.error("❌ Failed to load model. Application may not work properly.")
390
 
391
  if __name__ == "__main__":
392
+ # This runs only when script is executed directly (not imported by Gunicorn)
393
+ logger.info("🚀 Starting SuperKart Sales Prediction API...")
394
+ app.run(host="0.0.0.0", port=7860, debug=True)
requirements.txt CHANGED
@@ -1,5 +1,8 @@
1
- streamlit==1.29.0
2
- requests==2.32.3
 
3
  pandas==2.2.2
4
- plotly==5.17.0
5
- watchdog==6.0.0
 
 
 
1
+ Flask==3.0.0
2
+ flask-cors==4.0.0
3
+ joblib==1.4.2
4
  pandas==2.2.2
5
+ numpy==2.0.2
6
+ scikit-learn==1.6.1
7
+ gunicorn==21.2.0
8
+ pydantic==2.5.0