asdf98 commited on
Commit
1754427
·
verified ·
1 Parent(s): ed4cea5

v3: Fix 429 rate limit — download-then-train with retry/backoff, disk-based DataLoader, 8 threads, HF token auth

Browse files
Files changed (1) hide show
  1. train_bokehflow.ipynb +205 -193
train_bokehflow.ipynb CHANGED
@@ -5,17 +5,18 @@
5
  "metadata": {},
6
  "source": [
7
  "# 🎬 BokehFlow Training Notebook\n",
8
- "## Zero-download streaming starts training in ~5 seconds\n",
9
  "\n",
10
- "**How it works:** Metadata (3960 tiny JSONs) fetched async in 3s. Images streamed on-demand via HTTP during training. **Zero disk usage, zero wait.**\n",
 
 
 
 
 
11
  "\n",
12
- "| Platform | GPU | Batch/s | Notes |\n",
13
- "|----------|-----|---------|-------|\n",
14
- "| Colab Free | T4 16GB | ~2-3s | 4 workers, prefetch hides latency |\n",
15
- "| Kaggle | 2×T4 | ~1.5s | DataParallel + 8 workers |\n",
16
- "| Colab Pro | A100 | ~1s | 8 workers |\n",
17
  "\n",
18
- "**Just run all cells. No config changes needed.**"
19
  ]
20
  },
21
  {
@@ -24,8 +25,8 @@
24
  "metadata": {},
25
  "outputs": [],
26
  "source": [
27
- "#@title Step 0: Install (15s)\n",
28
- "!pip install -q torch torchvision Pillow huggingface_hub tqdm aiohttp"
29
  ]
30
  },
31
  {
@@ -34,10 +35,10 @@
34
  "metadata": {},
35
  "outputs": [],
36
  "source": [
37
- "#@title Step 1: Download BokehFlow model code (2s)\n",
38
  "from huggingface_hub import hf_hub_download\n",
39
  "hf_hub_download(repo_id='asdf98/BokehFlow', filename='bokehflow.py', local_dir='.')\n",
40
- "print('✓ BokehFlow downloaded')"
41
  ]
42
  },
43
  {
@@ -48,26 +49,33 @@
48
  "source": [
49
  "#@title Step 2: Config\n",
50
  "CONFIG = {\n",
 
51
  " 'variant': 'nano', # 'nano'=583K, 'small'=3.1M, 'base'=12M\n",
 
 
 
 
 
 
 
 
52
  " 'batch_size': 4, # 4 for T4, 8 for A100\n",
53
- " 'crop_size': 256, # Training crop size\n",
54
- " 'num_epochs': 5,\n",
55
  " 'lr': 3e-4,\n",
56
  " 'weight_decay': 0.05,\n",
57
  " 'max_grad_norm': 1.0,\n",
58
- " 'num_workers': 4, # 4 for Colab, 8 for Kaggle\n",
59
- " 'target_fstop': 2.0, # Train on max bokeh (f/2.0)\n",
60
- " 'max_samples': None, # None=all 3958, or set 200 for quick test\n",
61
  " 'output_dir': './checkpoints',\n",
62
  "}\n",
63
  "\n",
64
- "import torch\n",
65
  "NUM_GPUS = torch.cuda.device_count()\n",
66
  "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
67
  "print(f'Device: {DEVICE}' + (f' ({torch.cuda.get_device_name(0)})' if torch.cuda.is_available() else ''))\n",
68
  "if NUM_GPUS > 1:\n",
69
- " CONFIG['num_workers'] = 8\n",
70
- " print(f'Kaggle dual-GPU detected → {NUM_GPUS} GPUs, {CONFIG[\"num_workers\"]} workers')"
 
71
  ]
72
  },
