rtferraz commited on
Commit
e4d8561
·
verified ·
1 Parent(s): 165b138

Fix label leakage: temporal split — use first 70% of events as input, predict purchase in last 30%. Remove n_purchases/purchase_rate from features.

Browse files
Files changed (1) hide show
  1. notebooks/03_ecommerce_finetune.ipynb +111 -85
notebooks/03_ecommerce_finetune.ipynb CHANGED
@@ -4,11 +4,13 @@
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
- "# 03 — E-Commerce Fine-Tuning: Next-Purchase Prediction\n",
8
  "\n",
9
- "**Goal:** Fine-tune the pre-trained DomainTransformer for predicting whether a user will make a purchase, and compare against a LightGBM baseline on hand-crafted features.\n",
10
  "\n",
11
- "**Task:** Binary classification — given a user's event sequence, predict if they will purchase (1) or not (0).\n",
 
 
12
  "\n",
13
  "**Pre-trained model:** [rtferraz/ecommerce-domain-24m](https://huggingface.co/rtferraz/ecommerce-domain-24m)\n",
14
  "\n",
@@ -46,7 +48,7 @@
46
  "import matplotlib.pyplot as plt\n",
47
  "import torch\n",
48
  "from sklearn.model_selection import train_test_split\n",
49
- "from sklearn.metrics import roc_auc_score, classification_report\n",
50
  "\n",
51
  "if os.path.exists('../src'): sys.path.insert(0, '../src')\n",
52
  "elif os.path.exists('src'): sys.path.insert(0, 'src')\n",
@@ -54,7 +56,7 @@
54
  "from domain_tokenizer import (\n",
55
  " DomainTokenizerBuilder, DomainTransformerConfig,\n",
56
  " DomainTransformerForCausalLM, JointFusionModel,\n",
57
- " DomainFinetuneDataset, prepare_finetune_dataset, finetune_domain_model,\n",
58
  ")\n",
59
  "from domain_tokenizer.schema import DomainSchema, FieldSpec, FieldType\n",
60
  "\n",
@@ -82,9 +84,7 @@
82
  "cell_type": "markdown",
83
  "metadata": {},
84
  "source": [
85
- "## Step 1 — Load Pre-trained Artifacts\n",
86
- "\n",
87
- "Load the artifacts saved by `02_ecommerce_pretrain.ipynb`."
88
  ]
89
  },
90
  {
@@ -93,20 +93,16 @@
93
  "metadata": {},
94
  "outputs": [],
95
  "source": [
96
- "# Load user sequences from pre-training notebook\n",
97
  "with open('./ecommerce_artifacts.pkl', 'rb') as f:\n",
98
  " artifacts = pickle.load(f)\n",
99
- "\n",
100
  "user_sequences = artifacts['user_sequences']\n",
101
  "user_ids = artifacts['user_ids']\n",
102
  "print(f'Loaded {len(user_sequences):,} users')\n",
103
  "\n",
104
- "# Load tokenizer\n",
105
  "from transformers import PreTrainedTokenizerFast\n",
106
  "hf_tokenizer = PreTrainedTokenizerFast.from_pretrained('./ecommerce_tokenizer')\n",
107
  "print(f'Tokenizer vocab: {hf_tokenizer.vocab_size}')\n",
108
  "\n",
109
- "# Rebuild the schema and builder (needed for tokenize_event)\n",
110
  "ECOMMERCE_REES46_SCHEMA = DomainSchema(\n",
111
  " name='ecommerce_rees46',\n",
112
  " fields=[\n",
@@ -130,13 +126,7 @@
130
  "metadata": {},
131
  "outputs": [],
132
  "source": [
133
- "# Load pre-trained model using from_pretrained (handles safetensors natively)\n",
134
- "# Option A: from local checkpoint saved by notebook 02\n",
135
  "model = DomainTransformerForCausalLM.from_pretrained('./ecommerce_pretrain_checkpoints/final/')\n",
136
- "\n",
137
- "# Option B: from HuggingFace Hub (if local not available)\n",
138
- "# model = DomainTransformerForCausalLM.from_pretrained('rtferraz/ecommerce-domain-24m')\n",
139
- "\n",
140
  "print(f'Pre-trained model loaded: {sum(p.numel() for p in model.parameters()):,} params')"
141
  ]
142
  },
@@ -144,11 +134,15 @@
144
  "cell_type": "markdown",
145
  "metadata": {},
146
  "source": [
147
- "## Step 2 — Create Labels and Tabular Features\n",
148
  "\n",
149
- "**Label:** Binary did the user make at least one purchase? (1=yes, 0=no)\n",
 
 
 
 
150
  "\n",
151
- "**Tabular features:** Hand-crafted from user sequences (for the DCNv2 branch and LightGBM baseline)."
152
  ]
153
  },
