GLAkavya commited on
Commit
c6937ca
·
verified ·
1 Parent(s): 38ddd40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -633
app.py CHANGED
@@ -1,644 +1,161 @@
1
- """
2
- 🌟 PROFESSIONAL CUSTOMER FEEDBACK RATING PREDICTOR
3
- Complete Dashboard with CSV, URL, and Text Input
4
- """
5
 
6
- import gradio as gr
7
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
- import torch
9
  import pandas as pd
 
 
 
 
 
 
10
  import plotly.graph_objects as go
11
- import plotly.express as px
12
- from collections import Counter
13
- import requests
14
- from bs4 import BeautifulSoup
15
- import json
16
-
17
- # ============================================================================
18
- # MODEL LOADING
19
- # ============================================================================
20
-
21
- # 🔴 CHANGE THIS TO YOUR MODEL
22
- MODEL_NAME = "nlptown/bert-base-multilingual-uncased-sentiment" # Demo model
23
- # MODEL_NAME = "YOUR_USERNAME/feedback-rating-predictor" # Your trained model
24
-
25
- try:
26
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
27
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
28
- print("✅ Model loaded successfully!")
29
- except Exception as e:
30
- print(f"❌ Error: {e}")
31
-
32
- # ============================================================================
33
- # PREDICTION FUNCTIONS
34
- # ============================================================================
35
-
36
- def predict_single(text):
37
- """Predict rating for single text"""
38
- if not text or len(text.strip()) < 3:
39
- return None
40
-
41
- try:
42
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
43
- with torch.no_grad():
44
- outputs = model(**inputs)
45
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
46
- pred_class = torch.argmax(probs).item()
47
- confidence = probs[0][pred_class].item()
48
-
49
- rating = pred_class + 1
50
- all_probs = probs[0].cpu().numpy()
51
-
52
- return {
53
- 'text': text,
54
- 'rating': rating,
55
- 'confidence': confidence,
56
- 'probabilities': all_probs,
57
- 'sentiment': 'Negative' if rating <= 2 else ('Neutral' if rating == 3 else 'Positive')
58
- }
59
- except Exception as e:
60
- print(f"Error in prediction: {e}")
61
- return None
62
 
63
- def predict_batch(texts):
64
- """Predict ratings for multiple texts"""
65
- results = []
66
- for text in texts:
67
- result = predict_single(text)
68
- if result:
69
- results.append(result)
70
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- # ============================================================================
73
- # DATA PROCESSING FUNCTIONS
74
- # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- def process_csv(file):
77
- """Process uploaded CSV file"""
78
- try:
79
- df = pd.read_csv(file.name)
80
-
81
- # Try to find text column
82
- text_columns = ['feedback', 'review', 'text', 'comment', 'Review Text', 'Feedback']
83
- text_col = None
84
-
85
- for col in text_columns:
86
- if col in df.columns:
87
- text_col = col
88
- break
89
-
90
- if text_col is None:
91
- text_col = df.columns[0] # Use first column
92
-
93
- texts = df[text_col].dropna().astype(str).tolist()[:100] # Limit to 100 for performance
94
-
95
- return texts
96
- except Exception as e:
97
- return [f"Error reading CSV: {str(e)}"]
98
 
99
- def fetch_from_url(url):
100
- """Fetch reviews from URL (basic scraping)"""
101
  try:
102
- headers = {'User-Agent': 'Mozilla/5.0'}
103
- response = requests.get(url, headers=headers, timeout=10)
104
- soup = BeautifulSoup(response.content, 'html.parser')
105
-
106
- # Try to find review-like content
107
- reviews = []
108
-
109
- # Look for common review patterns
110
- for tag in soup.find_all(['p', 'div', 'span'], class_=lambda x: x and any(
111
- word in str(x).lower() for word in ['review', 'comment', 'feedback']
112
- )):
113
- text = tag.get_text().strip()
114
- if len(text) > 20 and len(text) < 1000:
115
- reviews.append(text)
116
-
117
- if not reviews:
118
- # Fallback: get all paragraph texts
119
- reviews = [p.get_text().strip() for p in soup.find_all('p') if len(p.get_text().strip()) > 20]
120
-
121
- return reviews[:50] # Limit to 50
122
  except Exception as e:
