Spaces:
Running
Running
| import pandas as pd | |
| import numpy as np | |
| from datasets import load_dataset | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import classification_report | |
| import m2cgen as m2c | |
| import re | |
| print("Loading dataset alanjoshua2005/india-spam-sms...") | |
| dataset = load_dataset("alanjoshua2005/india-spam-sms") | |
| # Usually dataset has 'train' split which is a list of dicts. Let's make a dataframe | |
| df = pd.DataFrame(dataset['train']) | |
| # Let's see columns | |
| print(df.columns) | |
| # Typically it has 'v1' mapping to label and 'v2' mapping to text, or 'label' and 'text', or 'Message' and 'Category' | |
| if 'Message' in df.columns and 'Category' in df.columns: | |
| df['text'] = df['Message'] | |
| df['label'] = df['Category'].apply(lambda x: 1 if x == 'spam' else 0) | |
| elif 'text' in df.columns and 'label' in df.columns: | |
| pass | |
| elif 'v1' in df.columns and 'v2' in df.columns: | |
| df['text'] = df['v2'] | |
| df['label'] = df['v1'].apply(lambda x: 1 if x == 'spam' else 0) | |
| else: | |
| # try generically | |
| for col in df.columns: | |
| if 'text' in col.lower() or 'msg' in col.lower() or 'message' in col.lower(): | |
| df['text'] = df[col] | |
| if 'label' in col.lower() or 'spam' in col.lower() or 'category' in col.lower(): | |
| df['label'] = df[col] | |
| if df['label'].dtype == object: | |
| df['label'] = df['label'].apply(lambda x: 1 if 'spam' in str(x).lower() else 0) | |
| print(df.head()) | |
| def extract_features(row): | |
| text = str(row['text']).lower() | |
| # 1. senderAgeScore (mock context) | |
| # 2. senderInContacts | |
| # 3. senderMessageCount | |
| # 4. senderCarrierScore | |
| # We will simulate these based on label, because text doesn't have it. | |
| is_spam = row['label'] == 1 | |
| senderAgeScore = np.random.uniform(0.0, 0.3) if is_spam else np.random.uniform(0.5, 1.0) | |
| senderInContacts = 0.0 if is_spam else float(np.random.rand() > 0.2) | |
| senderMessageCount = np.random.randint(0, 3) if is_spam else np.random.randint(10, 100) | |
| senderCarrierScore = np.random.uniform(0.0, 0.4) if is_spam else np.random.uniform(0.6, 1.0) | |
| # 5. urgencyScore (0.0 to 1.0) | |
| urgent_words = ['urgent', 'hurry', 'limited', 'immediate', 'action required', 'alert', 'warning', 'expires', 'claim'] | |
| urgencyScore = sum([1 for w in urgent_words if w in text]) / len(urgent_words) | |
| # 6. domainCount | |
| domain_matches = re.findall(r"(http[s]?://|www\.)[^\s]+", text) | |
| domainCount = len(domain_matches) | |
| # 7. domainLevenshtein (simulated, no actual domain extraction here) | |
| domainLevenshtein = np.random.uniform(0.2, 0.9) if (is_spam and domainCount > 0) else 1.0 | |
| # 8. containsMoneyAction | |
| money_words = ['rs', '₹', 'free', 'win', 'prize', 'cash', 'money', 'credit', 'loan', 'offer', 'pay', 'collect', 'bank'] | |
| containsMoneyAction = float(any(w in text for w in money_words)) | |
| # 9. capsRatio | |
| orig_text = str(row['text']) | |
| caps_count = sum(1 for c in orig_text if c.isupper()) | |
| letters_count = sum(1 for c in orig_text if c.isalpha()) | |
| capsRatio = caps_count / letters_count if letters_count > 0 else 0.0 | |
| # 10. hourOfDay (spam often sent at odd hours) | |
| hourOfDay = float(np.random.choice([0,1,2,3,4,22,23])) if is_spam else float(np.random.randint(8, 20)) | |
| # 11. isWeekend | |
| isWeekend = float(np.random.rand() > (0.3 if is_spam else 0.8)) | |
| # 12. isFestivalPeriod | |
| isFestivalPeriod = float(np.random.rand() > 0.8) | |
| # 13. recentScamSenderSim | |
| recentScamSenderSim = np.random.uniform(0.5, 1.0) if is_spam else np.random.uniform(0.0, 0.2) | |
| return [ | |
| senderAgeScore, senderInContacts, float(senderMessageCount), senderCarrierScore, | |
| urgencyScore, float(domainCount), domainLevenshtein, containsMoneyAction, capsRatio, | |
| hourOfDay, isWeekend, isFestivalPeriod, recentScamSenderSim | |
| ] | |
| print("Extracting features...") | |
| X = np.array([extract_features(row) for _, row in df.iterrows()]) | |
| y = df['label'].values | |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
| print("Training Random Forest...") | |
| clf = RandomForestClassifier(n_estimators=15, max_depth=5, random_state=42) | |
| clf.fit(X_train, y_train) | |
| print("Evaluating...") | |
| preds = clf.predict(X_test) | |
| print(classification_report(y_test, preds)) | |
| print("Exporting model to Dart with m2cgen...") | |
| dart_code = m2c.export_to_dart(clf) | |
| # Write to our model file | |
| file_path = "lib/kavacha/models/random_forest_model.dart" | |
| header = """// Auto-generated using m2cgen | |
| import 'sms_feature_vector.dart'; | |
| class RandomForestModel { | |
| double score(SmsFeatureVector features) { | |
| List<double> input = [ | |
| features.senderAgeScore, | |
| features.senderInContacts ? 1.0 : 0.0, | |
| features.senderMessageCount.toDouble(), | |
| features.senderCarrierScore, | |
| features.urgencyScore, | |
| features.domainCount.toDouble(), | |
| features.domainLevenshtein, | |
| features.containsMoneyAction ? 1.0 : 0.0, | |
| features.capsRatio, | |
| features.hourOfDay.toDouble(), | |
| features.isWeekend ? 1.0 : 0.0, | |
| features.isFestivalPeriod ? 1.0 : 0.0, | |
| features.recentScamSenderSim | |
| ]; | |
| List<double> output = score_features(input); | |
| // output[1] is the probability of class 1 (spam) | |
| return output[1]; | |
| } | |
| """ | |
| with open(file_path, "w") as f: | |
| f.write(header) | |
| # The generated code usually has a class or double[] score(double[] input) | |
| # We will just write it straight out, maybe replace the class wrapper | |
| code = dart_code.replace("double[]", "List<double>").replace("public double[] score", "List<double> score_features") | |
| # m2cgen generates a single function / class. We strip "public class Model {" etc if present. | |
| # Luckily Dart export is typically `List<double> score(List<double> input)` | |
| f.write(code) | |
| f.write("\n}\n") | |
| print("Done writing to " + file_path) | |