73
  {
@@ -76,88 +84,158 @@
76
  "metadata": {},
77
  "outputs": [],
78
  "source": [
79
- "#@title Step 3: Streaming DatasetNO download, starts in ~3s\n",
80
- "import asyncio, aiohttp, json, io, os, random, time, requests\n",
81
- "from PIL import Image\n",
82
- "from torch.utils.data import Dataset, DataLoader\n",
83
- "from torchvision import transforms\n",
84
- "from concurrent.futures import ThreadPoolExecutor\n",
85
  "\n",
86
  "HF_BASE = 'https://huggingface.co/datasets/timseizinger/RealBokeh_3MP/resolve/main'\n",
 
87
  "\n",
88
- "# ---- Async metadata fetch (3960 JSONs in ~3s) ----\n",
89
- "async def _fetch_all_metadata(split='train', concurrency=50):\n",
90
- " split_counts = {'train': 3960, 'validation': 220, 'test': 220}\n",
91
- " n = split_counts.get(split, 220)\n",
92
- " async def fetch_one(session, sem, sid):\n",
 
 
 
93
  " async with sem:\n",
94
- " url = f'{HF_BASE}/{split}/metadata/{sid}.json'\n",
95
  " try:\n",
96
  " async with session.get(url) as r:\n",
97
- " if r.status == 200:\n",
98
- " return await r.json(content_type=None)\n",
99
- " except:\n",
100
- " pass\n",
101
  " return None\n",
102
- " sem = asyncio.Semaphore(concurrency)\n",
103
- " conn = aiohttp.TCPConnector(limit=concurrency, force_close=False)\n",
104
- " async with aiohttp.ClientSession(connector=conn) as session:\n",
105
- " results = await asyncio.gather(*[fetch_one(session, sem, i) for i in range(1, n+1)])\n",
106
- " return [r for r in results if r is not None]\n",
107
- "\n",
108
- "def _build_pairs(metas, split, target_fstop=None):\n",
109
- " pairs = []\n",
110
- " for m in metas:\n",
111
- " for tgt_path, tgt_av in zip(m['target_images'], m['target_avs']):\n",
112
- " if target_fstop is not None and abs(tgt_av - target_fstop) > 0.05:\n",
113
- " continue\n",
114
- " pairs.append({\n",
115
- " 'input_path': f\"{split}/{m['source_image']}\",\n",
116
- " 'gt_path': f'{split}/{tgt_path}',\n",
117
- " 'f_number': tgt_av,\n",
118
- " 'focal_mm': float(m.get('focal_length', 50)),\n",
119
- " 'focus_m': float(m.get('focus_plane_distance', 2.0)),\n",
120
- " })\n",
121
- " return pairs\n",
122
- "\n",
123
- "def _fetch_img(path):\n",
124
- " \"\"\"HTTP fetch image → PIL. No disk write.\"\"\"\n",
125
- " r = requests.get(f'{HF_BASE}/{path}', timeout=30)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  " r.raise_for_status()\n",
127
- " return Image.open(io.BytesIO(r.content)).convert('RGB')\n",
128
- "\n",
129
- "class RealBokehStream(Dataset):\n",
130
- " \"\"\"Streaming dataset. Zero disk. Images fetched on-demand via HTTP.\"\"\"\n",
131
- " def __init__(self, split='train', crop_size=256, target_fstop=2.0, max_samples=None):\n",
132
- " t0 = time.time()\n",
133
- " # Async fetch all metadata (~3s)\n",
134
- " try:\n",
135
- " loop = asyncio.get_event_loop()\n",
136
- " if loop.is_running(): # Colab/Jupyter has running loop\n",
137
- " import nest_asyncio; nest_asyncio.apply()\n",
138
- " except RuntimeError:\n",
139
- " pass\n",
140
- " metas = asyncio.run(_fetch_all_metadata(split))\n",
141
- " self.pairs = _build_pairs(metas, split, target_fstop)\n",
142
- " random.shuffle(self.pairs)\n",
143
- " if max_samples:\n",
144
- " self.pairs = self.pairs[:max_samples]\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  " self.crop_size = crop_size\n",
146
  " self.to_tensor = transforms.ToTensor()\n",
147
- " print(f' {split}: {len(self.pairs)} pairs ready in {time.time()-t0:.1f}s (zero disk)')\n",
 
 
 
 
148
  "\n",
149
  " def __len__(self):\n",
150
  " return len(self.pairs)\n",
151
  "\n",
152
  " def __getitem__(self, idx):\n",
153
  " p = self.pairs[idx]\n",
154
- " # Fetch input + GT concurrently (2 threads)\n",
155
- " with ThreadPoolExecutor(2) as ex:\n",
156
- " f1 = ex.submit(_fetch_img, p['input_path'])\n",
157
- " f2 = ex.submit(_fetch_img, p['gt_path'])\n",
158
- " inp, gt = f1.result(), f2.result()\n",
159
  "\n",
160
- " # Synchronized random crop + flip on both images\n",
161
  " cs = self.crop_size\n",
162
  " w, h = inp.size\n",
163
  " if w >= cs and h >= cs:\n",
@@ -174,56 +252,26 @@
174
  " return {\n",
175
  " 'input': self.to_tensor(inp),\n",
176
  " 'target': self.to_tensor(gt),\n",
177
- " 'f_number': torch.tensor(p['f_number'], dtype=torch.float32),\n",
178
- " 'focal_length_mm': torch.tensor(p['focal_mm'], dtype=torch.float32),\n",
179
- " 'focus_distance_m':torch.tensor(p['focus_m'], dtype=torch.float32),\n",
180
  " }\n",
181
  "\n",
182
- "# ---- Create dataset + loader ----\n",
183
- "print('Fetching metadata (no images downloaded yet)...')\n",
184
- "try:\n",
185
- " import nest_asyncio; nest_asyncio.apply() # needed for Jupyter\n",
186
- "except ImportError:\n",
187
- " !pip install -q nest_asyncio\n",
188
- " import nest_asyncio; nest_asyncio.apply()\n",
189
- "\n",
190
- "train_ds = RealBokehStream(\n",
191
- " split='train',\n",
192
- " crop_size=CONFIG['crop_size'],\n",
193
- " target_fstop=CONFIG['target_fstop'],\n",
194
- " max_samples=CONFIG['max_samples'],\n",
195
- ")\n",
196
- "\n",
197
  "train_loader = DataLoader(\n",
198
  " train_ds,\n",
199
  " batch_size=CONFIG['batch_size'],\n",
200
  " shuffle=True,\n",
201
  " num_workers=CONFIG['num_workers'],\n",
202
- " prefetch_factor=2,\n",
203
- " persistent_workers=True,\n",
204
  " drop_last=True,\n",
 
205
  ")\n",
206
- "print(f'✓ DataLoader: {len(train_loader)} batches/epoch, {CONFIG[\"num_workers\"]} workers')\n",
207
- "print(f' Images streamed on-the-fly. Disk usage: 0 MB')"
208
- ]
209
- },
210
- {
211
- "cell_type": "code",
212
- "execution_count": null,
213
- "metadata": {},
214
- "outputs": [],
215
- "source": [
216
- "#@title Step 4: Sanity check — fetch 1 batch\n",
217
- "import time\n",
218
- "t0 = time.time()\n",
219
  "batch = next(iter(train_loader))\n",
220
- "t1 = time.time()\n",
221
- "print(f'First batch fetched in {t1-t0:.1f}s')\n",
222
- "print(f' input: {batch[\"input\"].shape}')\n",
223
- "print(f' target: {batch[\"target\"].shape}')\n",
224
- "print(f' f_number: {batch[\"f_number\"]}')\n",
225
- "print(f' focal_mm: {batch[\"focal_length_mm\"]}')\n",
226
- "print(f' focus_m: {batch[\"focus_distance_m\"]}')"
227
  ]
228
  },