154
  {
@@ -157,49 +151,87 @@
157
  "metadata": {},
158
  "outputs": [],
159
  "source": [
160
- "def compute_user_features(events):\n",
161
- " \"\"\"Extract tabular features from a user's event sequence.\"\"\"\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  " n_events = len(events)\n",
163
  " n_views = sum(1 for e in events if e['event_type'] == 'view')\n",
164
  " n_carts = sum(1 for e in events if e['event_type'] == 'cart')\n",
165
- " n_purchases = sum(1 for e in events if e['event_type'] == 'purchase')\n",
166
  " n_removes = sum(1 for e in events if e['event_type'] == 'remove_from_cart')\n",
 
 
167
  " \n",
168
  " prices = [e['price'] for e in events if e['price'] > 0]\n",
169
  " avg_price = np.mean(prices) if prices else 0\n",
170
  " max_price = max(prices) if prices else 0\n",
171
  " std_price = np.std(prices) if len(prices) > 1 else 0\n",
172
  " \n",
173
- " categories = set(e['category'] for e in events)\n",
174
- " n_unique_categories = len(categories)\n",
175
- " \n",
176
- " hours = [e['timestamp'].hour for e in events]\n",
177
- " avg_hour = np.mean(hours)\n",
178
  " \n",
 
179
  " cart_rate = n_carts / max(n_views, 1)\n",
180
- " purchase_rate = n_purchases / max(n_events, 1)\n",
181
  " remove_rate = n_removes / max(n_carts, 1) if n_carts > 0 else 0\n",
 
 
 
 
 
 
 
 
182
  " \n",
183
  " return [\n",
184
- " n_events, n_views, n_carts, n_purchases, n_removes,\n",
185
  " avg_price, max_price, std_price,\n",
186
  " n_unique_categories, avg_hour,\n",
187
- " cart_rate, purchase_rate, remove_rate,\n",
188
  " ]\n",
189
  "\n",
190
  "FEATURE_NAMES = [\n",
191
- " 'n_events', 'n_views', 'n_carts', 'n_purchases', 'n_removes',\n",
192
  " 'avg_price', 'max_price', 'std_price',\n",
193
  " 'n_unique_categories', 'avg_hour',\n",
194
- " 'cart_rate', 'purchase_rate', 'remove_rate',\n",
195
  "]\n",
196
  "\n",
197
- "print(f'Computing features for {len(user_sequences):,} users...')\n",
198
- "tabular_features = np.array([compute_user_features(seq) for seq in user_sequences], dtype=np.float32)\n",
199
- "labels = np.array([1.0 if any(e['event_type'] == 'purchase' for e in seq) else 0.0 for seq in user_sequences])\n",
200
- "\n",
201
- "print(f'Features shape: {tabular_features.shape}')\n",
202
- "print(f'Labels: {labels.sum():.0f} purchasers / {len(labels)} total ({labels.mean()*100:.1f}%)')"
203
  ]
204
  },
205
  {
@@ -210,27 +242,25 @@
210
  "source": [
211
  "# Train/test split (80/20, stratified)\n",
212
  "train_idx, test_idx = train_test_split(\n",
213
- " range(len(user_sequences)), test_size=0.2, random_state=42, stratify=labels\n",
214
  ")\n",
215
  "\n",
216
- "train_seqs = [user_sequences[i] for i in train_idx]\n",
217
- "test_seqs = [user_sequences[i] for i in test_idx]\n",
218
  "train_features = tabular_features[train_idx]\n",
219
  "test_features = tabular_features[test_idx]\n",
220
- "train_labels = labels[train_idx]\n",
221
- "test_labels = labels[test_idx]\n",
222
  "\n",
223
- "print(f'Train: {len(train_seqs):,} ({train_labels.mean()*100:.1f}% positive)')\n",
224
- "print(f'Test: {len(test_seqs):,} ({test_labels.mean()*100:.1f}% positive)')"
225
  ]
226
  },
