rtferraz commited on
Commit
165b138
·
verified ·
1 Parent(s): 857ec9a

Fix model loading: use from_pretrained() instead of torch.load() for safetensors format

Browse files
Files changed (1) hide show
  1. notebooks/03_ecommerce_finetune.ipynb +34 -57
notebooks/03_ecommerce_finetune.ipynb CHANGED
@@ -28,7 +28,7 @@
28
  "metadata": {},
29
  "outputs": [],
30
  "source": [
31
- "# !pip install datasets transformers torch accelerate tokenizers numpy pandas matplotlib scikit-learn wandb huggingface_hub lightgbm"
32
  ]
33
  },
34
  {
@@ -130,11 +130,13 @@
130
  "metadata": {},
131
  "outputs": [],
132
  "source": [
133
- "# Load pre-trained model\n",
134
- "config = DomainTransformerConfig.from_preset('24m', vocab_size=hf_tokenizer.vocab_size)\n",
135
- "model = DomainTransformerForCausalLM(config)\n",
136
- "model.load_state_dict(torch.load('./ecommerce_pretrain_checkpoints/final/model.safetensors',\n",
137
- " map_location='cpu', weights_only=True), strict=False)\n",
 
 
138
  "print(f'Pre-trained model loaded: {sum(p.numel() for p in model.parameters()):,} params')"
139
  ]
140
  },
@@ -171,11 +173,9 @@
171
  " categories = set(e['category'] for e in events)\n",
172
  " n_unique_categories = len(categories)\n",
173
  " \n",
174
- " # Temporal features\n",
175
  " hours = [e['timestamp'].hour for e in events]\n",
176
  " avg_hour = np.mean(hours)\n",
177
  " \n",
178
- " # Conversion funnel ratios\n",
179
  " cart_rate = n_carts / max(n_views, 1)\n",
180
  " purchase_rate = n_purchases / max(n_events, 1)\n",
181
  " remove_rate = n_removes / max(n_carts, 1) if n_carts > 0 else 0\n",
@@ -183,16 +183,14 @@
183
  " return [\n",
184
  " n_events, n_views, n_carts, n_purchases, n_removes,\n",
185
  " avg_price, max_price, std_price,\n",
186
- " n_unique_categories,\n",
187
- " avg_hour,\n",
188
  " cart_rate, purchase_rate, remove_rate,\n",
189
  " ]\n",
190
  "\n",
191
  "FEATURE_NAMES = [\n",
192
  " 'n_events', 'n_views', 'n_carts', 'n_purchases', 'n_removes',\n",
193
  " 'avg_price', 'max_price', 'std_price',\n",
194
- " 'n_unique_categories',\n",
195
- " 'avg_hour',\n",
196
  " 'cart_rate', 'purchase_rate', 'remove_rate',\n",
197
  "]\n",
198
  "\n",
@@ -201,8 +199,7 @@
201
  "labels = np.array([1.0 if any(e['event_type'] == 'purchase' for e in seq) else 0.0 for seq in user_sequences])\n",
202
  "\n",
203
  "print(f'Features shape: {tabular_features.shape}')\n",
204
- "print(f'Labels: {labels.sum():.0f} purchasers / {len(labels)} total ({labels.mean()*100:.1f}%)')\n",
205
- "print(f'Feature names: {FEATURE_NAMES}')"
206
  ]
207
  },
208
  {
@@ -211,7 +208,7 @@
211
  "metadata": {},
212
  "outputs": [],
213
  "source": [
214
- "# Train/test split (80/20, stratified by label)\n",
215
  "train_idx, test_idx = train_test_split(\n",
216
  " range(len(user_sequences)), test_size=0.2, random_state=42, stratify=labels\n",
217
  ")\n",
@@ -223,8 +220,8 @@
223
  "train_labels = labels[train_idx]\n",
224
  "test_labels = labels[test_idx]\n",
225
  "\n",
226
- "print(f'Train: {len(train_seqs):,} users ({train_labels.mean()*100:.1f}% positive)')\n",
227
- "print(f'Test: {len(test_seqs):,} users ({test_labels.mean()*100:.1f}% positive)')"
228
  ]
229
  },
