leygit commited on
Commit
a32e6e8
·
verified ·
1 Parent(s): ba5291d

Create app.py file

Browse files
Files changed (1) hide show
  1. app.py +335 -0
app.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch
4
+ from transformers import BertTokenizer
5
+ import seaborn as sns
6
+ import matplotlib.pyplot as plt
7
+ from sklearn.feature_extraction.text import CountVectorizer
8
+
9
+
10
+ # Load dataset
11
+ file_path = 'spam_ham_dataset.csv'
12
+ df = pd.read_csv(file_path)
13
+ df.head()
14
+
15
+ # Preprocessing
16
+ #.str.replace(r'[^\w\s]', '', regex=True) removes everthing except letters, numbers, and spaces
17
+ # df['text'].str.lower() converts everything in the text column to lower case only
18
+ df['text'] = df['text'].str.lower().str.replace(r'[^\w\s]', '', regex=True)
19
+ df['text'].head()
20
+
21
+
22
+ sns.countplot(x=df['label'])
23
+ plt.title("Spam vs Ham Distribution")
24
+ plt.show()
25
+
26
+ # Calculate text length metrics
27
+ df['char_count'] = df['text'].apply(len)
28
+ df['word_count'] = df['text'].apply(lambda x: len(x.split()))
29
+ # Plot word count distribution for spam and ham
30
+ plt.figure(figsize=(12, 5))
31
+ sns.histplot(data=df, x='word_count', hue='label', bins=30, kde=True)
32
+ plt.xlim(0, 1000)
33
+ plt.title("Word Count Distribution by Label")
34
+ plt.xlabel("Number of Words")
35
+ plt.ylabel("Frequency")
36
+ plt.show()
37
+
38
+ def get_top_words(corpus, n=None):
39
+ vec = CountVectorizer(stop_words='english').fit(corpus)
40
+ bag_of_words = vec.transform(corpus)
41
+ sum_words = bag_of_words.sum(axis=0)
42
+ words_freq = [(word, sum_words[0, idx]) for word, idx in vec.vocabulary_.items()]
43
+ words_freq = sorted(words_freq, key=lambda x: x[1], reverse=True)
44
+ return words_freq[:n]
45
+
46
+ # Top 10 words for spam
47
+ top_spam_words = get_top_words(df[df['label'] == "spam"]['text'], n=10)
48
+ print("Top spam words:", top_spam_words)
49
+
50
+ # Top 10 words for ham
51
+ top_ham_words = get_top_words(df[df['label'] == "ham"]['text'], n=10)
52
+ print("Top ham words:", top_ham_words)
53
+
54
+ from sklearn.feature_extraction.text import TfidfVectorizer
55
+ from sklearn.naive_bayes import MultinomialNB
56
+ from sklearn.metrics import classification_report
57
+
58
+ # TF-IDF Vectorization
59
+ vectorizer = TfidfVectorizer()
60
+ X = vectorizer.fit_transform(df['text'])
61
+ y = df['label_num']
62
+
63
+ # Train-Test Split
64
+ from sklearn.model_selection import train_test_split
65
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
66
+
67
+ # Train Naïve Bayes Model
68
+ nb_model = MultinomialNB()
69
+ nb_model.fit(X_train, y_train)
70
+
71
+ # Predictions
72
+ y_pred = nb_model.predict(X_test)
73
+ print(classification_report(y_test, y_pred))
74
+
75
+ import pandas as pd
76
+ import torch
77
+ import torch.nn as nn
78
+ import torch.optim as optim
79
+ from transformers import BertTokenizer, BertForSequenceClassification
80
+ from torch.utils.data import Dataset, DataLoader
81
+
82
+ # Load dataset
83
+ file_path = 'spam_ham_dataset.csv'
84
+ df = pd.read_csv(file_path)
85
+
86
+ # Convert label column to numeric (0 for ham, 1 for spam)
87
+ df['label_num'] = df['label'].astype('category').cat.codes
88
+
89
+ # Load tokenizer
90
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
91
+
92
+ # Tokenize dataset
93
+ encodings = tokenizer(df['text'].tolist(), padding=True, truncation=True, max_length=128, return_tensors="pt")
94
+ labels = torch.tensor(df['label_num'].values)
95
+
96
+ # Custom Dataset
97
+ class SpamDataset(Dataset):
98
+ def __init__(self, encodings, labels):
99
+ self.encodings = encodings
100
+ self.labels = labels
101
+
102
+ def __len__(self):
103
+ return len(self.labels)
104
+
105
+ def __getitem__(self, idx):
106
+ item = {key: val[idx] for key, val in self.encodings.items()} # Keep as PyTorch tensors
107
+ item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long) # Ensure labels are `long`
108
+ return item
109
+
110
+ # Create dataset
111
+ dataset = SpamDataset(encodings, labels)
112
+
113
+ # Split dataset (80% train, 20% validation)
114
+ train_size = int(0.8 * len(dataset))
115
+ val_size = len(dataset) - train_size
116
+ train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
117
+
118
+ # DataLoader Function (Fix Collate)
119
+ def collate_fn(batch):
120
+ keys = batch[0].keys()
121
+ collated = {key: torch.stack([b[key] for b in batch]) for key in keys}
122
+ return collated
123
+
124
+ # Create DataLoader
125
+ train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
126
+ val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)
127
+
128
+ # Load BERT model
129
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
130
+ model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
131
+ model.to(device)
132
+
133
+ # Define optimizer and loss function
134
+ optimizer = optim.AdamW(model.parameters(), lr=5e-5)
135
+ loss_fn = nn.CrossEntropyLoss()
136
+
137
+ # Training Loop
138
+ EPOCHS = 10
139
+
140
+ for epoch in range(EPOCHS):
141
+ model.train()
142
+ total_loss = 0
143
+
144
+ for batch in train_loader:
145
+ optimizer.zero_grad()
146
+
147
+ # Move batch to device
148
+ inputs = {key: val.to(device) for key, val in batch.items()}
149
+ labels = inputs.pop("labels").to(device) # Move labels to device
150
+
151
+ # Forward pass
152
+ outputs = model(**inputs)
153
+ loss = loss_fn(outputs.logits, labels)
154
+
155
+ # Backward pass
156
+ loss.backward()
157
+ optimizer.step()
158
+
159
+ total_loss += loss.item()
160
+
161
+ avg_loss = total_loss / len(train_loader)
162
+ print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
163
+
164
+ print("Training complete!")
165
+
166
+ from sklearn.metrics import classification_report
167
+ from transformers import BertTokenizer
168
+ import torch
169
+ import torch.nn.functional as F
170
+
171
+ # Classification function
172
+ def classify_email(email_text):
173
+ model.eval() # Set model to evaluation mode
174
+
175
+ with torch.no_grad():
176
+ # Tokenize and convert input text to tensor
177
+ inputs = tokenizer(email_text, padding=True, truncation=True, max_length=256, return_tensors="pt")
178
+
179
+ # Move inputs to the appropriate device
180
+ inputs = {key: val.to(device) for key, val in inputs.items()}
181
+
182
+ # Get model predictions
183
+ outputs = model(**inputs)
184
+ logits = outputs.logits
185
+
186
+ # Convert logits to predicted class
187
+ predictions = torch.argmax(logits, dim=1)
188
+
189
+ # Convert logits to probabilities using softmax
190
+ probs = F.softmax(logits, dim=1)
191
+ confidence = torch.max(probs).item() * 100 # Convert to percentage
192
+
193
+ # Convert numeric prediction to label
194
+ result = "Spam" if predictions.item() == 1 else "Ham"
195
+
196
+ return {
197
+ "result": result,
198
+ "confidence": f"{confidence:.2f}%",
199
+ }
200
+
201
+ # Evaluation function with detailed classification report
202
+ def evaluate_model_with_report(val_loader):
203
+ model.eval() # Set model to evaluation mode
204
+ y_true = []
205
+ y_pred = []
206
+ correct = 0
207
+ total = 0
208
+
209
+ with torch.no_grad():
210
+ for batch in val_loader:
211
+ inputs = {key: val.to(device) for key, val in batch.items()}
212
+ labels = inputs.pop("labels").to(device)
213
+
214
+ outputs = model(**inputs)
215
+ predictions = torch.argmax(outputs.logits, dim=1)
216
+
217
+ # Collect labels and predictions
218
+ y_true.extend(labels.cpu().numpy())
219
+ y_pred.extend(predictions.cpu().numpy())
220
+
221
+ # Calculate accuracy
222
+ correct += (predictions == labels).sum().item()
223
+ total += labels.size(0)
224
+
225
+ # Calculate accuracy
226
+ accuracy = correct / total if total > 0 else 0
227
+ print(f"Validation Accuracy: {accuracy:.4f}")
228
+
229
+ # Print classification report
230
+ print("\nClassification Report:")
231
+ print(classification_report(y_true, y_pred, target_names=["Ham", "Spam"]))
232
+
233
+ return accuracy
234
+
235
+ # Run evaluation with classification report
236
+ accuracy = evaluate_model_with_report(val_loader)
237
+ print(f"Model Validation Accuracy: {accuracy:.4f}")
238
+
239
+
240
+ ## App Deployment Functions
241
+
242
+ def generate_performance_metrics():
243
+ y_pred = model.predict(X_test)
244
+ accuracy = evaluate_model_with_report(val_loader)
245
+ report = classification_report(y_true, y_pred, target_names=["Ham", "Spam"])
246
+ return {
247
+ "accuracy": f"{accuracy:.2%}",
248
+ "precision": f"{report['1']['precision']:.2%}",
249
+ "recall": f"{report['1']['recall']:.2%}",
250
+ "f1_score": f"{report['1']['f1-score']:.2%}"
251
+ }
252
+
253
+ def email_analysis_pipeline(email_text):
254
+ results = classify_email(email_text)
255
+ accuracy = evaluate_model_with_report(val_loader)
256
+ return {
257
+ results["result"],
258
+ results["confidence"],
259
+ accuracy
260
+ }
261
+
262
+ ## Gradio Interface
263
+
264
+ !pip install gradio
265
+ import gradio as gr
266
+
267
+
268
+
269
+ # Create Gradio Interface
270
+ def create_interface():
271
+ performance_metrics = generate_performance_metrics()
272
+
273
+ # Introduction - Title + Brief Description
274
+ with gr.Blocks(css=custom_css) as interface:
275
+ gr.Markdown("Spam Email Classification")
276
+ gr.Markdown(
277
+ """
278
+ Brief description of the project here
279
+
280
+ """
281
+ )
282
+
283
+ # Email Text Input
284
+ with gr.Row():
285
+ email_input = gr.Textbox(
286
+ lines=8, placeholder="Type or paste your email content here...", label="Email Content"
287
+ )
288
+
289
+ # Email Text Results and Analysis
290
+ with gr.Row():
291
+ result_output = gr.HTML(label="Classification Result") # label = [function that prints classification result]
292
+ confidence_output = gr.Textbox(label="Confidence Score", interactive=False)
293
+ accuracy_output = gr.Textbox(label="Accuracy", interactive=False)
294
+
295
+
296
+ analyze_button = gr.Button("Analyze Email 🕵️‍♂️")
297
+
298
+ analyze_button.click(
299
+ fn=email_analysis_pipeline,
300
+ inputs=email_input,
301
+ outputs=[result_output, confidence_output, accuracy_output]
302
+ )
303
+
304
+ # Analysis
305
+ gr.Markdown("## 📊 Model Performance Analytics")
306
+ with gr.Row():
307
+ with gr.Column():
308
+ gr.Textbox(value=performance_metrics["accuracy"], label="Accuracy", interactive=False, elem_classes=["metric"])
309
+ gr.Textbox(value=performance_metrics["precision"], label="Precision", interactive=False, elem_classes=["metric"])
310
+ gr.Textbox(value=performance_metrics["recall"], label="Recall", interactive=False, elem_classes=["metric"])
311
+ gr.Textbox(value=performance_metrics["f1_score"], label="F1 Score", interactive=False, elem_classes=["metric"])
312
+ with gr.Column():
313
+ gr.Markdown("### Confusion Matrix")
314
+ gr.HTML(f"<img src='data:image/png;base64,{performance_metrics['confusion_matrix_plot']}' style='max-width: 100%; height: auto;' />")
315
+
316
+ gr.Markdown("## 📘 Glossary and Explanation of Labels")
317
+ gr.Markdown(
318
+ """
319
+ ### Labels:
320
+ - **Spam:** Unwanted or harmful emails flagged by the system.
321
+ - **Ham:** Legitimate, safe emails.
322
+
323
+ ### Metrics:
324
+ - **Accuracy:** The percentage of correct classifications.
325
+ - **Precision:** Out of predicted Spam, how many are actually Spam.
326
+ - **Recall:** Out of all actual Spam emails, how many are predicted as Spam.
327
+ - **F1 Score:** Harmonic mean of Precision and Recall.
328
+ """
329
+ )
330
+
331
+ return interface
332
+
333
+ # Launch the interface
334
+ interface = create_interface()
335
+ interface.launch(share=True)