227
  {
228
  "cell_type": "markdown",
229
  "metadata": {},
230
  "source": [
231
- "## Step 3 — LightGBM Baseline\n",
232
- "\n",
233
- "Standard ML baseline: LightGBM on hand-crafted tabular features. This is what we need to beat."
234
  ]
235
  },
236
  {
@@ -244,19 +274,15 @@
244
  "lgb_model = lgb.LGBMClassifier(n_estimators=200, learning_rate=0.05, max_depth=6, random_state=42, verbose=-1)\n",
245
  "lgb_model.fit(train_features, train_labels)\n",
246
  "\n",
247
- "lgb_train_probs = lgb_model.predict_proba(train_features)[:, 1]\n",
248
  "lgb_test_probs = lgb_model.predict_proba(test_features)[:, 1]\n",
249
- "\n",
250
- "lgb_train_auc = roc_auc_score(train_labels, lgb_train_probs)\n",
251
  "lgb_test_auc = roc_auc_score(test_labels, lgb_test_probs)\n",
252
  "\n",
253
- "print(f'LightGBM Baseline:')\n",
254
- "print(f' Train AUC: {lgb_train_auc:.4f}')\n",
255
- "print(f' Test AUC: {lgb_test_auc:.4f}')\n",
256
  "\n",
257
  "importance = pd.Series(lgb_model.feature_importances_, index=FEATURE_NAMES).sort_values(ascending=False)\n",
258
  "print(f'\\nTop features:')\n",
259
- "for feat, imp in importance.head(5).items(): print(f' {feat}: {imp}')"
260
  ]
261
  },
262
  {
@@ -265,10 +291,7 @@
265
  "source": [
266
  "## Step 4 — JointFusionModel Fine-Tuning\n",
267
  "\n",
268
- "Combines:\n",
269
- "- **Transaction branch:** Pre-trained DomainTransformer → user embedding\n",
270
- "- **Tabular branch:** DCNv2 with PLR embeddings on hand-crafted features\n",
271
- "- **Joint head:** MLP on concatenated embeddings → binary prediction"
272
  ]
273
  },
274
  {
@@ -284,8 +307,7 @@
284
  "test_dataset = DomainFinetuneDataset(\n",
285
  " test_seqs, test_features, test_labels, builder, hf_tokenizer, max_length=MAX_LENGTH)\n",
286
  "\n",
287
- "print(f'Train: {len(train_dataset)}, Test: {len(test_dataset)}')\n",
288
- "print(f'Sample keys: {set(train_dataset[0].keys())}')"
289
  ]
290
  },
291
  {
@@ -328,11 +350,11 @@
328
  " learning_rate=1e-4,\n",
329
  " warmup_steps=50,\n",
330
  " logging_steps=20,\n",
331
- " eval_steps=100 if USE_GPU else 50,\n",
332
  " save_strategy='no',\n",
333
  " bf16=USE_BF16, fp16=USE_FP16,\n",
334
  " report_to='wandb',\n",
335
- " run_name='ecommerce-finetune-joint-5ep',\n",
336
  " seed=42,\n",
337
  ")"
338
  ]
@@ -366,9 +388,7 @@
366
  " all_probs.extend(probs.cpu().numpy())\n",
367
  " all_labels_eval.extend(labels_batch.cpu().numpy())\n",
368
  "\n",
369
- "all_probs = np.array(all_probs)\n",
370
- "all_labels_eval = np.array(all_labels_eval)\n",
371
- "fusion_test_auc = roc_auc_score(all_labels_eval, all_probs)\n",
372
  "print(f'JointFusion Test AUC: {fusion_test_auc:.4f}')"
373
  ]
374
  },
@@ -378,19 +398,23 @@
378
  "metadata": {},
379
  "outputs": [],
380
  "source": [
381
- "print('=' * 50)\n",
382
- "print('MODEL COMPARISON — Purchase Prediction (AUC)')\n",
383
- "print('=' * 50)\n",
384
- "print(f' LightGBM (tabular only): {lgb_test_auc:.4f}')\n",
385
- "print(f' JointFusion (Transformer+DCNv2): {fusion_test_auc:.4f}')\n",
386
- "print(f' Difference: {fusion_test_auc - lgb_test_auc:+.4f}')\n",
387
- "print('=' * 50)\n",
388
  "\n",
389
  "if fusion_test_auc > lgb_test_auc:\n",
390
- " print(f'\\n✅ JointFusion beats LightGBM by {(fusion_test_auc - lgb_test_auc)*100:.2f} percentage points')\n",
 
 
 
 
391
  "else:\n",
392
- " print(f'\\n⚠️ LightGBM still leads by {(lgb_test_auc - fusion_test_auc)*100:.2f} percentage points')\n",
393
- " print(f' (More pre-training epochs and longer context would improve transformer embeddings.)')"
394
  ]