229
  {
@@ -233,18 +281,16 @@
233
  "outputs": [],
234
  "source": [
235
  "#@title Step 5: Create model\n",
236
- "from bokehflow import BokehFlow, BokehFlowConfig, BokehFlowLoss, model_summary\n",
237
  "\n",
238
  "config = BokehFlowConfig(variant=CONFIG['variant'])\n",
239
  "model = BokehFlow(config)\n",
240
- "\n",
241
  "if NUM_GPUS > 1:\n",
242
  " model = torch.nn.DataParallel(model)\n",
243
- " print(f'DataParallel on {NUM_GPUS} GPUs')\n",
244
  "model = model.to(DEVICE)\n",
245
  "\n",
246
- "total_params = sum(p.numel() for p in model.parameters())\n",
247
- "print(f'\\n✓ BokehFlow-{CONFIG[\"variant\"].capitalize()}: {total_params:,} params on {DEVICE}')"
248
  ]
249
  },
250
  {
@@ -253,27 +299,23 @@
253
  "metadata": {},
254
  "outputs": [],
255
  "source": [
256
- "#@title Step 6: Train!\n",
257
- "from tqdm.auto import tqdm\n",
258
- "import torch.nn.functional as F\n",
259
- "\n",
260
  "optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])\n",
261
- "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['num_epochs'] * len(train_loader))\n",
262
  "criterion = BokehFlowLoss(lambda_depth=0.5)\n",
263
  "os.makedirs(CONFIG['output_dir'], exist_ok=True)\n",
264
  "\n",
265
- "print(f'Training: {CONFIG[\"num_epochs\"]} epochs × {len(train_loader)} batches')\n",
266
- "print(f'Images streamed from HF Hub — no disk needed\\n')\n",
267
  "\n",
268
  "for epoch in range(CONFIG['num_epochs']):\n",
269
  " model.train()\n",
270
- " running_loss = 0.0\n",
271
- " t_epoch = time.time()\n",
272
  " pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{CONFIG[\"num_epochs\"]}')\n",
273
  "\n",
274
- " for step, batch in enumerate(pbar):\n",
275
- " inp = batch['input'].to(DEVICE)\n",
276
- " tgt = batch['target'].to(DEVICE)\n",
277
  " f_num = batch['f_number'].to(DEVICE)\n",
278
  " focal = batch['focal_length_mm'].to(DEVICE)\n",
279
  " focus = batch['focus_distance_m'].to(DEVICE)\n",
@@ -288,20 +330,19 @@
288
  " optimizer.step()\n",
289
  " scheduler.step()\n",
290
  "\n",
291
- " running_loss += loss.item()\n",
292
  " pbar.set_postfix(loss=f'{loss.item():.4f}', lr=f'{scheduler.get_last_lr()[0]:.1e}')\n",
293
  "\n",
294
- " avg = running_loss / len(train_loader)\n",
295
- " elapsed = time.time() - t_epoch\n",
296
- " print(f' avg_loss={avg:.4f} time={elapsed:.0f}s ({elapsed/len(train_loader):.1f}s/batch)')\n",
297
  "\n",
298
- " # Save checkpoint\n",
299
  " state = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()\n",
300
  " ckpt = f'{CONFIG[\"output_dir\"]}/bokehflow_{CONFIG[\"variant\"]}_ep{epoch+1}.pt'\n",
301
- " torch.save({'epoch': epoch+1, 'model': state, 'loss': avg, 'config': CONFIG}, ckpt)\n",
302
- " print(f' ✓ Saved {ckpt}')\n",
303
  "\n",
304
- "print(f'\\n✓ Training complete!')"
305
  ]
306
  },