123
- return [f"Error fetching URL: {str(e)}"]
124
-
125
- # ============================================================================
126
- # VISUALIZATION FUNCTIONS
127
- # ============================================================================
128
-
129
- def create_rating_pie_chart(results):
130
- """Create pie chart for rating distribution"""
131
- ratings = [r['rating'] for r in results]
132
- rating_counts = Counter(ratings)
133
-
134
- fig = go.Figure(data=[go.Pie(
135
- labels=[f"{i}⭐" for i in range(1, 6)],
136
- values=[rating_counts.get(i, 0) for i in range(1, 6)],
137
- hole=0.4,
138
- marker=dict(colors=['#e74c3c', '#e67e22', '#f39c12', '#2ecc71', '#27ae60']),
139
- textinfo='label+percent+value',
140
- textfont=dict(size=14, color='white'),
141
- hovertemplate='<b>%{label}</b><br>Count: %{value}<br>Percentage: %{percent}<extra></extra>'
142
- )])
143
-
144
- fig.update_layout(
145
- title=dict(
146
- text="Rating Distribution",
147
- font=dict(size=20, color='#2c3e50', family='Arial Black')
148
- ),
149
- showlegend=True,
150
- height=400,
151
- paper_bgcolor='rgba(0,0,0,0)',
152
- plot_bgcolor='rgba(0,0,0,0)',
153
- font=dict(size=12)
154
- )
155
-
156
- return fig
157
-
158
- def create_sentiment_bar_chart(results):
159
- """Create bar chart for sentiment distribution"""
160
- sentiments = [r['sentiment'] for r in results]
161
- sentiment_counts = Counter(sentiments)
162
-
163
- colors = {
164
- 'Positive': '#27ae60',
165
- 'Neutral': '#f39c12',
166
- 'Negative': '#e74c3c'
167
- }
168
-
169
- fig = go.Figure(data=[go.Bar(
170
- x=list(sentiment_counts.keys()),
171
- y=list(sentiment_counts.values()),
172
- marker=dict(
173
- color=[colors.get(s, '#3498db') for s in sentiment_counts.keys()],
174
- line=dict(color='white', width=2)
175
- ),
176
- text=list(sentiment_counts.values()),
177
- textposition='outside',
178
- textfont=dict(size=16, color='#2c3e50', family='Arial Black'),
179
- hovertemplate='<b>%{x}</b><br>Count: %{y}<extra></extra>'
180
- )])
181
-
182
- fig.update_layout(
183
- title=dict(
184
- text="Sentiment Analysis",
185
- font=dict(size=20, color='#2c3e50', family='Arial Black')
186
- ),
187
- xaxis=dict(title="Sentiment", titlefont=dict(size=14)),
188
- yaxis=dict(title="Count", titlefont=dict(size=14)),
189
- height=400,
190
- paper_bgcolor='rgba(0,0,0,0)',
191
- plot_bgcolor='rgba(240,240,240,0.5)',
192
- font=dict(size=12),
193
- showlegend=False
194
- )
195
-
196
- return fig
197
-
198
- def create_confidence_histogram(results):
199
- """Create histogram for confidence scores"""
200
- confidences = [r['confidence'] * 100 for r in results]
201
-
202
- fig = go.Figure(data=[go.Histogram(
203
- x=confidences,
204
- nbinsx=20,
205
- marker=dict(
206
- color='#3498db',
207
- line=dict(color='white', width=1)
208
- ),
209
- hovertemplate='Confidence: %{x:.1f}%<br>Count: %{y}<extra></extra>'
210
- )])
211
-
212
- fig.update_layout(
213
- title=dict(
214
- text="Confidence Distribution",
215
- font=dict(size=20, color='#2c3e50', family='Arial Black')
216
- ),
217
- xaxis=dict(title="Confidence (%)", titlefont=dict(size=14)),
218
- yaxis=dict(title="Frequency", titlefont=dict(size=14)),
219
- height=400,
220
- paper_bgcolor='rgba(0,0,0,0)',
221
- plot_bgcolor='rgba(240,240,240,0.5)',
222
- font=dict(size=12)
223
- )
224
-
225
- return fig
226
-
227
- def create_detailed_table(results):
228
- """Create detailed results table"""
229
- df = pd.DataFrame([{
230
- 'Feedback': r['text'][:100] + '...' if len(r['text']) > 100 else r['text'],
231
- 'Rating': '⭐' * r['rating'],
232
- 'Stars': r['rating'],
233
- 'Sentiment': r['sentiment'],
234
- 'Confidence': f"{r['confidence']*100:.1f}%"
235
- } for r in results])
236
-
237
- return df
238
-
239
- def create_summary_stats(results):
240
- """Create summary statistics"""
241
- if not results:
242
- return "No data to analyze"
243
-
244
- total = len(results)
245
- avg_rating = sum(r['rating'] for r in results) / total
246
- avg_confidence = sum(r['confidence'] for r in results) / total * 100
247
-
248
- sentiments = Counter(r['sentiment'] for r in results)
249
- ratings = Counter(r['rating'] for r in results)
250
-
251
- summary = f"""
252
- ## 📊 Analysis Summary
253
-
254
- **Total Reviews Analyzed:** {total}
255
-
256
- **Average Rating:** {'⭐' * int(avg_rating)} ({avg_rating:.2f}/5.0)
257
-
258
- **Average Confidence:** {avg_confidence:.1f}%
259
-
260
- **Sentiment Breakdown:**
261
- - 😊 Positive: {sentiments.get('Positive', 0)} ({sentiments.get('Positive', 0)/total*100:.1f}%)
262
- - 😐 Neutral: {sentiments.get('Neutral', 0)} ({sentiments.get('Neutral', 0)/total*100:.1f}%)
263
- - 😞 Negative: {sentiments.get('Negative', 0)} ({sentiments.get('Negative', 0)/total*100:.1f}%)
264
-
265
- **Rating Breakdown:**
266
- - 5⭐: {ratings.get(5, 0)} reviews
267
- - 4⭐: {ratings.get(4, 0)} reviews
268
- - 3⭐: {ratings.get(3, 0)} reviews
269
- - 2⭐: {ratings.get(2, 0)} reviews
270
- - 1⭐: {ratings.get(1, 0)} reviews
271
- """
272
-
273
- return summary
274
-
275
- # ============================================================================
276
- # MAIN PROCESSING FUNCTION
277
- # ============================================================================
278
-
279
- def analyze_feedbacks(input_type, text_input, csv_file, url_input):
280
- """Main function to analyze feedbacks from different sources"""
281
-
282
- texts = []
283
-
284
- # Get texts based on input type
285
- if input_type == "✍️ Manual Text":
286
- if text_input:
287
- texts = [t.strip() for t in text_input.split('\n') if t.strip()]
288
-
289
- elif input_type == "📄 CSV Upload":
290
- if csv_file:
291
- texts = process_csv(csv_file)
292
-
293
- elif input_type == "🌐 URL Fetch":
294
- if url_input:
295
- texts = fetch_from_url(url_input)
296
-
297
- if not texts:
298
- return (
299
- "⚠️ No valid input provided!",
300
- None, None, None, None, None
301
- )
302
-
303
- # Predict ratings
304
- results = predict_batch(texts)
305
-
306
- if not results:
307
- return (
308
- "❌ Error in prediction!",
309
- None, None, None, None, None
310
- )
311
-
312
- # Create visualizations
313
- summary = create_summary_stats(results)
314
- pie_chart = create_rating_pie_chart(results)
315
- bar_chart = create_sentiment_bar_chart(results)
316
- histogram = create_confidence_histogram(results)
317
- table = create_detailed_table(results)
318
-
319
- return summary, pie_chart, bar_chart, histogram, table
320
-
321
- # ============================================================================
322
- # SINGLE TEXT PREDICTION (CHAT MODE)
323
- # ============================================================================
324
-
325
- def predict_single_chat(text):
326
- """Predict rating for single text (chat interface)"""
327
- result = predict_single(text)
328
-
329
- if not result:
330
- return "⚠️ Please enter valid feedback", None, None
331
-
332
- # Create star display
333
- stars = "⭐" * result['rating'] + "☆" * (5 - result['rating'])
334
-
335
- # Create emoji
336
- emoji = "😞" if result['rating'] <= 2 else ("😐" if result['rating'] == 3 else "😊")
337
-
338
- # Response text
339
- response = f"""
340
- {emoji} **{result['sentiment']} Feedback**
341
-
342
- **Rating:** {stars} ({result['rating']}/5)
343
-
344
- **Confidence:** {result['confidence']*100:.1f}%
345
-
346
- **Analysis:**
347
- This feedback has been classified as **{result['sentiment'].lower()}** with high confidence.
348
- """
349
-
350
- # Probability chart
351
- prob_dict = {
352
- "1⭐": float(result['probabilities'][0]),
353
- "2⭐⭐": float(result['probabilities'][1]),
354
- "3⭐⭐⭐": float(result['probabilities'][2]),
355
- "4⭐⭐⭐⭐": float(result['probabilities'][3]),
356
- "5⭐⭐⭐⭐⭐": float(result['probabilities'][4])
357
- }
358
-
359
- # Create small viz
360
- fig = go.Figure(data=[go.Bar(
361
- x=list(prob_dict.keys()),
362
- y=list(prob_dict.values()),
363
- marker=dict(
364
- color=['#e74c3c', '#e67e22', '#f39c12', '#2ecc71', '#27ae60'],
365
- line=dict(color='white', width=2)
366
- ),
367
- text=[f"{v*100:.1f}%" for v in prob_dict.values()],
368
- textposition='outside'
369
- )])
370
-
371
- fig.update_layout(
372
- title="Rating Probabilities",
373
- height=300,
374
- showlegend=False,
375
- paper_bgcolor='rgba(0,0,0,0)',
376
- plot_bgcolor='rgba(240,240,240,0.5)'
377
- )
378
-
379
- return response, prob_dict, fig
380
-
381
- # ============================================================================
382
- # GRADIO INTERFACE
383
- # ============================================================================
384
-
385
- # Custom CSS
386
- custom_css = """
387
- <style>
388
- .gradio-container {
389
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important;
390
- }
391
- .main-header {
392
- text-align: center;
393
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
394
- padding: 2rem;
395
- border-radius: 10px;
396
- color: white;
397
- margin-bottom: 2rem;
398
- }
399
- .stat-box {
400
- background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
401
- padding: 1rem;
402
- border-radius: 10px;
403
- text-align: center;
404
- color: white;
405
- margin: 0.5rem;
406
- }
407
- </style>
408
- """
409
-
410
- # Create interface
411
- with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
412
-
413
- gr.HTML("""
414
- <div class="main-header">
415
- <h1 style="font-size: 3em; margin: 0;">🌟 Customer Feedback Rating Predictor</h1>
416
- <p style="font-size: 1.2em; margin-top: 1rem;">AI-Powered Sentiment Analysis & Rating Dashboard</p>
417
- <p style="font-size: 0.9em; opacity: 0.9;">Analyze feedback from text, CSV, or URLs with beautiful visualizations</p>
418
- </div>
419
- """)
420
-
421
- with gr.Tabs() as tabs:
422
-
423
- # ====================================================================
424
- # TAB 1: CHAT MODE (Single Text)
425
- # ====================================================================
426
- with gr.Tab("💬 Quick Analysis", id=0):
427
- gr.Markdown("### Enter any feedback to get instant rating prediction")
428
-
429
- with gr.Row():
430
- with gr.Column(scale=2):
431
- chat_input = gr.Textbox(
432
- label="✍️ Enter Feedback",
433
- placeholder="Type feedback here... e.g., 'What a good food! Loved it!' or 'Ewww, terrible service'",
434
- lines=5
435
- )
436
- chat_btn = gr.Button("🔮 Predict Rating", variant="primary", size="lg")
437
-
438
- gr.Examples(
439
- examples=[
440
- ["What a good food! Absolutely delicious! 😋"],
441
- ["Ewww, terrible taste. Never ordering again! 🤮"],
442
- ["It's okay, nothing special but edible"],
443
- ["Amazing service! Best restaurant in town! ⭐⭐⭐⭐⭐"],
444
- ["Disappointed with the quality. Expected better"],
445
- ["Pretty decent meal. Good value for money"],
446
- ],
447
- inputs=chat_input
448
- )
449
-
450
- with gr.Column(scale=1):
451
- chat_output = gr.Markdown(label="📊 Result")
452
- chat_prob = gr.Label(label="Rating Probabilities", num_top_classes=5)
453
-
454
- chat_viz = gr.Plot(label="Probability Distribution")
455
-
456
- chat_btn.click(
457
- predict_single_chat,
458
- inputs=chat_input,
459
- outputs=[chat_output, chat_prob, chat_viz]
460
- )
461
-
462
- # ====================================================================
463
- # TAB 2: BATCH ANALYSIS (CSV/URL/Multiple Texts)
464
- # ====================================================================
465
- with gr.Tab("📊 Batch Analysis Dashboard", id=1):
466
- gr.Markdown("### Analyze multiple feedbacks with comprehensive dashboard")
467
-
468
- with gr.Row():
469
- input_type = gr.Radio(
470
- choices=["✍️ Manual Text", "📄 CSV Upload", "🌐 URL Fetch"],
471
- value="✍️ Manual Text",
472
- label="Input Method"
473
- )
474
-
475
- with gr.Row():
476
- with gr.Column():
477
- text_input = gr.Textbox(
478
- label="Enter Multiple Feedbacks (one per line)",
479
- placeholder="Enter feedbacks, one per line...\nExample:\nAmazing product!\nTerrible quality\nIt's okay",
480
- lines=10,
481
- visible=True
482
- )
483
-
484
- csv_input = gr.File(
485
- label="Upload CSV File (must have 'feedback' or 'review' column)",
486
- file_types=[".csv"],
487
- visible=False
488
- )
489
-
490
- url_input = gr.Textbox(
491
- label="Enter URL (e.g., review page URL)",
492
- placeholder="https://example.com/reviews",
493
- visible=False
494
- )
495
-
496
- analyze_btn = gr.Button("🚀 Analyze All", variant="primary", size="lg")
497
-
498
- # Change visibility based on input type
499
- def update_visibility(choice):
500
- return (
501
- gr.update(visible=choice == "✍️ Manual Text"),
502
- gr.update(visible=choice == "📄 CSV Upload"),
503
- gr.update(visible=choice == "🌐 URL Fetch")
504
- )
505
-
506
- input_type.change(
507
- update_visibility,
508
- inputs=input_type,
509
- outputs=[text_input, csv_input, url_input]
510
- )
511
-
512
- # Results section
513
- gr.Markdown("---")
514
- gr.Markdown("## 📈 Analysis Results")
515
-
516
- summary_output = gr.Markdown(label="Summary")
517
-
518
- with gr.Row():
519
- with gr.Column():
520
- pie_output = gr.Plot(label="Rating Distribution")
521
- with gr.Column():
522
- bar_output = gr.Plot(label="Sentiment Analysis")
523
-
524
- hist_output = gr.Plot(label="Confidence Distribution")
525
-
526
- table_output = gr.Dataframe(
527
- label="Detailed Results",
528
- headers=["Feedback", "Rating", "Stars", "Sentiment", "Confidence"],
529
- interactive=False
530
- )
531
-
532
- # Download button
533
- gr.Markdown("### 💾 Download Results")
534
- download_btn = gr.Button("📥 Download as CSV")
535
-
536
- analyze_btn.click(
537
- analyze_feedbacks,
538
- inputs=[input_type, text_input, csv_input, url_input],
539
- outputs=[summary_output, pie_output, bar_output, hist_output, table_output]
540
- )
541
-
542
- # ====================================================================
543
- # TAB 3: ABOUT & HELP
544
- # ====================================================================
545
- with gr.Tab("ℹ️ About & Help", id=2):
546
- gr.Markdown("""
547
- # 🌟 About This Application
548
-
549
- ## What is this?
550
- This is an AI-powered customer feedback rating predictor that automatically analyzes text feedback
551
- and predicts satisfaction ratings from 1 to 5 stars.
552
-
553
- ## 🎯 Features
554
-
555
- ### 💬 Quick Analysis
556
- - Instant single feedback analysis
557
- - Real-time rating prediction
558
- - Sentiment classification (Positive/Neutral/Negative)
559
- - Confidence scores
560
-
561
- ### 📊 Batch Analysis Dashboard
562
- - Analyze multiple feedbacks at once
563
- - Three input methods:
564
- - **Manual Text**: Enter feedbacks line by line
565
- - **CSV Upload**: Upload a CSV file with feedback data
566
- - **URL Fetch**: Extract reviews from a webpage
567
-
568
- ### 📈 Beautiful Visualizations
569
- - **Rating Distribution**: Pie chart showing breakdown of 1-5 star ratings
570
- - **Sentiment Analysis**: Bar chart of positive/neutral/negative sentiments
571
- - **Confidence Distribution**: Histogram of prediction confidence levels
572
- - **Detailed Table**: Comprehensive view of all analyzed feedbacks
573
-
574
- ## 🔧 How to Use
575
-
576
- ### Quick Analysis (Chat Mode)
577
- 1. Go to "Quick Analysis" tab
578
- 2. Type your feedback
579
- 3. Click "Predict Rating"
580
- 4. Get instant results!
581
-
582
- ### Batch Analysis
583
- 1. Go to "Batch Analysis Dashboard" tab
584
- 2. Choose input method:
585
- - **Manual**: Type feedbacks (one per line)
586
- - **CSV**: Upload file (must have 'feedback' or 'review' column)
587
- - **URL**: Paste review page URL
588
- 3. Click "Analyze All"
589
- 4. View comprehensive dashboard with graphs and statistics
590
-
591
- ## 📊 Understanding Results
592
-
593
- - **Rating**: 1-5 stars (1 = very negative, 5 = very positive)
594
- - **Sentiment**: Overall emotion (Positive/Neutral/Negative)
595
- - **Confidence**: How sure the model is (0-100%)
596
- - **Probabilities**: Likelihood for each rating level
597
-
598
- ## 💡 Tips for Best Results
599
-
600
- 1. **Clear Feedback**: More detailed feedback = better predictions
601
- 2. **Language**: Works best with English text
602
- 3. **Length**: 10-500 characters ideal
603
- 4. **CSV Format**: Use column names like 'feedback', 'review', or 'text'
604
- 5. **Batch Size**: For performance, analyze up to 100 feedbacks at once
605
-
606
- ## 🎨 Use Cases
607
-
608
- - **E-commerce**: Analyze product reviews
609
- - **Restaurants**: Monitor food and service feedback
610
- - **Hotels**: Assess guest satisfaction
611
- - **Customer Service**: Evaluate support interactions
612
- - **Market Research**: Understand customer sentiment
613
-
614
- ## 🤖 Model Details
615
-
616
- - **Architecture**: BERT-based transformer model
617
- - **Training**: Fine-tuned on customer review datasets
618
- - **Accuracy**: 75-85% (depending on feedback quality)
619
- - **Speed**: ~100-200ms per prediction
620
-
621
- ## 📧 Support
622
-
623
- Found a bug or have suggestions? Open an issue on GitHub or contact support.
624
-
625
- ---
626
-
627
- **Made with ❤️ using Transformers & Gradio**
628
- """)
629
-
630
- # Footer
631
- gr.HTML("""
632
- <div style="text-align: center; padding: 2rem; color: #7f8c8d;">
633
- <p style="font-size: 0.9em;">
634
- Powered by Hugging Face Transformers 🤗 | Built with Gradio ⚡ | Deployed on HF Spaces 🚀
635
- </p>
636
- </div>
637
- """)
638
-
639
- # ============================================================================
640
- # LAUNCH
641
- # ============================================================================
642
 