395
  },
396
  {
@@ -405,9 +429,9 @@
405
  "fig, ax = plt.subplots(figsize=(10, 5))\n",
406
  "ax.plot(losses, label='Train Loss', alpha=0.7)\n",
407
  "if eval_losses:\n",
408
- " eval_steps_x = np.linspace(0, len(losses), len(eval_losses))\n",
409
- " ax.plot(eval_steps_x, eval_losses, 'ro-', label='Eval Loss', markersize=4)\n",
410
- "ax.set_xlabel('Step'); ax.set_ylabel('Loss'); ax.set_title('Fine-Tuning Loss')\n",
411
  "ax.legend(); ax.grid(True, alpha=0.3); plt.tight_layout(); plt.show()"
412
  ]
413
  },
@@ -427,12 +451,14 @@
427
  "source": [
428
  "## Summary\n",
429
  "\n",
430
- "| Model | Test AUC | Notes |\n",
431
  "|-------|----------|-------|\n",
432
- "| LightGBM (tabular) | *see above* | 13 hand-crafted features |\n",
433
- "| JointFusion (Transformer+DCNv2) | *see above* | Pre-trained domain tokens + same 13 features |\n",
 
 
434
  "\n",
435
- "The pre-trained DomainTransformer captures sequential behavioral patterns (view→cart→purchase funnels, category stickiness, temporal habits) that hand-crafted features cannot fully represent."
436
  ]
437
  }
438
  ],
 
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
+ "# 03 — E-Commerce Fine-Tuning: Future Purchase Prediction\n",
8
  "\n",
9
+ "**Goal:** Fine-tune the pre-trained DomainTransformer for predicting whether a user will purchase in the future, using only their past browsing history.\n",
10
  "\n",
11
+ "**Task:** Binary classification — given the first 70% of a user's events, predict if they purchase in the remaining 30%.\n",
12
+ "\n",
13
+ "**Why temporal split:** Avoids label leakage. The previous version used `n_purchases` as a feature to predict `has_purchase` → trivial AUC 1.0. This version simulates the real production scenario: predict future behavior from past behavior.\n",
14
  "\n",
15
  "**Pre-trained model:** [rtferraz/ecommerce-domain-24m](https://huggingface.co/rtferraz/ecommerce-domain-24m)\n",
16
  "\n",
 
48
  "import matplotlib.pyplot as plt\n",
49
  "import torch\n",
50
  "from sklearn.model_selection import train_test_split\n",
51
+ "from sklearn.metrics import roc_auc_score\n",
52
  "\n",
53
  "if os.path.exists('../src'): sys.path.insert(0, '../src')\n",
54
  "elif os.path.exists('src'): sys.path.insert(0, 'src')\n",
 
56
  "from domain_tokenizer import (\n",
57
  " DomainTokenizerBuilder, DomainTransformerConfig,\n",
58
  " DomainTransformerForCausalLM, JointFusionModel,\n",
59
+ " DomainFinetuneDataset, finetune_domain_model,\n",
60
  ")\n",
61
  "from domain_tokenizer.schema import DomainSchema, FieldSpec, FieldType\n",
62
  "\n",
 
84
  "cell_type": "markdown",
85
  "metadata": {},
86
  "source": [
87
+ "## Step 1 — Load Pre-trained Artifacts"
 
 
88
  ]
89
  },
