Fix notebook: total_mem → total_memory, add hub_model_id push, add wandb logging support
Browse files- notebooks/01_finance_pretrain.ipynb +75 -138
notebooks/01_finance_pretrain.ipynb
CHANGED
|
@@ -38,7 +38,7 @@
|
|
| 38 |
"outputs": [],
|
| 39 |
"source": [
|
| 40 |
"# Uncomment and run once to install dependencies:\n",
|
| 41 |
-
"# !pip install datasets transformers torch accelerate tokenizers numpy pandas matplotlib scikit-learn"
|
| 42 |
]
|
| 43 |
},
|
| 44 |
{
|
|
@@ -75,7 +75,23 @@
|
|
| 75 |
"logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')\n",
|
| 76 |
"print(f'torch: {torch.__version__}, CUDA: {torch.cuda.is_available()}')\n",
|
| 77 |
"if torch.cuda.is_available():\n",
|
| 78 |
-
" print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
]
|
| 80 |
},
|
| 81 |
{
|
|
@@ -143,7 +159,6 @@
|
|
| 143 |
"metadata": {},
|
| 144 |
"outputs": [],
|
| 145 |
"source": [
|
| 146 |
-
"# Events per user distribution\n",
|
| 147 |
"events_per_user = df.groupby('sender_account').size()\n",
|
| 148 |
"print(f\"Events per user: min={events_per_user.min()}, max={events_per_user.max()}, \"\n",
|
| 149 |
" f\"mean={events_per_user.mean():.1f}, median={events_per_user.median():.1f}\")\n",
|
|
@@ -151,23 +166,13 @@
|
|
| 151 |
"print(f\"Users with 10+ events: {(events_per_user >= 10).sum():,}\")\n",
|
| 152 |
"\n",
|
| 153 |
"fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
|
| 154 |
-
"\n",
|
| 155 |
"axes[0].hist(np.log10(df['amount_ngn'].clip(lower=1)), bins=50, edgecolor='black', alpha=0.7)\n",
|
| 156 |
-
"axes[0].set_xlabel('log10(Amount NGN)')\n",
|
| 157 |
-
"axes[0].set_ylabel('Count')\n",
|
| 158 |
-
"axes[0].set_title('Amount Distribution (log scale)')\n",
|
| 159 |
-
"\n",
|
| 160 |
"axes[1].hist(events_per_user.clip(upper=50), bins=50, edgecolor='black', alpha=0.7)\n",
|
| 161 |
-
"axes[1].set_xlabel('Events per User')\n",
|
| 162 |
-
"axes[1].set_ylabel('Count')\n",
|
| 163 |
-
"axes[1].set_title('Events per User')\n",
|
| 164 |
-
"\n",
|
| 165 |
"df['transaction_type'].value_counts().head(10).plot(kind='barh', ax=axes[2])\n",
|
| 166 |
-
"axes[2].set_xlabel('Count')\n",
|
| 167 |
-
"
|
| 168 |
-
"\n",
|
| 169 |
-
"plt.tight_layout()\n",
|
| 170 |
-
"plt.show()"
|
| 171 |
]
|
| 172 |
},
|
| 173 |
{
|
|
@@ -190,21 +195,14 @@
|
|
| 190 |
"outputs": [],
|
| 191 |
"source": [
|
| 192 |
"def row_to_event(row):\n",
|
| 193 |
-
" \"\"\"Convert a DataFrame row to a FINANCE_SCHEMA event dict.\"\"\"\n",
|
| 194 |
" dt = datetime.strptime(row['timestamp'][:19], '%Y-%m-%d %H:%M:%S')\n",
|
| 195 |
" desc = f\"{row['merchant_category']} {row['transaction_type']}\"\n",
|
| 196 |
" amt = row['amount_ngn']\n",
|
| 197 |
" if row['transaction_type'] == 'withdrawal':\n",
|
| 198 |
" amt = -abs(amt)\n",
|
| 199 |
-
" return {\n",
|
| 200 |
-
"
|
| 201 |
-
"
|
| 202 |
-
" 'timestamp': dt,\n",
|
| 203 |
-
" 'description': desc,\n",
|
| 204 |
-
" }\n",
|
| 205 |
-
"\n",
|
| 206 |
-
"sample = row_to_event(df.iloc[0])\n",
|
| 207 |
-
"print(f'Sample event: {sample}')"
|
| 208 |
]
|
| 209 |
},
|
| 210 |
{
|
|
@@ -215,12 +213,9 @@
|
|
| 215 |
"source": [
|
| 216 |
"%%time\n",
|
| 217 |
"MIN_EVENTS = 5\n",
|
| 218 |
-
"MAX_EVENTS = 500
|
| 219 |
-
"\n",
|
| 220 |
-
"user_sequences = []\n",
|
| 221 |
-
"user_ids = []\n",
|
| 222 |
-
"user_fraud_labels = []\n",
|
| 223 |
"\n",
|
|
|
|
| 224 |
"for sender, group in df.sort_values('timestamp').groupby('sender_account'):\n",
|
| 225 |
" if len(group) < MIN_EVENTS:\n",
|
| 226 |
" continue\n",
|
|
@@ -231,9 +226,7 @@
|
|
| 231 |
"\n",
|
| 232 |
"print(f'Users with {MIN_EVENTS}+ events: {len(user_sequences):,}')\n",
|
| 233 |
"print(f'Total events: {sum(len(s) for s in user_sequences):,}')\n",
|
| 234 |
-
"print(f'Events
|
| 235 |
-
" f'max={max(len(s) for s in user_sequences)}, '\n",
|
| 236 |
-
" f'mean={np.mean([len(s) for s in user_sequences]):.1f}')\n",
|
| 237 |
"print(f'Fraud rate (user-level): {np.mean(user_fraud_labels)*100:.2f}%')"
|
| 238 |
]
|
| 239 |
},
|
|
@@ -243,8 +236,7 @@
|
|
| 243 |
"source": [
|
| 244 |
"## Step 4 — Build Domain Tokenizer\n",
|
| 245 |
"\n",
|
| 246 |
-
"Hybrid vocabulary: 97 special tokens (sign + amount bins + calendar) + BPE for descriptions.
|
| 247 |
-
"Following Nubank nuFormer's tokenization approach."
|
| 248 |
]
|
| 249 |
},
|
| 250 |
{
|
|
@@ -262,16 +254,10 @@
|
|
| 262 |
"text_corpus = [e['description'] for e in all_events]\n",
|
| 263 |
"unique_descs = sorted(set(text_corpus))\n",
|
| 264 |
"print(f'Unique descriptions: {len(unique_descs)}')\n",
|
| 265 |
-
"for d in unique_descs[:10]:\n",
|
| 266 |
-
"
|
| 267 |
-
"if len(unique_descs) > 10:\n",
|
| 268 |
-
" print(f' ... and {len(unique_descs) - 10} more')\n",
|
| 269 |
-
"\n",
|
| 270 |
-
"hf_tokenizer = builder.build(\n",
|
| 271 |
-
" text_corpus=text_corpus,\n",
|
| 272 |
-
" bpe_vocab_size=2000,\n",
|
| 273 |
-
")\n",
|
| 274 |
"\n",
|
|
|
|
| 275 |
"print(f'\\nVocab size: {hf_tokenizer.vocab_size}')\n",
|
| 276 |
"print(f'Stats: {builder.get_stats()}')"
|
| 277 |
]
|
|
@@ -282,16 +268,12 @@
|
|
| 282 |
"metadata": {},
|
| 283 |
"outputs": [],
|
| 284 |
"source": [
|
| 285 |
-
"# Inspect tokenized output\n",
|
| 286 |
"print('--- Sample event tokenized ---')\n",
|
| 287 |
-
"
|
| 288 |
-
"for i, t in enumerate(sample_tokens):\n",
|
| 289 |
-
" print(f' [{i}] {t}')\n",
|
| 290 |
"\n",
|
| 291 |
"print(f'\\n--- First user, first 3 events ---')\n",
|
| 292 |
"seq_tokens = builder.tokenize_sequence(user_sequences[0][:3])\n",
|
| 293 |
-
"for i, t in enumerate(seq_tokens):\n",
|
| 294 |
-
" print(f' [{i:3d}] {t}')\n",
|
| 295 |
"\n",
|
| 296 |
"seq_ids = hf_tokenizer(' '.join(seq_tokens), add_special_tokens=False)['input_ids']\n",
|
| 297 |
"unk_id = hf_tokenizer.unk_token_id\n",
|
|
@@ -305,8 +287,7 @@
|
|
| 305 |
"source": [
|
| 306 |
"## Step 5 — Pack into CLM Training Dataset\n",
|
| 307 |
"\n",
|
| 308 |
-
"Sequence packing
|
| 309 |
-
"100% token utilization, zero padding waste."
|
| 310 |
]
|
| 311 |
},
|
| 312 |
{
|
|
@@ -316,13 +297,8 @@
|
|
| 316 |
"outputs": [],
|
| 317 |
"source": [
|
| 318 |
"%%time\n",
|
| 319 |
-
"BLOCK_SIZE = 512
|
| 320 |
-
"\n",
|
| 321 |
-
"dataset = prepare_clm_dataset(\n",
|
| 322 |
-
" user_sequences, builder, hf_tokenizer,\n",
|
| 323 |
-
" block_size=BLOCK_SIZE,\n",
|
| 324 |
-
")\n",
|
| 325 |
-
"\n",
|
| 326 |
"print(f'Packed: {len(dataset):,} blocks x {BLOCK_SIZE} = {len(dataset)*BLOCK_SIZE:,} training tokens')"
|
| 327 |
]
|
| 328 |
},
|
|
@@ -332,25 +308,15 @@
|
|
| 332 |
"metadata": {},
|
| 333 |
"outputs": [],
|
| 334 |
"source": [
|
| 335 |
-
"# Decode a sample block to verify it looks right\n",
|
| 336 |
-
"sample_block = dataset[0]['input_ids']\n",
|
| 337 |
"print(f'Sample block decoded (first 60 tokens):')\n",
|
| 338 |
-
"print(hf_tokenizer.decode(
|
| 339 |
"\n",
|
| 340 |
-
"# Token frequency analysis\n",
|
| 341 |
"all_ids = [i for row in dataset for i in row['input_ids']]\n",
|
| 342 |
"counts = Counter(all_ids)\n",
|
| 343 |
-
"
|
| 344 |
-
"\n",
|
| 345 |
-
"print(f'\\nTotal tokens: {len(all_ids):,}')\n",
|
| 346 |
-
"print(f'Unique token IDs used: {len(counts)}/{hf_tokenizer.vocab_size}')\n",
|
| 347 |
-
"print(f'UNK tokens: {counts.get(unk_id, 0):,} ({unk_pct:.2f}%)')\n",
|
| 348 |
-
"\n",
|
| 349 |
"print(f'\\nTop 20 tokens:')\n",
|
| 350 |
"for tid, count in counts.most_common(20):\n",
|
| 351 |
-
"
|
| 352 |
-
" pct = count / len(all_ids) * 100\n",
|
| 353 |
-
" print(f' {tid:5d} {count:8,} ({pct:5.1f}%) {tok_str}')"
|
| 354 |
]
|
| 355 |
},
|
| 356 |
{
|
|
@@ -359,11 +325,12 @@
|
|
| 359 |
"source": [
|
| 360 |
"## Step 6 — Pre-Train 24M DomainTransformer\n",
|
| 361 |
"\n",
|
| 362 |
-
"Architecture
|
| 363 |
"- GPT-style causal decoder, NoPE (no positional encoding)\n",
|
| 364 |
"- 24M preset: d=512, 6 layers, 8 heads, FFN=2048\n",
|
| 365 |
-
"- Cosine LR schedule with warmup, AdamW
|
| 366 |
-
"- CLM objective (next token prediction
|
|
|
|
| 367 |
]
|
| 368 |
},
|
| 369 |
{
|
|
@@ -374,11 +341,8 @@
|
|
| 374 |
"source": [
|
| 375 |
"config = DomainTransformerConfig.from_preset('24m', vocab_size=hf_tokenizer.vocab_size)\n",
|
| 376 |
"model = DomainTransformerForCausalLM(config)\n",
|
| 377 |
-
"\n",
|
| 378 |
"n_params = sum(p.numel() for p in model.parameters())\n",
|
| 379 |
-
"print(f'Model: {n_params:,}
|
| 380 |
-
"print(f'Config: d={config.hidden_size}, L={config.num_hidden_layers}, H={config.num_attention_heads}')\n",
|
| 381 |
-
"print(f'VRAM estimate: ~{n_params * 2 / 1e9:.1f}GB (bf16 training with optimizer states ~3x)')"
|
| 382 |
]
|
| 383 |
},
|
| 384 |
{
|
|
@@ -395,7 +359,7 @@
|
|
| 395 |
" tokenizer=hf_tokenizer,\n",
|
| 396 |
" train_dataset=dataset,\n",
|
| 397 |
" output_dir='./finance_pretrain_checkpoints',\n",
|
| 398 |
-
" hub_model_id=
|
| 399 |
" num_epochs=3 if USE_GPU else 1,\n",
|
| 400 |
" per_device_batch_size=32 if USE_GPU else 4,\n",
|
| 401 |
" gradient_accumulation_steps=4 if USE_GPU else 1,\n",
|
|
@@ -404,7 +368,8 @@
|
|
| 404 |
" logging_steps=50 if USE_GPU else 10,\n",
|
| 405 |
" save_steps=1000 if USE_GPU else 999999,\n",
|
| 406 |
" bf16=USE_GPU,\n",
|
| 407 |
-
" report_to='
|
|
|
|
| 408 |
" seed=42,\n",
|
| 409 |
")"
|
| 410 |
]
|
|
@@ -422,9 +387,7 @@
|
|
| 422 |
"metadata": {},
|
| 423 |
"outputs": [],
|
| 424 |
"source": [
|
| 425 |
-
"# Loss curve\n",
|
| 426 |
"losses = [h['loss'] for h in trainer.state.log_history if 'loss' in h]\n",
|
| 427 |
-
"\n",
|
| 428 |
"print(f'Steps: {trainer.state.global_step:,}')\n",
|
| 429 |
"print(f'Loss: {losses[0]:.4f} -> {losses[-1]:.4f} ({(1-losses[-1]/losses[0])*100:.1f}% reduction)')\n",
|
| 430 |
"print(f'Min loss: {min(losses):.4f}')\n",
|
|
@@ -433,15 +396,9 @@
|
|
| 433 |
"ax.plot(losses, linewidth=0.5, alpha=0.5, label='Per-step')\n",
|
| 434 |
"window = max(len(losses) // 50, 1)\n",
|
| 435 |
"if len(losses) > window:\n",
|
| 436 |
-
"
|
| 437 |
-
"
|
| 438 |
-
"ax.
|
| 439 |
-
"ax.set_ylabel('Loss')\n",
|
| 440 |
-
"ax.set_title('Pre-Training Loss Curve')\n",
|
| 441 |
-
"ax.legend()\n",
|
| 442 |
-
"ax.grid(True, alpha=0.3)\n",
|
| 443 |
-
"plt.tight_layout()\n",
|
| 444 |
-
"plt.show()"
|
| 445 |
]
|
| 446 |
},
|
| 447 |
{
|
|
@@ -450,25 +407,18 @@
|
|
| 450 |
"metadata": {},
|
| 451 |
"outputs": [],
|
| 452 |
"source": [
|
| 453 |
-
"# Next-token prediction test\n",
|
| 454 |
"model.eval()\n",
|
| 455 |
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
| 456 |
"model = model.to(device)\n",
|
| 457 |
"\n",
|
| 458 |
-
"
|
| 459 |
-
"test_ids = hf_tokenizer(' '.join(test_tokens), return_tensors='pt', add_special_tokens=False)['input_ids'].to(device)\n",
|
| 460 |
-
"\n",
|
| 461 |
"with torch.no_grad():\n",
|
| 462 |
-
"
|
| 463 |
-
" top5 = torch.topk(logits[0, -1, :], 5)\n",
|
| 464 |
"\n",
|
| 465 |
"print('Last 5 input tokens:')\n",
|
| 466 |
-
"for tid in test_ids[0, -5:]:\n",
|
| 467 |
-
" print(f\" {tid.item():5d} -> '{hf_tokenizer.decode([tid.item()])}'\")\n",
|
| 468 |
-
"\n",
|
| 469 |
"print('\\nTop-5 next token predictions:')\n",
|
| 470 |
-
"for score, tid in zip(top5.values, top5.indices):\
|
| 471 |
-
" print(f\" {tid.item():5d} -> '{hf_tokenizer.decode([tid.item()])}' (score={score.item():.3f})\")"
|
| 472 |
]
|
| 473 |
},
|
| 474 |
{
|
|
@@ -477,45 +427,35 @@
|
|
| 477 |
"metadata": {},
|
| 478 |
"outputs": [],
|
| 479 |
"source": [
|
| 480 |
-
"#
|
| 481 |
"n_sample = min(200, len(user_sequences))\n",
|
| 482 |
-
"embeddings = []\n",
|
| 483 |
-
"labels_sample = []\n",
|
| 484 |
-
"\n",
|
| 485 |
"for i in range(n_sample):\n",
|
| 486 |
-
"
|
| 487 |
-
"
|
| 488 |
-
" max_length=256, truncation=True, padding='max_length')\n",
|
| 489 |
" with torch.no_grad():\n",
|
| 490 |
-
"
|
| 491 |
-
" embeddings.append(emb.cpu().numpy().flatten())\n",
|
| 492 |
" labels_sample.append(user_fraud_labels[i])\n",
|
| 493 |
"\n",
|
| 494 |
-
"embeddings = np.array(embeddings)\n",
|
| 495 |
-
"labels_sample = np.array(labels_sample)\n",
|
| 496 |
"print(f'Embeddings: {embeddings.shape}, Fraud: {labels_sample.sum()}/{len(labels_sample)}')\n",
|
| 497 |
"\n",
|
| 498 |
"if len(embeddings) >= 20:\n",
|
| 499 |
" from sklearn.manifold import TSNE\n",
|
| 500 |
" coords = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1)).fit_transform(embeddings)\n",
|
| 501 |
-
" \n",
|
| 502 |
" fig, ax = plt.subplots(figsize=(8, 6))\n",
|
| 503 |
" for label, color, name in [(0, 'tab:green', 'Normal'), (1, 'tab:red', 'Fraud')]:\n",
|
| 504 |
" mask = labels_sample == label\n",
|
| 505 |
" ax.scatter(coords[mask, 0], coords[mask, 1], c=color, label=name, alpha=0.6, edgecolors='black', linewidth=0.3, s=30)\n",
|
| 506 |
-
" ax.set_title('User Embeddings (t-SNE) — Pre-trained DomainTransformer')\n",
|
| 507 |
-
"
|
| 508 |
-
" plt.tight_layout()\n",
|
| 509 |
-
" plt.show()"
|
| 510 |
]
|
| 511 |
},
|
| 512 |
{
|
| 513 |
"cell_type": "markdown",
|
| 514 |
"metadata": {},
|
| 515 |
"source": [
|
| 516 |
-
"## Save Artifacts
|
| 517 |
-
"\n",
|
| 518 |
-
"Saves the pre-trained model, tokenizer, and user data so `02_finance_finetune.ipynb` can pick up where we left off."
|
| 519 |
]
|
| 520 |
},
|
| 521 |
{
|
|
@@ -524,26 +464,23 @@
|
|
| 524 |
"metadata": {},
|
| 525 |
"outputs": [],
|
| 526 |
"source": [
|
| 527 |
-
"# Save tokenizer\n",
|
| 528 |
"hf_tokenizer.save_pretrained('./finance_tokenizer')\n",
|
| 529 |
"builder.save('./finance_tokenizer')\n",
|
| 530 |
-
"\n",
|
| 531 |
-
"# Save model\n",
|
| 532 |
"model.save_pretrained('./finance_pretrain_checkpoints/final')\n",
|
| 533 |
"\n",
|
| 534 |
-
"# Save user data\n",
|
| 535 |
-
"artifacts = {\n",
|
| 536 |
-
" 'user_sequences': user_sequences,\n",
|
| 537 |
-
" 'user_ids': user_ids,\n",
|
| 538 |
-
" 'user_fraud_labels': user_fraud_labels,\n",
|
| 539 |
-
"}\n",
|
| 540 |
"with open('./finance_artifacts.pkl', 'wb') as f:\n",
|
| 541 |
-
" pickle.dump(
|
| 542 |
"\n",
|
| 543 |
-
"print('Saved:
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
]
|
| 548 |
},
|
| 549 |
{
|
|
|
|
| 38 |
"outputs": [],
|
| 39 |
"source": [
|
| 40 |
"# Uncomment and run once to install dependencies:\n",
|
| 41 |
+
"# !pip install datasets transformers torch accelerate tokenizers numpy pandas matplotlib scikit-learn wandb"
|
| 42 |
]
|
| 43 |
},
|
| 44 |
{
|
|
|
|
| 75 |
"logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')\n",
|
| 76 |
"print(f'torch: {torch.__version__}, CUDA: {torch.cuda.is_available()}')\n",
|
| 77 |
"if torch.cuda.is_available():\n",
|
| 78 |
+
" print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB')"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"cell_type": "code",
|
| 83 |
+
"execution_count": null,
|
| 84 |
+
"metadata": {},
|
| 85 |
+
"outputs": [],
|
| 86 |
+
"source": [
|
| 87 |
+
"# wandb setup — logs persist even if notebook kernel disconnects\n",
|
| 88 |
+
"# Run `wandb login` in terminal first, or set WANDB_API_KEY env var\n",
|
| 89 |
+
"import wandb\n",
|
| 90 |
+
"wandb.login()\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"WANDB_PROJECT = 'domainTokenizer' # all runs grouped under this project\n",
|
| 93 |
+
"os.environ['WANDB_PROJECT'] = WANDB_PROJECT\n",
|
| 94 |
+
"print(f'wandb project: {WANDB_PROJECT}')"
|
| 95 |
]
|
| 96 |
},
|
| 97 |
{
|
|
|
|
| 159 |
"metadata": {},
|
| 160 |
"outputs": [],
|
| 161 |
"source": [
|
|
|
|
| 162 |
"events_per_user = df.groupby('sender_account').size()\n",
|
| 163 |
"print(f\"Events per user: min={events_per_user.min()}, max={events_per_user.max()}, \"\n",
|
| 164 |
" f\"mean={events_per_user.mean():.1f}, median={events_per_user.median():.1f}\")\n",
|
|
|
|
| 166 |
"print(f\"Users with 10+ events: {(events_per_user >= 10).sum():,}\")\n",
|
| 167 |
"\n",
|
| 168 |
"fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
|
|
|
|
| 169 |
"axes[0].hist(np.log10(df['amount_ngn'].clip(lower=1)), bins=50, edgecolor='black', alpha=0.7)\n",
|
| 170 |
+
"axes[0].set_xlabel('log10(Amount NGN)'); axes[0].set_ylabel('Count'); axes[0].set_title('Amount Distribution (log scale)')\n",
|
|
|
|
|
|
|
|
|
|
| 171 |
"axes[1].hist(events_per_user.clip(upper=50), bins=50, edgecolor='black', alpha=0.7)\n",
|
| 172 |
+
"axes[1].set_xlabel('Events per User'); axes[1].set_ylabel('Count'); axes[1].set_title('Events per User')\n",
|
|
|
|
|
|
|
|
|
|
| 173 |
"df['transaction_type'].value_counts().head(10).plot(kind='barh', ax=axes[2])\n",
|
| 174 |
+
"axes[2].set_xlabel('Count'); axes[2].set_title('Transaction Types')\n",
|
| 175 |
+
"plt.tight_layout(); plt.show()"
|
|
|
|
|
|
|
|
|
|
| 176 |
]
|
| 177 |
},
|
| 178 |
{
|
|
|
|
| 195 |
"outputs": [],
|
| 196 |
"source": [
|
| 197 |
"def row_to_event(row):\n",
|
|
|
|
| 198 |
" dt = datetime.strptime(row['timestamp'][:19], '%Y-%m-%d %H:%M:%S')\n",
|
| 199 |
" desc = f\"{row['merchant_category']} {row['transaction_type']}\"\n",
|
| 200 |
" amt = row['amount_ngn']\n",
|
| 201 |
" if row['transaction_type'] == 'withdrawal':\n",
|
| 202 |
" amt = -abs(amt)\n",
|
| 203 |
+
" return {'amount_sign': amt, 'amount': amt, 'timestamp': dt, 'description': desc}\n",
|
| 204 |
+
"\n",
|
| 205 |
+
"print(f'Sample event: {row_to_event(df.iloc[0])}')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
]
|
| 207 |
},
|
| 208 |
{
|
|
|
|
| 213 |
"source": [
|
| 214 |
"%%time\n",
|
| 215 |
"MIN_EVENTS = 5\n",
|
| 216 |
+
"MAX_EVENTS = 500\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
"\n",
|
| 218 |
+
"user_sequences, user_ids, user_fraud_labels = [], [], []\n",
|
| 219 |
"for sender, group in df.sort_values('timestamp').groupby('sender_account'):\n",
|
| 220 |
" if len(group) < MIN_EVENTS:\n",
|
| 221 |
" continue\n",
|
|
|
|
| 226 |
"\n",
|
| 227 |
"print(f'Users with {MIN_EVENTS}+ events: {len(user_sequences):,}')\n",
|
| 228 |
"print(f'Total events: {sum(len(s) for s in user_sequences):,}')\n",
|
| 229 |
+
"print(f'Events/user: min={min(len(s) for s in user_sequences)}, max={max(len(s) for s in user_sequences)}, mean={np.mean([len(s) for s in user_sequences]):.1f}')\n",
|
|
|
|
|
|
|
| 230 |
"print(f'Fraud rate (user-level): {np.mean(user_fraud_labels)*100:.2f}%')"
|
| 231 |
]
|
| 232 |
},
|
|
|
|
| 236 |
"source": [
|
| 237 |
"## Step 4 — Build Domain Tokenizer\n",
|
| 238 |
"\n",
|
| 239 |
+
"Hybrid vocabulary: 97 special tokens (sign + amount bins + calendar) + BPE for descriptions."
|
|
|
|
| 240 |
]
|
| 241 |
},
|
| 242 |
{
|
|
|
|
| 254 |
"text_corpus = [e['description'] for e in all_events]\n",
|
| 255 |
"unique_descs = sorted(set(text_corpus))\n",
|
| 256 |
"print(f'Unique descriptions: {len(unique_descs)}')\n",
|
| 257 |
+
"for d in unique_descs[:10]: print(f\" '{d}'\")\n",
|
| 258 |
+
"if len(unique_descs) > 10: print(f' ... and {len(unique_descs) - 10} more')\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
"\n",
|
| 260 |
+
"hf_tokenizer = builder.build(text_corpus=text_corpus, bpe_vocab_size=2000)\n",
|
| 261 |
"print(f'\\nVocab size: {hf_tokenizer.vocab_size}')\n",
|
| 262 |
"print(f'Stats: {builder.get_stats()}')"
|
| 263 |
]
|
|
|
|
| 268 |
"metadata": {},
|
| 269 |
"outputs": [],
|
| 270 |
"source": [
|
|
|
|
| 271 |
"print('--- Sample event tokenized ---')\n",
|
| 272 |
+
"for i, t in enumerate(builder.tokenize_event(user_sequences[0][0])): print(f' [{i}] {t}')\n",
|
|
|
|
|
|
|
| 273 |
"\n",
|
| 274 |
"print(f'\\n--- First user, first 3 events ---')\n",
|
| 275 |
"seq_tokens = builder.tokenize_sequence(user_sequences[0][:3])\n",
|
| 276 |
+
"for i, t in enumerate(seq_tokens): print(f' [{i:3d}] {t}')\n",
|
|
|
|
| 277 |
"\n",
|
| 278 |
"seq_ids = hf_tokenizer(' '.join(seq_tokens), add_special_tokens=False)['input_ids']\n",
|
| 279 |
"unk_id = hf_tokenizer.unk_token_id\n",
|
|
|
|
| 287 |
"source": [
|
| 288 |
"## Step 5 — Pack into CLM Training Dataset\n",
|
| 289 |
"\n",
|
| 290 |
+
"Sequence packing: concatenate all user sequences, split into fixed-length blocks. 100% token utilization."
|
|
|
|
| 291 |
]
|
| 292 |
},
|
| 293 |
{
|
|
|
|
| 297 |
"outputs": [],
|
| 298 |
"source": [
|
| 299 |
"%%time\n",
|
| 300 |
+
"BLOCK_SIZE = 512\n",
|
| 301 |
+
"dataset = prepare_clm_dataset(user_sequences, builder, hf_tokenizer, block_size=BLOCK_SIZE)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
"print(f'Packed: {len(dataset):,} blocks x {BLOCK_SIZE} = {len(dataset)*BLOCK_SIZE:,} training tokens')"
|
| 303 |
]
|
| 304 |
},
|
|
|
|
| 308 |
"metadata": {},
|
| 309 |
"outputs": [],
|
| 310 |
"source": [
|
|
|
|
|
|
|
| 311 |
"print(f'Sample block decoded (first 60 tokens):')\n",
|
| 312 |
+
"print(hf_tokenizer.decode(dataset[0]['input_ids'][:60]))\n",
|
| 313 |
"\n",
|
|
|
|
| 314 |
"all_ids = [i for row in dataset for i in row['input_ids']]\n",
|
| 315 |
"counts = Counter(all_ids)\n",
|
| 316 |
+
"print(f'\\nTotal tokens: {len(all_ids):,}, Unique: {len(counts)}/{hf_tokenizer.vocab_size}, UNK: {counts.get(unk_id,0)} ({counts.get(unk_id,0)/len(all_ids)*100:.2f}%)')\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
"print(f'\\nTop 20 tokens:')\n",
|
| 318 |
"for tid, count in counts.most_common(20):\n",
|
| 319 |
+
" print(f' {tid:5d} {count:8,} ({count/len(all_ids)*100:5.1f}%) {hf_tokenizer.decode([tid]).strip() or \"(space)\"}')"
|
|
|
|
|
|
|
| 320 |
]
|
| 321 |
},
|
| 322 |
{
|
|
|
|
| 325 |
"source": [
|
| 326 |
"## Step 6 — Pre-Train 24M DomainTransformer\n",
|
| 327 |
"\n",
|
| 328 |
+
"Architecture:\n",
|
| 329 |
"- GPT-style causal decoder, NoPE (no positional encoding)\n",
|
| 330 |
"- 24M preset: d=512, 6 layers, 8 heads, FFN=2048\n",
|
| 331 |
+
"- Cosine LR schedule with warmup, AdamW\n",
|
| 332 |
+
"- CLM objective (next token prediction)\n",
|
| 333 |
+
"- wandb logging for persistent monitoring"
|
| 334 |
]
|
| 335 |
},
|
| 336 |
{
|
|
|
|
| 341 |
"source": [
|
| 342 |
"config = DomainTransformerConfig.from_preset('24m', vocab_size=hf_tokenizer.vocab_size)\n",
|
| 343 |
"model = DomainTransformerForCausalLM(config)\n",
|
|
|
|
| 344 |
"n_params = sum(p.numel() for p in model.parameters())\n",
|
| 345 |
+
"print(f'Model: {n_params:,} params | d={config.hidden_size}, L={config.num_hidden_layers}, H={config.num_attention_heads}')"
|
|
|
|
|
|
|
| 346 |
]
|
| 347 |
},
|
| 348 |
{
|
|
|
|
| 359 |
" tokenizer=hf_tokenizer,\n",
|
| 360 |
" train_dataset=dataset,\n",
|
| 361 |
" output_dir='./finance_pretrain_checkpoints',\n",
|
| 362 |
+
" hub_model_id='rtferraz/finance-domain-24m',\n",
|
| 363 |
" num_epochs=3 if USE_GPU else 1,\n",
|
| 364 |
" per_device_batch_size=32 if USE_GPU else 4,\n",
|
| 365 |
" gradient_accumulation_steps=4 if USE_GPU else 1,\n",
|
|
|
|
| 368 |
" logging_steps=50 if USE_GPU else 10,\n",
|
| 369 |
" save_steps=1000 if USE_GPU else 999999,\n",
|
| 370 |
" bf16=USE_GPU,\n",
|
| 371 |
+
" report_to='wandb',\n",
|
| 372 |
+
" run_name='finance-pretrain-24m-3ep',\n",
|
| 373 |
" seed=42,\n",
|
| 374 |
")"
|
| 375 |
]
|
|
|
|
| 387 |
"metadata": {},
|
| 388 |
"outputs": [],
|
| 389 |
"source": [
|
|
|
|
| 390 |
"losses = [h['loss'] for h in trainer.state.log_history if 'loss' in h]\n",
|
|
|
|
| 391 |
"print(f'Steps: {trainer.state.global_step:,}')\n",
|
| 392 |
"print(f'Loss: {losses[0]:.4f} -> {losses[-1]:.4f} ({(1-losses[-1]/losses[0])*100:.1f}% reduction)')\n",
|
| 393 |
"print(f'Min loss: {min(losses):.4f}')\n",
|
|
|
|
| 396 |
"ax.plot(losses, linewidth=0.5, alpha=0.5, label='Per-step')\n",
|
| 397 |
"window = max(len(losses) // 50, 1)\n",
|
| 398 |
"if len(losses) > window:\n",
|
| 399 |
+
" ax.plot(pd.Series(losses).rolling(window=window, min_periods=1).mean(), linewidth=2, color='red', label=f'Smoothed (w={window})')\n",
|
| 400 |
+
"ax.set_xlabel('Step'); ax.set_ylabel('Loss'); ax.set_title('Pre-Training Loss Curve')\n",
|
| 401 |
+
"ax.legend(); ax.grid(True, alpha=0.3); plt.tight_layout(); plt.show()"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
]
|
| 403 |
},
|
| 404 |
{
|
|
|
|
| 407 |
"metadata": {},
|
| 408 |
"outputs": [],
|
| 409 |
"source": [
|
|
|
|
| 410 |
"model.eval()\n",
|
| 411 |
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
| 412 |
"model = model.to(device)\n",
|
| 413 |
"\n",
|
| 414 |
+
"test_ids = hf_tokenizer(' '.join(builder.tokenize_sequence(user_sequences[0][:3])), return_tensors='pt', add_special_tokens=False)['input_ids'].to(device)\n",
|
|
|
|
|
|
|
| 415 |
"with torch.no_grad():\n",
|
| 416 |
+
" top5 = torch.topk(model(input_ids=test_ids).logits[0, -1, :], 5)\n",
|
|
|
|
| 417 |
"\n",
|
| 418 |
"print('Last 5 input tokens:')\n",
|
| 419 |
+
"for tid in test_ids[0, -5:]: print(f\" {tid.item():5d} -> '{hf_tokenizer.decode([tid.item()])}'\")\n",
|
|
|
|
|
|
|
| 420 |
"print('\\nTop-5 next token predictions:')\n",
|
| 421 |
+
"for score, tid in zip(top5.values, top5.indices): print(f\" {tid.item():5d} -> '{hf_tokenizer.decode([tid.item()])}' (score={score.item():.3f})\")"
|
|
|
|
| 422 |
]
|
| 423 |
},
|
| 424 |
{
|
|
|
|
| 427 |
"metadata": {},
|
| 428 |
"outputs": [],
|
| 429 |
"source": [
|
| 430 |
+
"# t-SNE user embeddings colored by fraud label\n",
|
| 431 |
"n_sample = min(200, len(user_sequences))\n",
|
| 432 |
+
"embeddings, labels_sample = [], []\n",
|
|
|
|
|
|
|
| 433 |
"for i in range(n_sample):\n",
|
| 434 |
+
" enc = hf_tokenizer(' '.join(builder.tokenize_sequence(user_sequences[i][:50])),\n",
|
| 435 |
+
" return_tensors='pt', add_special_tokens=False, max_length=256, truncation=True, padding='max_length')\n",
|
|
|
|
| 436 |
" with torch.no_grad():\n",
|
| 437 |
+
" embeddings.append(model.get_user_embedding(enc['input_ids'].to(device), enc['attention_mask'].to(device)).cpu().numpy().flatten())\n",
|
|
|
|
| 438 |
" labels_sample.append(user_fraud_labels[i])\n",
|
| 439 |
"\n",
|
| 440 |
+
"embeddings = np.array(embeddings); labels_sample = np.array(labels_sample)\n",
|
|
|
|
| 441 |
"print(f'Embeddings: {embeddings.shape}, Fraud: {labels_sample.sum()}/{len(labels_sample)}')\n",
|
| 442 |
"\n",
|
| 443 |
"if len(embeddings) >= 20:\n",
|
| 444 |
" from sklearn.manifold import TSNE\n",
|
| 445 |
" coords = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1)).fit_transform(embeddings)\n",
|
|
|
|
| 446 |
" fig, ax = plt.subplots(figsize=(8, 6))\n",
|
| 447 |
" for label, color, name in [(0, 'tab:green', 'Normal'), (1, 'tab:red', 'Fraud')]:\n",
|
| 448 |
" mask = labels_sample == label\n",
|
| 449 |
" ax.scatter(coords[mask, 0], coords[mask, 1], c=color, label=name, alpha=0.6, edgecolors='black', linewidth=0.3, s=30)\n",
|
| 450 |
+
" ax.set_title('User Embeddings (t-SNE) — Pre-trained DomainTransformer'); ax.legend()\n",
|
| 451 |
+
" plt.tight_layout(); plt.show()"
|
|
|
|
|
|
|
| 452 |
]
|
| 453 |
},
|
| 454 |
{
|
| 455 |
"cell_type": "markdown",
|
| 456 |
"metadata": {},
|
| 457 |
"source": [
|
| 458 |
+
"## Save Artifacts"
|
|
|
|
|
|
|
| 459 |
]
|
| 460 |
},
|
| 461 |
{
|
|
|
|
| 464 |
"metadata": {},
|
| 465 |
"outputs": [],
|
| 466 |
"source": [
|
|
|
|
| 467 |
"hf_tokenizer.save_pretrained('./finance_tokenizer')\n",
|
| 468 |
"builder.save('./finance_tokenizer')\n",
|
|
|
|
|
|
|
| 469 |
"model.save_pretrained('./finance_pretrain_checkpoints/final')\n",
|
| 470 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
"with open('./finance_artifacts.pkl', 'wb') as f:\n",
|
| 472 |
+
" pickle.dump({'user_sequences': user_sequences, 'user_ids': user_ids, 'user_fraud_labels': user_fraud_labels}, f)\n",
|
| 473 |
"\n",
|
| 474 |
+
"print('Saved: ./finance_tokenizer/, ./finance_pretrain_checkpoints/final/, ./finance_artifacts.pkl')"
|
| 475 |
+
]
|
| 476 |
+
},
|
| 477 |
+
{
|
| 478 |
+
"cell_type": "code",
|
| 479 |
+
"execution_count": null,
|
| 480 |
+
"metadata": {},
|
| 481 |
+
"outputs": [],
|
| 482 |
+
"source": [
|
| 483 |
+
"wandb.finish() # close wandb run cleanly"
|
| 484 |
]
|
| 485 |
},
|
| 486 |
{
|