643
  if __name__ == "__main__":
644
- demo.launch(share=False, show_error=True)
 
1
+ # app.py (Hugging Face Space friendly)
2
+ import os, warnings
3
+ warnings.filterwarnings("ignore")
 
4
 
5
+ import numpy as np
 
 
6
  import pandas as pd
7
+ import yfinance as yf
8
+ from datetime import datetime, timedelta
9
+ import joblib
10
+ from sklearn.ensemble import RandomForestClassifier
11
+ from sklearn.model_selection import train_test_split
12
+ from sklearn.metrics import roc_auc_score
13
  import plotly.graph_objects as go
14
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # ----- Utilities -----
17
+ def download_data(ticker, period='6y', interval='1d'):
18
+ df = yf.download(ticker, period=period, interval=interval, progress=False)
19
+ if df is None or df.empty:
20
+ raise ValueError(f"No data for {ticker}")
21
+ df.index = pd.to_datetime(df.index)
22
+ return df.dropna()
23
+
24
+ def add_features(df):
25
+ df = df.copy()
26
+ df['AdjClose'] = df['Adj Close']
27
+ df['ret'] = df['AdjClose'].pct_change()
28
+ df['logret'] = np.log(df['AdjClose']).diff()
29
+ df['ma5'] = df['AdjClose'].rolling(5).mean()
30
+ df['ma20'] = df['AdjClose'].rolling(20).mean()
31
+ df['vol20'] = df['logret'].rolling(20).std()
32
+ delta = df['AdjClose'].diff()
33
+ up = delta.clip(lower=0); down = -1*delta.clip(upper=0)
34
+ ma_up = up.rolling(14).mean(); ma_down = down.rolling(14).mean()
35
+ rs = ma_up / (ma_down + 1e-9)
36
+ df['rsi14'] = 100 - (100 / (1 + rs))
37
+ df['mom5'] = df['AdjClose'].pct_change(5)
38
+ return df.dropna()
39
+
40
+ def make_label(df, threshold_pct=-0.10, horizon=30):
41
+ closes = df['AdjClose'].values
42
+ n = len(closes)
43
+ label = np.zeros(n, dtype=int)
44
+ for i in range(n):
45
+ end = min(n, i + horizon + 1)
46
+ future = closes[i+1:end]
47
+ if future.size==0:
48
+ label[i]=0; continue
49
+ minf = np.min(future)
50
+ drop = (minf - closes[i]) / closes[i]
51
+ if drop <= threshold_pct:
52
+ label[i]=1
53
+ df['label']=label
54
+ return df
55
 