90
  {
 
93
  "metadata": {},
94
  "outputs": [],
95
  "source": [
 
96
  "with open('./ecommerce_artifacts.pkl', 'rb') as f:\n",
97
  " artifacts = pickle.load(f)\n",
 
98
  "user_sequences = artifacts['user_sequences']\n",
99
  "user_ids = artifacts['user_ids']\n",
100
  "print(f'Loaded {len(user_sequences):,} users')\n",
101
  "\n",
 
102
  "from transformers import PreTrainedTokenizerFast\n",
103
  "hf_tokenizer = PreTrainedTokenizerFast.from_pretrained('./ecommerce_tokenizer')\n",
104
  "print(f'Tokenizer vocab: {hf_tokenizer.vocab_size}')\n",
105
  "\n",
 
106
  "ECOMMERCE_REES46_SCHEMA = DomainSchema(\n",
107
  " name='ecommerce_rees46',\n",
108
  " fields=[\n",
 
126
  "metadata": {},
127
  "outputs": [],
128
  "source": [
 
 
129
  "model = DomainTransformerForCausalLM.from_pretrained('./ecommerce_pretrain_checkpoints/final/')\n",
 
 
 
 
130
  "print(f'Pre-trained model loaded: {sum(p.numel() for p in model.parameters()):,} params')"
131
  ]
132
  },
 
134
  "cell_type": "markdown",
135
  "metadata": {},
136
  "source": [
137
+ "## Step 2 — Temporal Split: Labels and Features\n",
138
  "\n",
139
+ "**The key design (avoids leakage):**\n",
140
+ "- Split each user's events at the 70% mark temporally\n",
141
+ "- **Input to model:** first 70% of events (history)\n",
142
+ "- **Label:** did the user purchase in the last 30%? (future)\n",
143
+ "- **Tabular features:** computed only from the first 70% (no future info)\n",
144
  "\n",
145
+ "This matches Nubank's setup: predict future behavior from past history."
146
  ]
147
  },