230
  {
@@ -257,11 +254,9 @@
257
  "print(f' Train AUC: {lgb_train_auc:.4f}')\n",
258
  "print(f' Test AUC: {lgb_test_auc:.4f}')\n",
259
  "\n",
260
- "# Feature importance\n",
261
  "importance = pd.Series(lgb_model.feature_importances_, index=FEATURE_NAMES).sort_values(ascending=False)\n",
262
  "print(f'\\nTop features:')\n",
263
- "for feat, imp in importance.head(5).items():\n",
264
- " print(f' {feat}: {imp}')"
265
  ]
266
  },
267
  {
@@ -270,7 +265,7 @@
270
  "source": [
271
  "## Step 4 — JointFusionModel Fine-Tuning\n",
272
  "\n",
273
- "The JointFusionModel combines:\n",
274
  "- **Transaction branch:** Pre-trained DomainTransformer → user embedding\n",
275
  "- **Tabular branch:** DCNv2 with PLR embeddings on hand-crafted features\n",
276
  "- **Joint head:** MLP on concatenated embeddings → binary prediction"
@@ -282,21 +277,15 @@
282
  "metadata": {},
283
  "outputs": [],
284
  "source": [
285
- "# Create fine-tuning datasets\n",
286
- "MAX_LENGTH = 256 # tokens per user sequence\n",
287
  "\n",
288
  "train_dataset = DomainFinetuneDataset(\n",
289
- " train_seqs, train_features, train_labels,\n",
290
- " builder, hf_tokenizer, max_length=MAX_LENGTH,\n",
291
- ")\n",
292
  "test_dataset = DomainFinetuneDataset(\n",
293
- " test_seqs, test_features, test_labels,\n",
294
- " builder, hf_tokenizer, max_length=MAX_LENGTH,\n",
295
- ")\n",
296
  "\n",
297
- "print(f'Train dataset: {len(train_dataset)} samples')\n",
298
- "print(f'Test dataset: {len(test_dataset)} samples')\n",
299
- "print(f'Sample: {set(train_dataset[0].keys())}')"
300
  ]
301
  },
302
  {
@@ -305,22 +294,15 @@
305
  "metadata": {},
306
  "outputs": [],
307
  "source": [
308
- "# Create JointFusionModel\n",
309
  "fusion_model = JointFusionModel(\n",
310
  " transformer_model=model,\n",
311
  " n_tabular_features=len(FEATURE_NAMES),\n",
312
- " n_classes=1, # binary\n",
313
- " plr_frequencies=32,\n",
314
- " plr_embedding_dim=32,\n",
315
- " dcn_cross_layers=3,\n",
316
- " dcn_deep_layers=2,\n",
317
- " dcn_deep_dim=128,\n",
318
- " head_hidden_dim=128,\n",
319
- " dropout=0.1,\n",
320
  ")\n",
321
- "\n",
322
- "n_params = sum(p.numel() for p in fusion_model.parameters())\n",
323
- "print(f'JointFusion model: {n_params:,} params (transformer + DCNv2 + head)')"
324
  ]
325
  },
326
  {
@@ -348,8 +330,7 @@
348
  " logging_steps=20,\n",
349
  " eval_steps=100 if USE_GPU else 50,\n",
350
  " save_strategy='no',\n",
351
- " bf16=USE_BF16,\n",
352
- " fp16=USE_FP16,\n",
353
  " report_to='wandb',\n",
354
  " run_name='ecommerce-finetune-joint-5ep',\n",
355
  " seed=42,\n",
@@ -369,12 +350,11 @@
369
  "metadata": {},
370
  "outputs": [],
371
  "source": [
372
- "# Get predictions from JointFusion model\n",
373
  "fusion_model.eval()\n",
374
  "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
375
  "fusion_model = fusion_model.to(device)\n",
376
  "\n",
377
- "all_probs, all_labels = [], []\n",
378
  "loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)\n",
379
  "\n",
380
  "with torch.no_grad():\n",
@@ -384,12 +364,11 @@
384
  " out = fusion_model(**batch)\n",
385
  " probs = torch.sigmoid(out['logits'].squeeze(-1))\n",
386
  " all_probs.extend(probs.cpu().numpy())\n",
387
- " all_labels.extend(labels_batch.cpu().numpy())\n",
388
  "\n",
389
  "all_probs = np.array(all_probs)\n",
390
- "all_labels = np.array(all_labels)\n",
391
- "\n",
392
- "fusion_test_auc = roc_auc_score(all_labels, all_probs)\n",
393
  "print(f'JointFusion Test AUC: {fusion_test_auc:.4f}')"
394
  ]
395
  },
@@ -399,7 +378,6 @@
399
  "metadata": {},
400
  "outputs": [],
401
  "source": [
402
- "# Comparison table\n",
403
  "print('=' * 50)\n",
404
  "print('MODEL COMPARISON — Purchase Prediction (AUC)')\n",
405
  "print('=' * 50)\n",
@@ -412,7 +390,7 @@
412
  " print(f'\\n✅ JointFusion beats LightGBM by {(fusion_test_auc - lgb_test_auc)*100:.2f} percentage points')\n",
413
  "else:\n",
414
  " print(f'\\n⚠️ LightGBM still leads by {(lgb_test_auc - fusion_test_auc)*100:.2f} percentage points')\n",
415
- " print(f' (Expected with only 3-epoch pre-training. More epochs would improve the transformer embeddings.)')"
416
  ]
417
  },