307
  {
@@ -310,61 +351,32 @@
310
  "metadata": {},
311
  "outputs": [],
312
  "source": [
313
- "#@title Step 7: Visualize result\n",
314
  "import matplotlib.pyplot as plt\n",
315
  "\n",
316
  "model.eval()\n",
317
- "sample = train_ds[0]\n",
318
  "with torch.no_grad():\n",
319
  " out = model(\n",
320
- " sample['input'].unsqueeze(0).to(DEVICE),\n",
321
- " sample['f_number'].unsqueeze(0).to(DEVICE),\n",
322
- " sample['focal_length_mm'].unsqueeze(0).to(DEVICE),\n",
323
- " sample['focus_distance_m'].unsqueeze(0).to(DEVICE),\n",
324
  " )\n",
325
  "\n",
326
- "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
327
- "axes[0].imshow(sample['input'].permute(1,2,0).cpu().numpy())\n",
328
- "axes[0].set_title('Input (f/22 sharp)')\n",
329
- "axes[1].imshow(out['bokeh'][0].permute(1,2,0).cpu().clamp(0,1).numpy())\n",
330
- "axes[1].set_title('BokehFlow output')\n",
331
- "axes[2].imshow(sample['target'].permute(1,2,0).cpu().numpy())\n",
332
- "axes[2].set_title('Ground truth (f/2.0)')\n",
333
- "for ax in axes: ax.axis('off')\n",
334
- "plt.tight_layout()\n",
335
- "plt.savefig('result.png', dpi=100, bbox_inches='tight')\n",
336
- "plt.show()\n",
337
  "print('✓ Done!')"
338
  ]
339
- },
340
- {
341
- "cell_type": "code",
342
- "execution_count": null,
343
- "metadata": {},
344
- "outputs": [],
345
- "source": [
346
- "#@title (Optional) Push trained model to HuggingFace Hub\n",
347
- "# from huggingface_hub import HfApi, login\n",
348
- "# login() # paste your HF token\n",
349
- "# api = HfApi()\n",
350
- "# api.upload_file(\n",
351
- "# path_or_fileobj=f'{CONFIG[\"output_dir\"]}/bokehflow_{CONFIG[\"variant\"]}_ep{CONFIG[\"num_epochs\"]}.pt',\n",
352
- "# path_in_repo=f'checkpoints/bokehflow_{CONFIG[\"variant\"]}.pt',\n",
353
- "# repo_id='YOUR_USERNAME/BokehFlow-trained',\n",
354
- "# )"
355
- ]
356
  }
357
  ],
358
  "metadata": {
359
- "kernelspec": {
360
- "display_name": "Python 3",
361
- "language": "python",
362
- "name": "python3"
363
- },
364
- "language_info": {
365
- "name": "python",
366
- "version": "3.10.0"
367
- },
368
  "accelerator": "GPU"
369
  },
370
  "nbformat": 4,
 
5
  "metadata": {},