148
  {
 
151
  "metadata": {},
152
  "outputs": [],
153
  "source": [
154
+ "SPLIT_RATIO = 0.7 # 70% history, 30% future\n",
155
+ "MIN_HISTORY = 5 # need at least 5 events in history\n",
156
+ "MIN_FUTURE = 3 # need at least 3 events in future\n",
157
+ "\n",
158
+ "history_sequences = [] # input to model\n",
159
+ "future_labels = [] # target: purchased in future?\n",
160
+ "valid_user_ids = []\n",
161
+ "\n",
162
+ "for i, events in enumerate(user_sequences):\n",
163
+ " split_idx = int(len(events) * SPLIT_RATIO)\n",
164
+ " history = events[:split_idx]\n",
165
+ " future = events[split_idx:]\n",
166
+ " \n",
167
+ " if len(history) < MIN_HISTORY or len(future) < MIN_FUTURE:\n",
168
+ " continue\n",
169
+ " \n",
170
+ " # Label: did user purchase in the future window?\n",
171
+ " has_future_purchase = any(e['event_type'] == 'purchase' for e in future)\n",
172
+ " \n",
173
+ " history_sequences.append(history)\n",
174
+ " future_labels.append(1.0 if has_future_purchase else 0.0)\n",
175
+ " valid_user_ids.append(user_ids[i])\n",
176
+ "\n",
177
+ "future_labels = np.array(future_labels)\n",
178
+ "print(f'Valid users (enough history + future): {len(history_sequences):,}')\n",
179
+ "print(f'Future purchasers: {future_labels.sum():.0f} / {len(future_labels)} ({future_labels.mean()*100:.1f}%)')"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": null,
185
+ "metadata": {},
186
+ "outputs": [],
187
+ "source": [
188
+ "def compute_history_features(events):\n",
189
+ " \"\"\"Features from HISTORY ONLY — no future information leaks.\"\"\"\n",
190
  " n_events = len(events)\n",
191
  " n_views = sum(1 for e in events if e['event_type'] == 'view')\n",
192
  " n_carts = sum(1 for e in events if e['event_type'] == 'cart')\n",
 
193
  " n_removes = sum(1 for e in events if e['event_type'] == 'remove_from_cart')\n",
194
+ " # NOTE: n_purchases in HISTORY is allowed — it's past behavior, not future\n",
195
+ " n_hist_purchases = sum(1 for e in events if e['event_type'] == 'purchase')\n",
196
  " \n",
197
  " prices = [e['price'] for e in events if e['price'] > 0]\n",
198
  " avg_price = np.mean(prices) if prices else 0\n",
199
  " max_price = max(prices) if prices else 0\n",
200
  " std_price = np.std(prices) if len(prices) > 1 else 0\n",
201
  " \n",
202
+ " n_unique_categories = len(set(e['category'] for e in events))\n",
203
+ " avg_hour = np.mean([e['timestamp'].hour for e in events])\n",
 
 
 
204
  " \n",
205
+ " # Funnel ratios from history\n",
206
  " cart_rate = n_carts / max(n_views, 1)\n",
 
207
  " remove_rate = n_removes / max(n_carts, 1) if n_carts > 0 else 0\n",
208
+ " hist_purchase_rate = n_hist_purchases / max(n_events, 1)\n",
209
+ " \n",
210
+ " # Session intensity (events per day approximation)\n",
211
+ " if len(events) >= 2:\n",
212
+ " time_span = (events[-1]['timestamp'] - events[0]['timestamp']).total_seconds() / 86400 # days\n",
213
+ " events_per_day = n_events / max(time_span, 1)\n",
214
+ " else:\n",
215
+ " events_per_day = 0\n",
216
  " \n",
217
  " return [\n",
218
+ " n_events, n_views, n_carts, n_removes, n_hist_purchases,\n",
219
  " avg_price, max_price, std_price,\n",
220
  " n_unique_categories, avg_hour,\n",
221
+ " cart_rate, remove_rate, hist_purchase_rate, events_per_day,\n",
222
  " ]\n",
223
  "\n",
224
  "FEATURE_NAMES = [\n",
225
+ " 'n_events', 'n_views', 'n_carts', 'n_removes', 'n_hist_purchases',\n",
226
  " 'avg_price', 'max_price', 'std_price',\n",
227
  " 'n_unique_categories', 'avg_hour',\n",
228
+ " 'cart_rate', 'remove_rate', 'hist_purchase_rate', 'events_per_day',\n",
229
  "]\n",
230
  "\n",
231
+ "print(f'Computing features from history only...')\n",
232
+ "tabular_features = np.array([compute_history_features(seq) for seq in history_sequences], dtype=np.float32)\n",
233
+ "print(f'Features: {tabular_features.shape}, {len(FEATURE_NAMES)} features')\n",
234
+ "print(f'Feature names: {FEATURE_NAMES}')"
 
 
235
  ]
236
  },
237
  {
 
242
  "source": [
243
  "# Train/test split (80/20, stratified)\n",
244
  "train_idx, test_idx = train_test_split(\n",
245
+ " range(len(history_sequences)), test_size=0.2, random_state=42, stratify=future_labels\n",
246
  ")\n",
247
  "\n",
248
+ "train_seqs = [history_sequences[i] for i in train_idx]\n",
249
+ "test_seqs = [history_sequences[i] for i in test_idx]\n",
250
  "train_features = tabular_features[train_idx]\n",
251
  "test_features = tabular_features[test_idx]\n",
252
+ "train_labels = future_labels[train_idx]\n",
253
+ "test_labels = future_labels[test_idx]\n",
254
  "\n",
255
+ "print(f'Train: {len(train_seqs):,} ({train_labels.mean()*100:.1f}% will purchase in future)')\n",
256
+ "print(f'Test: {len(test_seqs):,} ({test_labels.mean()*100:.1f}% will purchase in future)')"
257
  ]
258
  },
259
  {
260
  "cell_type": "markdown",
261
  "metadata": {},
262
  "source": [
263
+ "## Step 3 — LightGBM Baseline (history features only)"
 
 
264
  ]
265
  },
266
  {
 
274
  "lgb_model = lgb.LGBMClassifier(n_estimators=200, learning_rate=0.05, max_depth=6, random_state=42, verbose=-1)\n",
275
  "lgb_model.fit(train_features, train_labels)\n",
276
  "\n",
 
277
  "lgb_test_probs = lgb_model.predict_proba(test_features)[:, 1]\n",
 
 
278
  "lgb_test_auc = roc_auc_score(test_labels, lgb_test_probs)\n",
279
  "\n",
280
+ "print(f'LightGBM Baseline (history features only):')\n",
281
+ "print(f' Test AUC: {lgb_test_auc:.4f}')\n",
 
282
  "\n",
283
  "importance = pd.Series(lgb_model.feature_importances_, index=FEATURE_NAMES).sort_values(ascending=False)\n",
284
  "print(f'\\nTop features:')\n",
285
+ "for feat, imp in importance.head(7).items(): print(f' {feat}: {imp}')"
286
  ]
287
  },
288
  {
 
291
  "source": [
292
  "## Step 4 — JointFusionModel Fine-Tuning\n",
293
  "\n",
294
+ "The transformer sees the **raw event sequence** (history only). The DCNv2 branch sees the **hand-crafted features** (also history only). The question: does the raw sequence add signal beyond what the features capture?"
 
 
 
295
  ]
296
  },