418
  {
@@ -421,15 +399,14 @@
421
  "metadata": {},
422
  "outputs": [],
423
  "source": [
424
- "# Loss curve\n",
425
  "losses = [h['loss'] for h in trainer.state.log_history if 'loss' in h]\n",
426
  "eval_losses = [h['eval_loss'] for h in trainer.state.log_history if 'eval_loss' in h]\n",
427
  "\n",
428
  "fig, ax = plt.subplots(figsize=(10, 5))\n",
429
  "ax.plot(losses, label='Train Loss', alpha=0.7)\n",
430
  "if eval_losses:\n",
431
- " eval_steps = np.linspace(0, len(losses), len(eval_losses))\n",
432
- " ax.plot(eval_steps, eval_losses, 'ro-', label='Eval Loss', markersize=4)\n",
433
  "ax.set_xlabel('Step'); ax.set_ylabel('Loss'); ax.set_title('Fine-Tuning Loss')\n",
434
  "ax.legend(); ax.grid(True, alpha=0.3); plt.tight_layout(); plt.show()"
435
  ]
 
28
  "metadata": {},
29
  "outputs": [],
30
  "source": [
31
+ "# !pip install datasets transformers torch accelerate tokenizers numpy pandas matplotlib scikit-learn wandb huggingface_hub lightgbm safetensors"
32
  ]
33
  },
34
  {
 
130
  "metadata": {},
131
  "outputs": [],
132
  "source": [
133
+ "# Load pre-trained model using from_pretrained (handles safetensors natively)\n",
134
+ "# Option A: from local checkpoint saved by notebook 02\n",
135
+ "model = DomainTransformerForCausalLM.from_pretrained('./ecommerce_pretrain_checkpoints/final/')\n",
136
+ "\n",
137
+ "# Option B: from HuggingFace Hub (if local not available)\n",
138
+ "# model = DomainTransformerForCausalLM.from_pretrained('rtferraz/ecommerce-domain-24m')\n",
139
+ "\n",
140
  "print(f'Pre-trained model loaded: {sum(p.numel() for p in model.parameters()):,} params')"
141
  ]
142
  },
 
173
  " categories = set(e['category'] for e in events)\n",
174
  " n_unique_categories = len(categories)\n",
175
  " \n",
 
176
  " hours = [e['timestamp'].hour for e in events]\n",