6
  "source": [
7
  "# 🎬 BokehFlow Training Notebook\n",
8
+ "## Smart download: only f/2.0 pairs, parallel, with resume\n",
9
  "\n",
10
+ "**Downloads only what's needed:**\n",
11
+ "| Subset | Files | Size | Download Time |\n",
12
+ "|--------|-------|------|---------------|\n",
13
+ "| 200 scenes | 400 images | ~234 MB | ~2 min |\n",
14
+ "| 500 scenes | 1000 images | ~586 MB | ~4 min |\n",
15
+ "| All 3958 | 7918 images | ~4.5 GB | ~25 min |\n",
16
  "\n",
17
+ "Default: **500 scenes (~586MB)**. Cached re-running skips downloaded files.\n",
 
 
 
 
18
  "\n",
19
+ "**Just run all cells.**"
20
  ]
21
  },
22
  {
 
25
  "metadata": {},
26
  "outputs": [],
27
  "source": [
28
+ "#@title Step 0: Install\n",
29
+ "!pip install -q torch torchvision Pillow huggingface_hub tqdm aiohttp nest_asyncio"
30
  ]
31
  },
32
  {
 
35
  "metadata": {},
36
  "outputs": [],
37
  "source": [
38
+ "#@title Step 1: Download BokehFlow code\n",
39
  "from huggingface_hub import hf_hub_download\n",
40
  "hf_hub_download(repo_id='asdf98/BokehFlow', filename='bokehflow.py', local_dir='.')\n",
41
+ "print('✓ BokehFlow code ready')"
42
  ]
43
  },
44
  {
 
49
  "source": [
50
  "#@title Step 2: Config\n",
51
  "CONFIG = {\n",
52
+ " # Model\n",
53
  " 'variant': 'nano', # 'nano'=583K, 'small'=3.1M, 'base'=12M\n",
54
+ " \n",
55
+ " # Data\n",
56
+ " 'max_scenes': 500, # 200=quick test(234MB), 500=good(586MB), None=all(4.5GB)\n",
57
+ " 'target_fstop': 2.0,\n",
58
+ " 'crop_size': 256,\n",
59
+ " 'data_dir': '/tmp/realbokeh', # /tmp = fast SSD on Colab/Kaggle\n",
60
+ " \n",
61
+ " # Training\n",
62
  " 'batch_size': 4, # 4 for T4, 8 for A100\n",
63
+ " 'num_epochs': 10,\n",
 
64
  " 'lr': 3e-4,\n",
65
  " 'weight_decay': 0.05,\n",
66
  " 'max_grad_norm': 1.0,\n",
67
+ " 'num_workers': 2, # 2 for Colab, 4 for Kaggle\n",
 
 
68
  " 'output_dir': './checkpoints',\n",
69
  "}\n",
70
  "\n",
71
+ "import torch, os\n",
72
  "NUM_GPUS = torch.cuda.device_count()\n",
73
  "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
74
  "print(f'Device: {DEVICE}' + (f' ({torch.cuda.get_device_name(0)})' if torch.cuda.is_available() else ''))\n",
75
  "if NUM_GPUS > 1:\n",
76
+ " CONFIG['num_workers'] = 4\n",
77
+ " CONFIG['batch_size'] = 8\n",
78
+ " print(f'Multi-GPU: {NUM_GPUS} GPUs')"
79
  ]
80
  },