56
+ # ----- Training (light) -----
57
+ def train_if_missing(ticker, threshold_pct=-0.10, horizon=30):
58
+ model_path = f"models/{ticker}_rf.pkl"
59
+ os.makedirs("models", exist_ok=True)
60
+ if os.path.exists(model_path):
61
+ return model_path
62
+ df = download_data(ticker, period='6y')
63
+ df = add_features(df)
64
+ df = make_label(df, threshold_pct=threshold_pct, horizon=horizon)
65
+ features = ['ret','logret','ma5','ma20','vol20','rsi14','mom5']
66
+ df = df.dropna(subset=features+['label'])
67
+ X = df[features].values; y = df['label'].values
68
+ if len(y) < 250:
69
+ # still train but warn
70
+ pass
71
+ # LIGHTER model for Spaces: fewer trees
72
+ clf = RandomForestClassifier(n_estimators=50, random_state=42, n_jobs=-1, class_weight='balanced')
73
+ # Use time-ordered split (no shuffle)
74
+ split = int(len(X)*0.8)
75
+ X_train, y_train = X[:split], y[:split]
76
+ clf.fit(X_train, y_train)
77
+ joblib.dump({'model':clf, 'features':features}, model_path)
78
+ return model_path
79
+
80
+ # ----- Predict probability -----
81
+ def predict_prob(ticker, threshold_pct_pos, horizon):
82
+ ticker = ticker.strip().upper()
83
+ threshold = -abs(threshold_pct_pos)/100.0
84
+ model_path = train_if_missing(ticker, threshold_pct=threshold, horizon=horizon)
85
+ saved = joblib.load(model_path)
86
+ clf = saved['model']; features = saved['features']
87
+ df = download_data(ticker, period='6y')
88
+ df = add_features(df)
89
+ X_latest = df[features].iloc[-1].values.reshape(1,-1)
90
+ prob = float(clf.predict_proba(X_latest)[:,1][0])
91
+ return prob, df
92
+
93
+ # ----- GBM Monte Carlo (smaller sims default) -----
94
+ def simulate_gbm(S0, mu, sigma, days=252, n_sims=500, seed=0):
95
+ np.random.seed(seed)
96
+ dt = 1/252
97
+ paths = np.zeros((days+1, n_sims)); paths[0]=S0
98
+ for t in range(1, days+1):
99
+ z = np.random.normal(size=n_sims)
100
+ paths[t] = paths[t-1] * np.exp((mu - 0.5*sigma**2)*dt + sigma*np.sqrt(dt)*z)
101
+ return paths
102
+
103
+ def build_candles_from_paths(paths, start_date):
104
+ median = np.percentile(paths,50,axis=1)
105
+ q10 = np.percentile(paths,10,axis=1)
106
+ q90 = np.percentile(paths,90,axis=1)
107
+ o = median[:-1]; c = median[1:]
108
+ h = np.maximum(c, q90[1:]); l = np.minimum(c, q10[1:])
109
+ dates = pd.bdate_range(start=start_date, periods=len(c))
110
+ df = pd.DataFrame({'Open':o, 'High':h, 'Low':l, 'Close':c}, index=dates)
111
+ return df
112
 