297
  {
 
307
  "test_dataset = DomainFinetuneDataset(\n",
308
  " test_seqs, test_features, test_labels, builder, hf_tokenizer, max_length=MAX_LENGTH)\n",
309
  "\n",
310
+ "print(f'Train: {len(train_dataset)}, Test: {len(test_dataset)}')"
 
311
  ]
312
  },
313
  {
 
350
  " learning_rate=1e-4,\n",
351
  " warmup_steps=50,\n",
352
  " logging_steps=20,\n",
353
+ " eval_steps=200 if USE_GPU else 50,\n",
354
  " save_strategy='no',\n",
355
  " bf16=USE_BF16, fp16=USE_FP16,\n",
356
  " report_to='wandb',\n",
357
+ " run_name='ecommerce-finetune-temporal-5ep',\n",
358
  " seed=42,\n",
359
  ")"
360
  ]
 
388
  " all_probs.extend(probs.cpu().numpy())\n",
389
  " all_labels_eval.extend(labels_batch.cpu().numpy())\n",
390
  "\n",
391
+ "fusion_test_auc = roc_auc_score(np.array(all_labels_eval), np.array(all_probs))\n",
 
 
392
  "print(f'JointFusion Test AUC: {fusion_test_auc:.4f}')"
393
  ]
394
  },
 
398
  "metadata": {},
399
  "outputs": [],
400
  "source": [
401
+ "print('=' * 60)\n",
402
+ "print('MODEL COMPARISON — Future Purchase Prediction (AUC)')\n",
403
+ "print('=' * 60)\n",
404
+ "print(f' LightGBM (history features only): {lgb_test_auc:.4f}')\n",
405
+ "print(f' JointFusion (Transformer + features): {fusion_test_auc:.4f}')\n",
406
+ "print(f' Difference: {fusion_test_auc - lgb_test_auc:+.4f}')\n",
407
+ "print('=' * 60)\n",
408
  "\n",
409
  "if fusion_test_auc > lgb_test_auc:\n",
410
+ " print(f'\\n✅ JointFusion beats LightGBM by {(fusion_test_auc - lgb_test_auc)*100:.2f} pp')\n",
411
+ " print(f' The sequential patterns from domain tokens add value beyond tabular features.')\n",
412
+ "elif abs(fusion_test_auc - lgb_test_auc) < 0.005:\n",
413
+ " print(f'\\n≈ Roughly tied. The transformer embeddings match LightGBM.')\n",
414
+ " print(f' More pre-training epochs would likely push JointFusion ahead.')\n",
415
  "else:\n",
416
+ " print(f'\\n⚠️ LightGBM leads by {(lgb_test_auc - fusion_test_auc)*100:.2f} pp')\n",
417
+ " print(f' More pre-training (10+ epochs) and longer context (1024+) needed.')"
418
  ]
419
  },
420
  {
 
429
  "fig, ax = plt.subplots(figsize=(10, 5))\n",
430
  "ax.plot(losses, label='Train Loss', alpha=0.7)\n",
431
  "if eval_losses:\n",
432
+ " eval_x = np.linspace(0, len(losses), len(eval_losses))\n",
433
+ " ax.plot(eval_x, eval_losses, 'ro-', label='Eval Loss', markersize=4)\n",
434
+ "ax.set_xlabel('Step'); ax.set_ylabel('Loss'); ax.set_title('Fine-Tuning Loss (Temporal Split)')\n",
435
  "ax.legend(); ax.grid(True, alpha=0.3); plt.tight_layout(); plt.show()"
436
  ]
437
  },
 
451
  "source": [
452
  "## Summary\n",
453
  "\n",
454
+ "| Model | Test AUC | Input |\n",
455
  "|-------|----------|-------|\n",
456
+ "| LightGBM | *see above* | 14 history-only features |\n",
457
+ "| JointFusion | *see above* | Pre-trained domain token sequence + same 14 features |\n",
458
+ "\n",
459
+ "**Task:** Predict future purchase from past browsing history (temporal split, no leakage).\n",
460
  "\n",
461
+ "The pre-trained DomainTransformer captures sequential patterns (browsing funnels, category stickiness, temporal habits) that may add predictive signal beyond aggregate features."
462
  ]
463
  }
464
  ],