177
  " avg_hour = np.mean(hours)\n",
178
  " \n",
 
179
  " cart_rate = n_carts / max(n_views, 1)\n",
180
  " purchase_rate = n_purchases / max(n_events, 1)\n",
181
  " remove_rate = n_removes / max(n_carts, 1) if n_carts > 0 else 0\n",
 
183
  " return [\n",
184
  " n_events, n_views, n_carts, n_purchases, n_removes,\n",
185
  " avg_price, max_price, std_price,\n",
186
+ " n_unique_categories, avg_hour,\n",
 
187
  " cart_rate, purchase_rate, remove_rate,\n",
188
  " ]\n",
189
  "\n",
190
  "FEATURE_NAMES = [\n",
191
  " 'n_events', 'n_views', 'n_carts', 'n_purchases', 'n_removes',\n",
192
  " 'avg_price', 'max_price', 'std_price',\n",
193
+ " 'n_unique_categories', 'avg_hour',\n",
 
194
  " 'cart_rate', 'purchase_rate', 'remove_rate',\n",
195
  "]\n",
196
  "\n",
 
199
  "labels = np.array([1.0 if any(e['event_type'] == 'purchase' for e in seq) else 0.0 for seq in user_sequences])\n",
200
  "\n",
201
  "print(f'Features shape: {tabular_features.shape}')\n",
202
+ "print(f'Labels: {labels.sum():.0f} purchasers / {len(labels)} total ({labels.mean()*100:.1f}%)')"
 
203
  ]
204
  },
205
  {
 
208
  "metadata": {},
209
  "outputs": [],
210
  "source": [
211
+ "# Train/test split (80/20, stratified)\n",
212
  "train_idx, test_idx = train_test_split(\n",
213
  " range(len(user_sequences)), test_size=0.2, random_state=42, stratify=labels\n",
214
  ")\n",
 
220
  "train_labels = labels[train_idx]\n",
221
  "test_labels = labels[test_idx]\n",
222
  "\n",
223
+ "print(f'Train: {len(train_seqs):,} ({train_labels.mean()*100:.1f}% positive)')\n",
224
+ "print(f'Test: {len(test_seqs):,} ({test_labels.mean()*100:.1f}% positive)')"
225
  ]
226
  },
227
  {
 
254
  "print(f' Train AUC: {lgb_train_auc:.4f}')\n",
255
  "print(f' Test AUC: {lgb_test_auc:.4f}')\n",
256
  "\n",
 
257
  "importance = pd.Series(lgb_model.feature_importances_, index=FEATURE_NAMES).sort_values(ascending=False)\n",
258
  "print(f'\\nTop features:')\n",
259
+ "for feat, imp in importance.head(5).items(): print(f' {feat}: {imp}')"
 
260
  ]
261
  },