81
  {
 
84
  "metadata": {},
85
  "outputs": [],
86
  "source": [
87
+ "#@title Step 3: Smart downloadonly f/2.0 input+GT pairs, parallel, cached\n",
88
+ "import asyncio, aiohttp, json, time, random\n",
89
+ "from pathlib import Path\n",
90
+ "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
91
+ "from tqdm.auto import tqdm\n",
92
+ "import nest_asyncio; nest_asyncio.apply()\n",
93
  "\n",
94
  "HF_BASE = 'https://huggingface.co/datasets/timseizinger/RealBokeh_3MP/resolve/main'\n",
95
+ "DATA = Path(CONFIG['data_dir'])\n",
96
  "\n",
97
+ "# --- Phase 1: Fetch metadata (3s, async) ---\n",
98
+ "print('Phase 1: Fetching metadata...')\n",
99
+ "t0 = time.time()\n",
100
+ "\n",
101
+ "async def _fetch_metas(concurrency=50):\n",
102
+ " sem = asyncio.Semaphore(concurrency)\n",
103
+ " conn = aiohttp.TCPConnector(limit=concurrency)\n",
104
+ " async def fetch(session, i):\n",
105
  " async with sem:\n",
106
+ " url = f'{HF_BASE}/train/metadata/{i}.json'\n",
107
  " try:\n",
108
  " async with session.get(url) as r:\n",
109
+ " if r.status == 200: return await r.json(content_type=None)\n",
110
+ " except: pass\n",
 
 
111
  " return None\n",
112
+ " async with aiohttp.ClientSession(connector=conn) as s:\n",
113
+ " return await asyncio.gather(*[fetch(s, i) for i in range(1, 3961)])\n",
114
+ "\n",
115
+ "metas = [m for m in asyncio.run(_fetch_metas()) if m]\n",
116
+ "print(f' {len(metas)} scenes in {time.time()-t0:.1f}s')\n",
117
+ "\n",
118
+ "# Build download list: only input + f/2.0 GT\n",
119
+ "pairs = []\n",
120
+ "for m in metas:\n",
121
+ " gt_path = None\n",
122
+ " for tp, av in zip(m['target_images'], m['target_avs']):\n",
123
+ " if abs(av - CONFIG['target_fstop']) < 0.05:\n",
124
+ " gt_path = tp; break\n",
125
+ " if gt_path is None: continue\n",
126
+ " pairs.append({\n",
127
+ " 'input_rel': m['source_image'], # e.g. 'in/1_f22.JPG'\n",
128
+ " 'gt_rel': gt_path, # e.g. 'gt/1/1_f2.0.JPG'\n",
129
+ " 'f_number': CONFIG['target_fstop'],\n",
130
+ " 'focal_mm': float(m.get('focal_length', 50)),\n",
131
+ " 'focus_m': float(m.get('focus_plane_distance', 2.0)),\n",
132
+ " })\n",
133
+ "random.shuffle(pairs)\n",
134
+ "if CONFIG['max_scenes']:\n",
135
+ " pairs = pairs[:CONFIG['max_scenes']]\n",
136
+ "print(f' {len(pairs)} pairs selected for download')\n",
137
+ "\n",
138
+ "# --- Phase 2: Download images (parallel, with retry + skip cached) ---\n",
139
+ "print(f'\\nPhase 2: Downloading images to {DATA}...')\n",
140
+ "import requests\n",
141
+ "from requests.adapters import HTTPAdapter\n",
142
+ "from urllib3.util.retry import Retry\n",
143
+ "\n",
144
+ "def _make_session():\n",
145
+ " \"\"\"Session with automatic retry on 429/500/503.\"\"\"\n",
146
+ " s = requests.Session()\n",
147
+ " retries = Retry(\n",
148
+ " total=5,\n",
149
+ " backoff_factor=1.0, # 1s, 2s, 4s, 8s, 16s\n",
150
+ " status_forcelist=[429, 500, 502, 503],\n",
151
+ " allowed_methods=['GET'],\n",
152
+ " )\n",
153
+ " s.mount('https://', HTTPAdapter(max_retries=retries))\n",
154
+ " # Add HF token if available (higher rate limits)\n",
155
+ " hf_token = os.environ.get('HF_TOKEN', '')\n",
156
+ " if hf_token:\n",
157
+ " s.headers['Authorization'] = f'Bearer {hf_token}'\n",
158
+ " return s\n",
159
+ "\n",
160
+ "def _download_file(rel_path, session):\n",
161
+ " \"\"\"Download one file to DATA/train/{rel_path}. Skips if exists.\"\"\"\n",
162
+ " local = DATA / 'train' / rel_path\n",
163
+ " if local.exists() and local.stat().st_size > 1000:\n",
164
+ " return 'cached'\n",
165
+ " local.parent.mkdir(parents=True, exist_ok=True)\n",
166
+ " url = f'{HF_BASE}/train/{rel_path}'\n",
167
+ " r = session.get(url, timeout=60)\n",
168
  " r.raise_for_status()\n",
169
+ " local.write_bytes(r.content)\n",
170
+ " return 'downloaded'\n",
171
+ "\n",
172
+ "# Collect all files to download\n",
173
+ "all_files = set()\n",
174
+ "for p in pairs:\n",
175
+ " all_files.add(p['input_rel'])\n",
176
+ " all_files.add(p['gt_rel'])\n",
177
+ "\n",
178
+ "# Download with 8 threads (conservative to avoid 429)\n",
179
+ "t0 = time.time()\n",
180
+ "downloaded, cached = 0, 0\n",
181
+ "pbar = tqdm(total=len(all_files), desc='Downloading')\n",
182
+ "\n",
183
+ "# Use thread-local sessions to avoid connection pool issues\n",
184
+ "import threading\n",
185
+ "_local = threading.local()\n",
186
+ "\n",
187
+ "def _dl(rel_path):\n",
188
+ " if not hasattr(_local, 'session'):\n",
189
+ " _local.session = _make_session()\n",
190
+ " return _download_file(rel_path, _local.session)\n",
191
+ "\n",
192
+ "with ThreadPoolExecutor(max_workers=8) as ex:\n",
193
+ " futures = {ex.submit(_dl, f): f for f in all_files}\n",
194
+ " for fut in as_completed(futures):\n",
195
+ " result = fut.result()\n",
196
+ " if result == 'cached': cached += 1\n",
197
+ " else: downloaded += 1\n",
198
+ " pbar.update(1)\n",
199
+ "pbar.close()\n",
200
+ "\n",
201
+ "elapsed = time.time() - t0\n",
202
+ "print(f'\\n✓ Done in {elapsed:.0f}s: {downloaded} downloaded, {cached} cached')\n",
203
+ "print(f' Disk usage: ~{sum(f.stat().st_size for f in DATA.rglob(\"*.JPG\"))/1e6:.0f} MB')"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": null,
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "#@title Step 4: Dataset (reads from disk — fast, no network)\n",
213
+ "from torch.utils.data import Dataset, DataLoader\n",
214
+ "from torchvision import transforms\n",
215
+ "from PIL import Image\n",
216
+ "\n",
217
+ "class RealBokehDisk(Dataset):\n",
218
+ " \"\"\"Reads pre-downloaded image pairs from disk. Zero network at training time.\"\"\"\n",
219
+ " def __init__(self, pairs, data_dir, crop_size=256):\n",
220
+ " self.pairs = pairs\n",
221
+ " self.data_dir = Path(data_dir) / 'train'\n",
222
  " self.crop_size = crop_size\n",
223
  " self.to_tensor = transforms.ToTensor()\n",
224
+ " # Verify a sample\n",
225
+ " p = pairs[0]\n",
226
+ " assert (self.data_dir / p['input_rel']).exists(), f\"Missing: {p['input_rel']}\"\n",
227
+ " assert (self.data_dir / p['gt_rel']).exists(), f\"Missing: {p['gt_rel']}\"\n",
228
+ " print(f' Dataset: {len(pairs)} pairs, reading from disk (fast)')\n",
229
  "\n",
230
  " def __len__(self):\n",
231
  " return len(self.pairs)\n",
232
  "\n",
233
  " def __getitem__(self, idx):\n",
234
  " p = self.pairs[idx]\n",
235
+ " inp = Image.open(self.data_dir / p['input_rel']).convert('RGB')\n",
236
+ " gt = Image.open(self.data_dir / p['gt_rel']).convert('RGB')\n",
 
 
 
237
  "\n",
238
+ " # Synchronized random crop + flip\n",
239
  " cs = self.crop_size\n",
240
  " w, h = inp.size\n",
241
  " if w >= cs and h >= cs:\n",
 
252
  " return {\n",
253
  " 'input': self.to_tensor(inp),\n",
254
  " 'target': self.to_tensor(gt),\n",
255
+ " 'f_number': torch.tensor(p['f_number'], dtype=torch.float32),\n",
256
+ " 'focal_length_mm': torch.tensor(p['focal_mm'], dtype=torch.float32),\n",
257
+ " 'focus_distance_m': torch.tensor(p['focus_m'], dtype=torch.float32),\n",
258
  " }\n",
259
  "\n",
260
+ "train_ds = RealBokehDisk(pairs, CONFIG['data_dir'], CONFIG['crop_size'])\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  "train_loader = DataLoader(\n",
262
  " train_ds,\n",
263
  " batch_size=CONFIG['batch_size'],\n",
264
  " shuffle=True,\n",
265
  " num_workers=CONFIG['num_workers'],\n",
266
+ " pin_memory=True,\n",
 
267
  " drop_last=True,\n",
268
+ " persistent_workers=True,\n",
269
  ")\n",
270
+ "print(f'✓ DataLoader: {len(train_loader)} batches/epoch')\n",
271
+ "\n",
272
+ "# Quick sanity check\n",
 
 
 
 
 
 
 
 
 
 
273
  "batch = next(iter(train_loader))\n",
274
+ "print(f' Batch shapes: input={batch[\"input\"].shape}, target={batch[\"target\"].shape}')"
 
 
 
 
 
 
275
  ]
276
  },
277
  {
 
281
  "outputs": [],
282
  "source": [
283
  "#@title Step 5: Create model\n",
284
+ "from bokehflow import BokehFlow, BokehFlowConfig, BokehFlowLoss\n",
285
  "\n",
286
  "config = BokehFlowConfig(variant=CONFIG['variant'])\n",
287
  "model = BokehFlow(config)\n",
 
288
  "if NUM_GPUS > 1:\n",
289
  " model = torch.nn.DataParallel(model)\n",
 
290
  "model = model.to(DEVICE)\n",
291
  "\n",
292
+ "n_params = sum(p.numel() for p in model.parameters())\n",
293
+ "print(f'✓ BokehFlow-{CONFIG[\"variant\"].capitalize()}: {n_params:,} params on {DEVICE}')"
294
  ]
295
  },
296
  {
 
299
  "metadata": {},
300
  "outputs": [],
301
  "source": [
302
+ "#@title Step 6: Train\n",
 
 
 
303
  "optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])\n",
304
+ "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['num_epochs']*len(train_loader))\n",
305
  "criterion = BokehFlowLoss(lambda_depth=0.5)\n",
306
  "os.makedirs(CONFIG['output_dir'], exist_ok=True)\n",
307
  "\n",
308
+ "print(f'Training: {CONFIG[\"num_epochs\"]} epochs × {len(train_loader)} batches\\n')\n",
 
309
  "\n",
310
  "for epoch in range(CONFIG['num_epochs']):\n",
311
  " model.train()\n",
312
+ " total_loss = 0.0\n",
313
+ " t0 = time.time()\n",
314
  " pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{CONFIG[\"num_epochs\"]}')\n",
315
  "\n",
316
+ " for batch in pbar:\n",
317
+ " inp = batch['input'].to(DEVICE)\n",
318
+ " tgt = batch['target'].to(DEVICE)\n",
319
  " f_num = batch['f_number'].to(DEVICE)\n",
320
  " focal = batch['focal_length_mm'].to(DEVICE)\n",
321
  " focus = batch['focus_distance_m'].to(DEVICE)\n",
 
330
  " optimizer.step()\n",
331
  " scheduler.step()\n",
332
  "\n",
333
+ " total_loss += loss.item()\n",
334
  " pbar.set_postfix(loss=f'{loss.item():.4f}', lr=f'{scheduler.get_last_lr()[0]:.1e}')\n",
335
  "\n",
336
+ " avg = total_loss / len(train_loader)\n",
337
+ " dt = time.time() - t0\n",
338
+ " print(f' avg_loss={avg:.4f} time={dt:.0f}s ({dt/len(train_loader):.2f}s/batch)')\n",
339
  "\n",
 
340
  " state = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()\n",
341
  " ckpt = f'{CONFIG[\"output_dir\"]}/bokehflow_{CONFIG[\"variant\"]}_ep{epoch+1}.pt'\n",
342
+ " torch.save({'epoch': epoch+1, 'model': state, 'loss': avg}, ckpt)\n",
343
+ " print(f' ✓ {ckpt}')\n",
344
  "\n",
345
+ "print('\\n✓ Training complete!')"
346
  ]
347
  },
348
  {
 
351
  "metadata": {},
352
  "outputs": [],
353
  "source": [
354
+ "#@title Step 7: Visualize\n",
355
  "import matplotlib.pyplot as plt\n",
356
  "\n",
357
  "model.eval()\n",
358
+ "s = train_ds[0]\n",
359
  "with torch.no_grad():\n",
360
  " out = model(\n",
361
+ " s['input'].unsqueeze(0).to(DEVICE),\n",
362
+ " s['f_number'].unsqueeze(0).to(DEVICE),\n",
363
+ " s['focal_length_mm'].unsqueeze(0).to(DEVICE),\n",
364
+ " s['focus_distance_m'].unsqueeze(0).to(DEVICE),\n",
365
  " )\n",
366
  "\n",
367
+ "fig, ax = plt.subplots(1, 3, figsize=(15, 5))\n",
368
+ "ax[0].imshow(s['input'].permute(1,2,0).cpu()); ax[0].set_title('Input (f/22)')\n",
369
+ "ax[1].imshow(out['bokeh'][0].permute(1,2,0).cpu().clamp(0,1)); ax[1].set_title('BokehFlow')\n",
370
+ "ax[2].imshow(s['target'].permute(1,2,0).cpu()); ax[2].set_title('GT (f/2.0)')\n",
371
+ "for a in ax: a.axis('off')\n",
372
+ "plt.tight_layout(); plt.savefig('result.png', dpi=100); plt.show()\n",
 
 
 
 
 
373
  "print('✓ Done!')"
374
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  }
376
  ],
377
  "metadata": {
378
+ "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
379
+ "language_info": {"name": "python", "version": "3.10.0"},
 
 
 
 
 
 
 
380
  "accelerator": "GPU"
381
  },
382
  "nbformat": 4,