113
+ def plot_candles(df):
114
+ fig = go.Figure(data=[go.Candlestick(x=df.index, open=df['Open'], high=df['High'],
115
+ low=df['Low'], close=df['Close'])])
116
+ fig.update_layout(xaxis_rangeslider_visible=False, height=600)
117
+ return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ # ----- Main function used by Gradio -----
120
+ def run(ticker="RELIANCE.NS", threshold=10.0, horizon=30, sims=500):
121
  try:
122
+ prob, df = predict_prob(ticker, threshold, horizon)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  except Exception as e:
124
+ return None, f"Error: {e}"
125
+ # VaR/CVaR simple (historical daily)
126
+ returns = df['Adj Close'].pct_change().dropna().values
127
+ sorted_ret = np.sort(returns)
128
+ idx = max(0, int(0.05*len(sorted_ret))-1)
129
+ var = -sorted_ret[idx]
130
+ cvar = -sorted_ret[:idx+1].mean() if idx>=0 else -sorted_ret.mean()
131
+ # GBM simulate
132
+ logrets = np.log(df['Adj Close']).diff().dropna()
133
+ mu = float(logrets.mean()*252); sigma = float(logrets.std()*np.sqrt(252))
134
+ S0 = float(df['Adj Close'].iloc[-1])
135
+ sims = int(max(100, min(2000, sims)))
136
+ model_paths = simulate_gbm(S0, mu, sigma, days=252, n_sims=sims, seed=1)
137
+ start_date = (df.index[-1] + pd.Timedelta(days=1)).normalize()
138
+ df_candles = build_candles_from_paths(model_paths, start_date)
139
+ fig = plot_candles(df_candles)
140
+ summary = (f"Ticker: {ticker}\nThreshold: {threshold}% drop within {horizon} days\n"
141
+ f"Predicted prob: {prob*100:.2f}%\nHistorical VaR(5%): {var:.4f}, CVaR: {cvar:.4f}\n"
142
+ f"Annual mu: {mu:.4f}, sigma: {sigma:.4f}")
143
+ return fig, summary
144
+
145
+ # ----- Gradio UI -----
146
+ title = "Stock Risk Predictor + 1Y Candle Simulator (Hugging Face Space)"
147
+ desc = "Enter ticker (eg RELIANCE.NS). Threshold (percent), horizon days, sims (keep small for hosted Space)."
148
+
149
+ iface = gr.Interface(
150
+ fn=run,
151
+ inputs=[gr.Textbox(label="Ticker", value="RELIANCE.NS"),
152
+ gr.Number(label="Threshold percent (drop)", value=10.0),
153
+ gr.Number(label="Horizon days", value=30, precision=0),
154
+ gr.Number(label="Monte Carlo sims (100-2000)", value=500, precision=0)],
155
+ outputs=[gr.Plot(label="Simulated 1Y Candles"), gr.Textbox(label="Summary")],
156
+ title=title, description=desc, allow_flagging="never",
157
+ examples=[["RELIANCE.NS",10,30,500], ["AAPL",15,30,500]]
158
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  if __name__ == "__main__":
161
+ iface.launch()