262
  {
 
265
  "source": [
266
  "## Step 4 — JointFusionModel Fine-Tuning\n",
267
  "\n",
268
+ "Combines:\n",
269
  "- **Transaction branch:** Pre-trained DomainTransformer → user embedding\n",
270
  "- **Tabular branch:** DCNv2 with PLR embeddings on hand-crafted features\n",
271
  "- **Joint head:** MLP on concatenated embeddings → binary prediction"
 
277
  "metadata": {},
278
  "outputs": [],
279
  "source": [
280
+ "MAX_LENGTH = 256\n",
 
281
  "\n",
282
  "train_dataset = DomainFinetuneDataset(\n",
283
+ " train_seqs, train_features, train_labels, builder, hf_tokenizer, max_length=MAX_LENGTH)\n",
 
 
284
  "test_dataset = DomainFinetuneDataset(\n",
285
+ " test_seqs, test_features, test_labels, builder, hf_tokenizer, max_length=MAX_LENGTH)\n",
 
 
286
  "\n",
287
+ "print(f'Train: {len(train_dataset)}, Test: {len(test_dataset)}')\n",
288
+ "print(f'Sample keys: {set(train_dataset[0].keys())}')"
 
289
  ]
290
  },
291
  {
 
294
  "metadata": {},
295
  "outputs": [],
296
  "source": [
 
297
  "fusion_model = JointFusionModel(\n",
298
  " transformer_model=model,\n",
299
  " n_tabular_features=len(FEATURE_NAMES),\n",
300
+ " n_classes=1,\n",
301
+ " plr_frequencies=32, plr_embedding_dim=32,\n",
302
+ " dcn_cross_layers=3, dcn_deep_layers=2, dcn_deep_dim=128,\n",
303
+ " head_hidden_dim=128, dropout=0.1,\n",
 
 
 
 
304
  ")\n",
305
+ "print(f'JointFusion: {sum(p.numel() for p in fusion_model.parameters()):,} params')"
 
 
306
  ]
307
  },
308
  {
 
330
  " logging_steps=20,\n",
331
  " eval_steps=100 if USE_GPU else 50,\n",
332
  " save_strategy='no',\n",
333
+ " bf16=USE_BF16, fp16=USE_FP16,\n",
 
334
  " report_to='wandb',\n",
335
  " run_name='ecommerce-finetune-joint-5ep',\n",
336
  " seed=42,\n",
 
350
  "metadata": {},
351
  "outputs": [],
352
  "source": [
 
353
  "fusion_model.eval()\n",
354
  "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
355
  "fusion_model = fusion_model.to(device)\n",
356
  "\n",
357
+ "all_probs, all_labels_eval = [], []\n",
358
  "loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)\n",
359
  "\n",
360
  "with torch.no_grad():\n",
 
364
  " out = fusion_model(**batch)\n",
365
  " probs = torch.sigmoid(out['logits'].squeeze(-1))\n",
366
  " all_probs.extend(probs.cpu().numpy())\n",
367
+ " all_labels_eval.extend(labels_batch.cpu().numpy())\n",
368
  "\n",
369
  "all_probs = np.array(all_probs)\n",
370
+ "all_labels_eval = np.array(all_labels_eval)\n",
371
+ "fusion_test_auc = roc_auc_score(all_labels_eval, all_probs)\n",
 
372
  "print(f'JointFusion Test AUC: {fusion_test_auc:.4f}')"
373
  ]
374
  },
 
378
  "metadata": {},
379
  "outputs": [],
380
  "source": [
 
381
  "print('=' * 50)\n",
382
  "print('MODEL COMPARISON — Purchase Prediction (AUC)')\n",
383
  "print('=' * 50)\n",
 
390
  " print(f'\\n✅ JointFusion beats LightGBM by {(fusion_test_auc - lgb_test_auc)*100:.2f} percentage points')\n",
391
  "else:\n",
392
  " print(f'\\n⚠️ LightGBM still leads by {(lgb_test_auc - fusion_test_auc)*100:.2f} percentage points')\n",
393
+ " print(f' (More pre-training epochs and longer context would improve transformer embeddings.)')"
394
  ]
395
  },
396
  {
 
399
  "metadata": {},
400
  "outputs": [],
401
  "source": [
 
402
  "losses = [h['loss'] for h in trainer.state.log_history if 'loss' in h]\n",
403
  "eval_losses = [h['eval_loss'] for h in trainer.state.log_history if 'eval_loss' in h]\n",
404
  "\n",
405
  "fig, ax = plt.subplots(figsize=(10, 5))\n",
406
  "ax.plot(losses, label='Train Loss', alpha=0.7)\n",
407
  "if eval_losses:\n",
408
+ " eval_steps_x = np.linspace(0, len(losses), len(eval_losses))\n",
409
+ " ax.plot(eval_steps_x, eval_losses, 'ro-', label='Eval Loss', markersize=4)\n",
410
  "ax.set_xlabel('Step'); ax.set_ylabel('Loss'); ax.set_title('Fine-Tuning Loss')\n",
411
  "ax.legend(); ax.grid(True, alpha=0.3); plt.tight_layout(); plt.show()"
412
  ]