harsh99 commited on
Commit
d1131f4
·
1 Parent(s): 76ee43d
Files changed (1) hide show
  1. training.ipynb +170 -565
training.ipynb CHANGED
@@ -11,12 +11,12 @@
11
  "output_type": "stream",
12
  "text": [
13
  "Cloning into 'stable-diffusion'...\n",
14
- "remote: Enumerating objects: 156, done.\u001b[K\n",
15
- "remote: Counting objects: 100% (156/156), done.\u001b[K\n",
16
- "remote: Compressing objects: 100% (129/129), done.\u001b[K\n",
17
- "remote: Total 156 (delta 41), reused 141 (delta 27), pack-reused 0 (from 0)\u001b[K\n",
18
- "Receiving objects: 100% (156/156), 9.12 MiB | 37.38 MiB/s, done.\n",
19
- "Resolving deltas: 100% (41/41), done.\n"
20
  ]
21
  }
22
  ],
@@ -70,25 +70,25 @@
70
  "name": "stdout",
71
  "output_type": "stream",
72
  "text": [
73
- "--2025-06-16 17:29:32-- https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n",
74
- "Resolving huggingface.co (huggingface.co)... 18.239.50.103, 18.239.50.49, 18.239.50.16, ...\n",
75
- "Connecting to huggingface.co (huggingface.co)|18.239.50.103|:443... connected.\n",
76
  "HTTP request sent, awaiting response... 307 Temporary Redirect\n",
77
  "Location: /stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt [following]\n",
78
- "--2025-06-16 17:29:32-- https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n",
79
  "Reusing existing connection to huggingface.co:443.\n",
80
  "HTTP request sent, awaiting response... 302 Found\n",
81
- "Location: https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1750097473&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc1MDA5NzQ3M319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=ixhNuL21wGqTYSmWbp-FTGAc-mnEAOyFNxhrmGSYcIj2jFokr-VLv3n46s1W3-d73DrLo%7EKYv1-vSbbTeJMf-q1drmOflxD-6HmdhijgDBedxnEcqrN%7EJ1vPLNTxQvveD2Sk%7Es6Zpdb045ylv7k8RRxqP4rdZtJRLLb6JK2wze-fu8LKBxUEVlTnPo4Mf6fo-cqhuP16GG384BlCT-HjlgM7urHKvH%7E5HAPxNmiqoMEyE7W7essWnpJYQxJKaG1U96CqHWXfGAP8HuzKqCGOpWwNPzHTIXhvOIOY7Gc%7EdDc91QBdknj%7EYaY6aGq%7E8VKou1PjmS0F1r6AQbm3JSexvw__&Key-Pair-Id=K3RPWS32NSSJCE [following]\n",
82
- "--2025-06-16 17:29:32-- https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1750097473&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc1MDA5NzQ3M319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=ixhNuL21wGqTYSmWbp-FTGAc-mnEAOyFNxhrmGSYcIj2jFokr-VLv3n46s1W3-d73DrLo%7EKYv1-vSbbTeJMf-q1drmOflxD-6HmdhijgDBedxnEcqrN%7EJ1vPLNTxQvveD2Sk%7Es6Zpdb045ylv7k8RRxqP4rdZtJRLLb6JK2wze-fu8LKBxUEVlTnPo4Mf6fo-cqhuP16GG384BlCT-HjlgM7urHKvH%7E5HAPxNmiqoMEyE7W7essWnpJYQxJKaG1U96CqHWXfGAP8HuzKqCGOpWwNPzHTIXhvOIOY7Gc%7EdDc91QBdknj%7EYaY6aGq%7E8VKou1PjmS0F1r6AQbm3JSexvw__&Key-Pair-Id=K3RPWS32NSSJCE\n",
83
- "Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 18.239.83.87, 18.239.83.31, 18.239.83.30, ...\n",
84
- "Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|18.239.83.87|:443... connected.\n",
85
  "HTTP request sent, awaiting response... 200 OK\n",
86
  "Length: 4265437280 (4.0G) [binary/octet-stream]\n",
87
  "Saving to: ‘sd-v1-5-inpainting.ckpt’\n",
88
  "\n",
89
- "sd-v1-5-inpainting. 100%[===================>] 3.97G 307MB/s in 11s \n",
90
  "\n",
91
- "2025-06-16 17:29:43 (372 MB/s) - ‘sd-v1-5-inpainting.ckpt’ saved [4265437280/4265437280]\n",
92
  "\n"
93
  ]
94
  }
@@ -99,7 +99,7 @@
99
  },
100
  {
101
  "cell_type": "code",
102
- "execution_count": 12,
103
  "id": "4c5198ca",
104
  "metadata": {},
105
  "outputs": [
@@ -107,11 +107,12 @@
107
  "name": "stdout",
108
  "output_type": "stream",
109
  "text": [
110
- "attention.py encoder.py\t pipeline.py\t\t test.ipynb\n",
111
- "clip.py interface.py\t README.md\t\t training.ipynb\n",
112
- "ddpm.py merges.txt\t requirements.txt\t utils.py\n",
113
- "decoder.py model_converter.py sample_dataset\t VITON_Dataset.py\n",
114
- "diffusion.py model.py\t\t sd-v1-5-inpainting.ckpt vocab.json\n"
 
115
  ]
116
  }
117
  ],
@@ -172,39 +173,6 @@
172
  "# !pip install -U --no-cache-dir gdown --pre"
173
  ]
174
  },
175
- {
176
- "cell_type": "code",
177
- "execution_count": null,
178
- "id": "4467b7c7",
179
- "metadata": {},
180
- "outputs": [
181
- {
182
- "name": "stdout",
183
- "output_type": "stream",
184
- "text": [
185
- "/usr/local/lib/python3.11/dist-packages/gdown/__main__.py:140: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.\n",
186
- " warnings.warn(\n",
187
- "Failed to retrieve file url:\n",
188
- "\n",
189
- "\tToo many users have viewed or downloaded this file recently. Please\n",
190
- "\ttry accessing the file again later. If the file you are trying to\n",
191
- "\taccess is particularly large or is shared with many people, it may\n",
192
- "\ttake up to 24 hours to be able to view or download the file. If you\n",
193
- "\tstill can't access a file after 24 hours, contact your domain\n",
194
- "\tadministrator.\n",
195
- "\n",
196
- "You may still be able to access the file from the browser:\n",
197
- "\n",
198
- "\thttps://drive.google.com/uc?id=1tLx8LRp-sxDp0EcYmYoV_vXdSc-jJ79w\n",
199
- "\n",
200
- "but Gdown can't. Please check connections and permissions.\n"
201
- ]
202
- }
203
- ],
204
- "source": [
205
- "# !gdown --id 1tLx8LRp-sxDp0EcYmYoV_vXdSc-jJ79w\n"
206
- ]
207
- },
208
  {
209
  "cell_type": "code",
210
  "execution_count": null,
@@ -272,65 +240,26 @@
272
  },
273
  {
274
  "cell_type": "code",
275
- "execution_count": 5,
276
  "id": "53095103",
277
  "metadata": {},
278
- "outputs": [],
279
- "source": [
280
- "!mkdir output\n",
281
- "!mkdir checkpoints"
282
- ]
283
- },
284
- {
285
- "cell_type": "code",
286
- "execution_count": null,
287
- "id": "dcb8885d",
288
- "metadata": {},
289
  "outputs": [
290
  {
291
  "name": "stdout",
292
  "output_type": "stream",
293
  "text": [
294
- "Requirement already satisfied: diffusers in /usr/local/lib/python3.11/dist-packages (0.32.2)\n",
295
- "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.11/dist-packages (from diffusers) (8.6.1)\n",
296
- "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from diffusers) (3.18.0)\n",
297
- "Requirement already satisfied: huggingface-hub>=0.23.2 in /usr/local/lib/python3.11/dist-packages (from diffusers) (0.30.2)\n",
298
- "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from diffusers) (1.26.4)\n",
299
- "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from diffusers) (2024.11.6)\n",
300
- "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from diffusers) (2.32.3)\n",
301
- "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.11/dist-packages (from diffusers) (0.5.2)\n",
302
- "Requirement already satisfied: Pillow in /usr/local/lib/python3.11/dist-packages (from diffusers) (11.1.0)\n",
303
- "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.23.2->diffusers) (2024.12.0)\n",
304
- "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.23.2->diffusers) (24.2)\n",
305
- "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.23.2->diffusers) (6.0.2)\n",
306
- "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.23.2->diffusers) (4.67.1)\n",
307
- "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.23.2->diffusers) (4.13.1)\n",
308
- "Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.11/dist-packages (from importlib-metadata->diffusers) (3.21.0)\n",
309
- "Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy->diffusers) (1.3.8)\n",
310
- "Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy->diffusers) (1.2.4)\n",
311
- "Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy->diffusers) (0.1.1)\n",
312
- "Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy->diffusers) (2025.1.0)\n",
313
- "Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy->diffusers) (2022.1.0)\n",
314
- "Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy->diffusers) (2.4.1)\n",
315
- "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->diffusers) (3.4.1)\n",
316
- "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->diffusers) (3.10)\n",
317
- "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->diffusers) (2.3.0)\n",
318
- "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->diffusers) (2025.1.31)\n",
319
- "Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->diffusers) (2024.2.0)\n",
320
- "Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->diffusers) (2022.1.0)\n",
321
- "Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy->diffusers) (1.2.0)\n",
322
- "Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy->diffusers) (2024.2.0)\n",
323
- "Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy->diffusers) (2024.2.0)\n"
324
  ]
325
  }
326
  ],
327
  "source": [
328
- "!pip install diffusers"
 
329
  ]
330
  },
331
  {
332
  "cell_type": "code",
333
- "execution_count": 11,
334
  "id": "7efe325c",
335
  "metadata": {},
336
  "outputs": [],
@@ -352,7 +281,7 @@
352
  },
353
  {
354
  "cell_type": "code",
355
- "execution_count": 16,
356
  "id": "a48f2753",
357
  "metadata": {},
358
  "outputs": [
@@ -363,7 +292,7 @@
363
  "traceback": [
364
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
365
  "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
366
- "\u001b[0;32m/tmp/ipykernel_71/1017109895.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mempty_cache\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Release unused GPU memory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mgc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Run Python garbage collector\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
367
  "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/IPython/core/displayhook.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, result)\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite_output_prompt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0mformat_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_format_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_user_ns\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 264\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfill_exec_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mformat_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
368
  "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/IPython/core/displayhook.py\u001b[0m in \u001b[0;36mupdate_user_ns\u001b[0;34m(self, result)\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;31m# Avoid recursive reference when displaying _oh/Out\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 201\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_ns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'_oh'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 202\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_ns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'_oh'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdo_full_cache\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcull_cache\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
369
  "\u001b[0;31mKeyError\u001b[0m: '_oh'"
@@ -380,7 +309,7 @@
380
  },
381
  {
382
  "cell_type": "code",
383
- "execution_count": 17,
384
  "id": "5a57d765",
385
  "metadata": {},
386
  "outputs": [],
@@ -403,21 +332,10 @@
403
  },
404
  {
405
  "cell_type": "code",
406
- "execution_count": 14,
407
  "id": "5957ec57",
408
  "metadata": {},
409
- "outputs": [
410
- {
411
- "name": "stderr",
412
- "output_type": "stream",
413
- "text": [
414
- "2025-06-16 17:40:54.825758: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
415
- "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
416
- "E0000 00:00:1750095655.110921 71 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
417
- "E0000 00:00:1750095655.201950 71 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
418
- ]
419
- }
420
- ],
421
  "source": [
422
  "import tensorflow as tf\n",
423
  "tf.keras.backend.clear_session()"
@@ -425,7 +343,7 @@
425
  },
426
  {
427
  "cell_type": "code",
428
- "execution_count": 18,
429
  "id": "796e8ef7",
430
  "metadata": {},
431
  "outputs": [
@@ -450,7 +368,7 @@
450
  },
451
  {
452
  "cell_type": "code",
453
- "execution_count": 19,
454
  "id": "32ed173e",
455
  "metadata": {},
456
  "outputs": [
@@ -459,7 +377,7 @@
459
  "output_type": "stream",
460
  "text": [
461
  "Total RAM: 31.35 GB\n",
462
- "Available RAM: 27.30 GB\n"
463
  ]
464
  }
465
  ],
@@ -483,7 +401,7 @@
483
  },
484
  {
485
  "cell_type": "code",
486
- "execution_count": 20,
487
  "id": "3ce888b6",
488
  "metadata": {},
489
  "outputs": [],
@@ -495,16 +413,14 @@
495
  " (image_tensor.shape[0], 4, image_tensor.shape[2] // 8, image_tensor.shape[3] // 8),\n",
496
  " device=device,\n",
497
  " )\n",
498
- " \n",
499
- " # Encode using your custom encoder\n",
500
- " latent = encoder(image_tensor, encoder_noise)\n",
501
- " return latent"
502
  ]
503
  },
504
  {
505
  "cell_type": "code",
506
- "execution_count": 21,
507
- "id": "081c5b70",
508
  "metadata": {},
509
  "outputs": [
510
  {
@@ -515,312 +431,24 @@
515
  ]
516
  },
517
  {
518
- "name": "stderr",
519
- "output_type": "stream",
520
- "text": [
521
- "/tmp/ipykernel_71/658570771.py:77: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n",
522
- " self.scaler = torch.cuda.amp.GradScaler()\n"
523
- ]
524
- },
525
- {
526
- "name": "stdout",
527
- "output_type": "stream",
528
- "text": [
529
- "Creating dataloaders...\n",
530
- "Dataset vitonhd loaded, total 11647 pairs.\n",
531
- "Training for 178 epochs (16000 steps)\n",
532
- "Steps per epoch: 90\n",
533
- "Total training steps: 16000\n",
534
- "Total epochs: 178\n",
535
- "Initializing trainer...\n",
536
- "Enabling PEFT training (self-attention layers only)\n",
537
- "Total parameters: 1,022,287,147\n",
538
- "Trainable parameters: 6,554,880 (0.64%)\n",
539
- "Warning: Expected ~49,570,000 trainable parameters, got 6,554,880\n",
540
- "Starting training...\n",
541
- "Starting training for 178 epochs\n",
542
- "Total training steps: 2073166\n",
543
- "Using DREAM with lambda = 0\n",
544
- "Mixed precision: True\n"
545
- ]
546
- },
547
- {
548
- "name": "stderr",
549
- "output_type": "stream",
550
- "text": [
551
- "Epoch 1: 0%| | 0/11647 [00:00<?, ?it/s]/tmp/ipykernel_71/658570771.py:292: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
552
- " with torch.cuda.amp.autocast():\n",
553
- "/tmp/ipykernel_71/658570771.py:195: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
554
- " with torch.cuda.amp.autocast(enabled=self.use_mixed_precision):\n",
555
- "Epoch 1: 0%| | 1/11647 [00:09<29:46:54, 9.21s/it, loss=1.88, lr=1e-5, step=0]"
556
- ]
557
- },
558
- {
559
- "name": "stdout",
560
- "output_type": "stream",
561
- "text": [
562
- "Checkpoint saved: checkpoints/checkpoint_step_0.pth\n"
563
- ]
564
- },
565
- {
566
- "name": "stderr",
567
- "output_type": "stream",
568
- "text": [
569
- "Epoch 1: 0%| | 2/11647 [00:16<26:03:12, 8.05s/it, loss=2.69, lr=1e-5, step=0]"
570
- ]
571
- },
572
- {
573
- "name": "stdout",
574
- "output_type": "stream",
575
- "text": [
576
- "Checkpoint saved: checkpoints/checkpoint_step_0.pth\n"
577
- ]
578
- },
579
- {
580
- "name": "stderr",
581
- "output_type": "stream",
582
- "text": [
583
- "Epoch 1: 0%| | 3/11647 [00:23<24:01:31, 7.43s/it, loss=1.63, lr=1e-5, step=0]"
584
- ]
585
- },
586
- {
587
- "name": "stdout",
588
- "output_type": "stream",
589
- "text": [
590
- "Checkpoint saved: checkpoints/checkpoint_step_0.pth\n"
591
- ]
592
- },
593
- {
594
- "name": "stderr",
595
- "output_type": "stream",
596
- "text": [
597
- "Epoch 1: 0%| | 4/11647 [00:34<28:42:23, 8.88s/it, loss=1.67, lr=1e-5, step=0]"
598
- ]
599
- },
600
- {
601
- "name": "stdout",
602
- "output_type": "stream",
603
- "text": [
604
- "Checkpoint saved: checkpoints/checkpoint_step_0.pth\n"
605
- ]
606
- },
607
- {
608
- "name": "stderr",
609
- "output_type": "stream",
610
- "text": [
611
- "Epoch 1: 0%| | 5/11647 [00:47<33:21:58, 10.32s/it, loss=2.06, lr=1e-5, step=0]"
612
- ]
613
- },
614
- {
615
- "name": "stdout",
616
- "output_type": "stream",
617
- "text": [
618
- "Checkpoint saved: checkpoints/checkpoint_step_0.pth\n"
619
- ]
620
- },
621
- {
622
- "name": "stderr",
623
- "output_type": "stream",
624
- "text": [
625
- "Epoch 1: 0%| | 6/11647 [01:16<54:28:11, 16.84s/it, loss=2.37, lr=1e-5, step=0]"
626
- ]
627
- },
628
- {
629
- "name": "stdout",
630
- "output_type": "stream",
631
- "text": [
632
- "Checkpoint saved: checkpoints/checkpoint_step_0.pth\n"
633
- ]
634
- },
635
- {
636
- "name": "stderr",
637
- "output_type": "stream",
638
- "text": [
639
- "Epoch 1: 0%| | 7/11647 [01:22<43:07:22, 13.34s/it, loss=2.64, lr=1e-5, step=0]"
640
- ]
641
- },
642
- {
643
- "name": "stdout",
644
- "output_type": "stream",
645
- "text": [
646
- "Checkpoint saved: checkpoints/checkpoint_step_0.pth\n"
647
- ]
648
- },
649
- {
650
- "name": "stderr",
651
- "output_type": "stream",
652
- "text": [
653
- "Epoch 1: 0%| | 8/11647 [01:39<46:12:47, 14.29s/it, loss=2.49, lr=1e-5, step=0]"
654
- ]
655
- },
656
- {
657
- "name": "stdout",
658
- "output_type": "stream",
659
- "text": [
660
- "Checkpoint saved: checkpoints/checkpoint_step_0.pth\n"
661
- ]
662
- },
663
- {
664
- "name": "stderr",
665
- "output_type": "stream",
666
- "text": [
667
- "Epoch 1: 0%| | 9/11647 [01:45<38:37:52, 11.95s/it, loss=1.77, lr=1e-5, step=0]"
668
- ]
669
- },
670
- {
671
- "name": "stdout",
672
- "output_type": "stream",
673
- "text": [
674
- "Checkpoint saved: checkpoints/checkpoint_step_0.pth\n"
675
- ]
676
- },
677
- {
678
- "name": "stderr",
679
- "output_type": "stream",
680
- "text": [
681
- "Epoch 1: 0%| | 10/11647 [01:57<37:48:49, 11.70s/it, loss=2.18, lr=1e-5, step=0]"
682
- ]
683
- },
684
- {
685
- "name": "stdout",
686
- "output_type": "stream",
687
- "text": [
688
- "Checkpoint saved: checkpoints/checkpoint_step_0.pth\n"
689
- ]
690
- },
691
- {
692
- "name": "stderr",
693
- "output_type": "stream",
694
- "text": [
695
- "Epoch 1: 0%| | 11/11647 [02:20<49:32:59, 15.33s/it, loss=3.05, lr=1e-5, step=0]"
696
- ]
697
- },
698
- {
699
- "name": "stdout",
700
- "output_type": "stream",
701
- "text": [
702
- "Checkpoint saved: checkpoints/checkpoint_step_0.pth\n"
703
- ]
704
- },
705
- {
706
- "name": "stderr",
707
- "output_type": "stream",
708
- "text": [
709
- "Epoch 1: 0%| | 12/11647 [02:28<41:54:59, 12.97s/it, loss=2.02, lr=1e-5, step=0]"
710
- ]
711
- },
712
- {
713
- "name": "stdout",
714
- "output_type": "stream",
715
- "text": [
716
- "Checkpoint saved: checkpoints/checkpoint_step_0.pth\n"
717
- ]
718
- },
719
- {
720
- "name": "stderr",
721
- "output_type": "stream",
722
- "text": [
723
- "Epoch 1: 0%| | 13/11647 [02:41<42:09:43, 13.05s/it, loss=2.42, lr=1e-5, step=0]"
724
- ]
725
- },
726
- {
727
- "name": "stdout",
728
- "output_type": "stream",
729
- "text": [
730
- "Checkpoint saved: checkpoints/checkpoint_step_0.pth\n"
731
- ]
732
- },
733
- {
734
- "name": "stderr",
735
- "output_type": "stream",
736
- "text": [
737
- "Epoch 1: 0%| | 14/11647 [02:53<41:07:53, 12.73s/it, loss=1.64, lr=1e-5, step=0]"
738
- ]
739
- },
740
- {
741
- "name": "stdout",
742
- "output_type": "stream",
743
- "text": [
744
- "Checkpoint saved: checkpoints/checkpoint_step_0.pth\n"
745
- ]
746
- },
747
- {
748
- "name": "stderr",
749
- "output_type": "stream",
750
- "text": [
751
- "Epoch 1: 0%| | 15/11647 [03:06<41:56:49, 12.98s/it, loss=1.75, lr=1e-5, step=0]"
752
- ]
753
- },
754
- {
755
- "name": "stdout",
756
- "output_type": "stream",
757
- "text": [
758
- "Checkpoint saved: checkpoints/checkpoint_step_0.pth\n"
759
- ]
760
- },
761
- {
762
- "name": "stderr",
763
- "output_type": "stream",
764
- "text": [
765
- "Epoch 1: 0%| | 16/11647 [03:20<42:46:06, 13.24s/it, loss=2.43, lr=1e-5, step=0]"
766
- ]
767
- },
768
- {
769
- "name": "stdout",
770
- "output_type": "stream",
771
- "text": [
772
- "Checkpoint saved: checkpoints/checkpoint_step_0.pth\n"
773
- ]
774
- },
775
- {
776
- "name": "stderr",
777
- "output_type": "stream",
778
- "text": [
779
- "Epoch 1: 0%| | 17/11647 [03:43<51:45:46, 16.02s/it, loss=1.81, lr=1e-5, step=0]"
780
- ]
781
- },
782
- {
783
- "name": "stdout",
784
- "output_type": "stream",
785
- "text": [
786
- "Checkpoint saved: checkpoints/checkpoint_step_0.pth\n"
787
- ]
788
- },
789
- {
790
- "name": "stderr",
791
- "output_type": "stream",
792
- "text": [
793
- "Epoch 1: 0%| | 17/11647 [03:43<42:31:35, 13.16s/it, loss=1.81, lr=1e-5, step=0]\n"
794
- ]
795
- },
796
- {
797
- "ename": "KeyboardInterrupt",
798
- "evalue": "",
799
  "output_type": "error",
800
  "traceback": [
801
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
802
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
803
- "\u001b[0;32m/tmp/ipykernel_71/658570771.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 525\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 526\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"__main__\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 527\u001b[0;31m \u001b[0mmain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
804
- "\u001b[0;32m/tmp/ipykernel_71/658570771.py\u001b[0m in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 522\u001b[0m \u001b[0;31m# Start training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Starting training...\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 524\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 525\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 526\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"__main__\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
805
- "\u001b[0;32m/tmp/ipykernel_71/658570771.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 359\u001b[0m \u001b[0;31m# Train\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 360\u001b[0;31m \u001b[0mtrain_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 361\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 362\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Epoch {epoch+1}/{self.num_epochs}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
806
- "\u001b[0;32m/tmp/ipykernel_71/658570771.py\u001b[0m in \u001b[0;36mtrain_epoch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 291\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_mixed_precision\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 292\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mamp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautocast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 293\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 294\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 295\u001b[0m \u001b[0;31m# Scale loss for gradient accumulation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
807
- "\u001b[0;32m/tmp/ipykernel_71/658570771.py\u001b[0m in \u001b[0;36mcompute_loss\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 270\u001b[0m \u001b[0;31m# Standard training without DREAM\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 271\u001b[0;31m predicted_noise = self.diffusion(\n\u001b[0m\u001b[1;32m 272\u001b[0m \u001b[0minpainting_latent_model_input\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 273\u001b[0m \u001b[0mtimesteps_embedding\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
808
- "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1735\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1736\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1737\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1738\u001b[0m \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
809
- "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1745\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1746\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1748\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1749\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
810
- "\u001b[0;32m/kaggle/working/stable-diffusion/diffusion.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, latent, time)\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime_embedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 256\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 257\u001b[0;31m \u001b[0moutput\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlatent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 258\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 259\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfinal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
811
- "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1735\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1736\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1737\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1738\u001b[0m \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
812
- "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1745\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1746\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1748\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1749\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
813
- "\u001b[0;32m/kaggle/working/stable-diffusion/diffusion.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, time)\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[0;31m# Since we always concat with the skip connection of the encoder, the number of features increases before being sent to the decoder's layer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 227\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mskip_connections\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 228\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 229\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 230\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
814
- "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1735\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1736\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1737\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1738\u001b[0m \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
815
- "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1745\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1746\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1748\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1749\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
816
- "\u001b[0;32m/kaggle/working/stable-diffusion/diffusion.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, time)\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlayer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mUNET_AttentionBlock\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 129\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlayer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 130\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlayer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mUNET_ResidualBlock\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlayer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
817
- "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1735\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1736\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1737\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1738\u001b[0m \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
818
- "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1745\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1746\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1748\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1749\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
819
- "\u001b[0;32m/kaggle/working/stable-diffusion/diffusion.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayernorm_1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 88\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention_1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 89\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m+=\u001b[0m\u001b[0mresidue_short\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
820
- "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1735\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1736\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1737\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1738\u001b[0m \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
821
- "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1745\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1746\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1748\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1749\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
822
- "\u001b[0;32m/kaggle/working/stable-diffusion/attention.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, causal_mask)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;31m# Scaling by sqrt(d_head)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0mattention_weights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mattention_weights\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0md_head\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;31m# Causal mask to prevent attending to future tokens\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
823
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
824
  ]
825
  }
826
  ],
@@ -846,12 +474,12 @@
846
  "\n",
847
  "# Import your custom modules\n",
848
  "from load_model import preload_models_from_standard_weights\n",
849
- "from ddpm import DDPMSampler # Fixed import\n",
850
  "from utils import check_inputs, get_time_embedding, prepare_image, prepare_mask_image\n",
851
  "from diffusers.utils.torch_utils import randn_tensor\n",
852
  "\n",
853
  "class CatVTONTrainer:\n",
854
- " \"\"\"CatVTON Training Class with PEFT, CFG and DREAM support\"\"\"\n",
855
  " \n",
856
  " def __init__(\n",
857
  " self,\n",
@@ -859,19 +487,16 @@
859
  " train_dataloader: DataLoader,\n",
860
  " val_dataloader: Optional[DataLoader] = None,\n",
861
  " device: str = \"cuda\",\n",
862
- " learning_rate: float = 1e-5, # Updated to paper value\n",
863
- " num_epochs: int = 100,\n",
864
  " save_steps: int = 1000,\n",
865
  " output_dir: str = \"./checkpoints\",\n",
866
  " cfg_dropout_prob: float = 0.1,\n",
867
- " guidance_scale: float = 2.5,\n",
868
- " num_inference_steps: int = 50,\n",
869
- " gradient_accumulation_steps: int = 1,\n",
870
  " max_grad_norm: float = 1.0,\n",
871
  " use_peft: bool = True,\n",
872
- " dream_lambda: float = 10.0, # DREAM parameter\n",
873
  " resume_from_checkpoint: Optional[str] = None,\n",
874
- " use_mixed_precision: bool = True, # For memory optimization\n",
875
  " height=512,\n",
876
  " width=384,\n",
877
  " ):\n",
@@ -885,15 +510,12 @@
885
  " self.save_steps = save_steps\n",
886
  " self.output_dir = Path(output_dir)\n",
887
  " self.cfg_dropout_prob = cfg_dropout_prob\n",
888
- " self.guidance_scale = guidance_scale\n",
889
- " self.num_inference_steps = num_inference_steps\n",
890
- " self.gradient_accumulation_steps = gradient_accumulation_steps\n",
891
  " self.max_grad_norm = max_grad_norm\n",
892
  " self.use_peft = use_peft\n",
893
  " self.dream_lambda = dream_lambda\n",
894
  " self.use_mixed_precision = use_mixed_precision\n",
895
- " self.height=height\n",
896
- " self.width=width\n",
897
  " self.generator = torch.Generator(device=device)\n",
898
  " \n",
899
  " # Create output directory\n",
@@ -914,20 +536,18 @@
914
  " if resume_from_checkpoint:\n",
915
  " self._load_checkpoint(resume_from_checkpoint)\n",
916
  " \n",
917
- " self.encoder=self.models.get('encoder', None)\n",
918
- " self.decoder=self.models.get('decoder', None)\n",
919
- " self.diffusion=self.models.get('diffusion', None)\n",
920
  "\n",
921
  " # Setup models and optimizers\n",
922
  " self._setup_training()\n",
923
  " \n",
924
  " def _setup_training(self):\n",
925
  " \"\"\"Setup models for training with PEFT\"\"\"\n",
926
- " # Move models to device with mixed precision\n",
927
  " for name, model in self.models.items():\n",
928
  " model.to(self.device)\n",
929
- " # if self.use_mixed_precision and name != 'encoder': # Keep encoder in float32 for stability\n",
930
- " # model.half()\n",
931
  " \n",
932
  " # Freeze all parameters first\n",
933
  " for model in self.models.values():\n",
@@ -939,7 +559,7 @@
939
  " self._enable_peft_training()\n",
940
  " else:\n",
941
  " # Enable full training for diffusion model\n",
942
- " for param in self.models['diffusion'].parameters():\n",
943
  " param.requires_grad = True\n",
944
  " \n",
945
  " # Collect trainable parameters\n",
@@ -957,12 +577,6 @@
957
  " print(f\"Total parameters: {total_params:,}\")\n",
958
  " print(f\"Trainable parameters: {trainable_count:,} ({trainable_count/total_params*100:.2f}%)\")\n",
959
  " \n",
960
- " # Verify we're close to the paper's 49.57M parameters for self-attention only\n",
961
- " if self.use_peft:\n",
962
- " expected_params = 49_570_000 # 49.57M\n",
963
- " if abs(trainable_count - expected_params) > 5_000_000: # 5M tolerance\n",
964
- " print(f\"Warning: Expected ~{expected_params:,} trainable parameters, got {trainable_count:,}\")\n",
965
- " \n",
966
  " # Setup optimizer - AdamW as per paper\n",
967
  " self.optimizer = AdamW(\n",
968
  " trainable_params,\n",
@@ -972,31 +586,31 @@
972
  " eps=1e-8\n",
973
  " )\n",
974
  " \n",
975
- " # Setup learning rate scheduler (constant as per paper)\n",
976
- " # For constant LR, we can use a dummy scheduler\n",
977
  " self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(\n",
978
  " self.optimizer, lr_lambda=lambda epoch: 1.0\n",
979
  " )\n",
980
  " \n",
981
  " def _enable_peft_training(self):\n",
982
- " \"\"\"Enable PEFT training - only self-attention layers (49.57M parameters)\"\"\"\n",
983
  " print(\"Enabling PEFT training (self-attention layers only)\")\n",
984
  " \n",
985
  " unet = self.diffusion.unet\n",
986
  " \n",
987
- " # Enable attention layers in encoders\n",
988
  " for layers in [unet.encoders, unet.decoders]:\n",
989
  " for layer in layers:\n",
990
- " if hasattr(layer, 'attention_1'): # Alternative naming\n",
991
- " for param in layer.attention_1.parameters():\n",
992
- " param.requires_grad = True\n",
993
- " \n",
 
994
  " # Enable attention layers in bottleneck\n",
995
  " for layer in unet.bottleneck:\n",
996
- " if hasattr(layer, 'attention_1'):\n",
997
- " for param in layer.attention_1.parameters():\n",
998
  " param.requires_grad = True\n",
999
- " \n",
1000
  " def _apply_cfg_dropout(self, garment_latent: torch.Tensor) -> torch.Tensor:\n",
1001
  " \"\"\"Apply classifier-free guidance dropout (10% chance)\"\"\"\n",
1002
  " if self.training and random.random() < self.cfg_dropout_prob:\n",
@@ -1010,33 +624,35 @@
1010
  " cloth_images = batch['cloth'].to(self.device)\n",
1011
  " masks = batch['mask'].to(self.device)\n",
1012
  "\n",
1013
- " concat_dim = -2 # FIXME: y axis concat\n",
1014
- " # Prepare inputs to Tensor\n",
 
1015
  " image, condition_image, mask = check_inputs(person_images, cloth_images, masks, self.width, self.height)\n",
1016
  " image = prepare_image(person_images).to(self.device, dtype=self.weight_dtype)\n",
1017
  " condition_image = prepare_image(cloth_images).to(self.device, dtype=self.weight_dtype)\n",
1018
  " mask = prepare_mask_image(masks).to(self.device, dtype=self.weight_dtype)\n",
 
1019
  " # Mask image\n",
1020
  " masked_image = image * (mask < 0.5)\n",
1021
  "\n",
1022
  " with torch.cuda.amp.autocast(enabled=self.use_mixed_precision):\n",
1023
- " # VAE encoding\n",
1024
  " masked_latent = compute_vae_encodings(masked_image, self.encoder)\n",
1025
  " person_latent = compute_vae_encodings(person_images, self.encoder)\n",
1026
  " condition_latent = compute_vae_encodings(condition_image, self.encoder)\n",
1027
  " mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode=\"nearest\")\n",
 
1028
  " del image, mask, condition_image\n",
1029
  "\n",
1030
- "\n",
1031
  " # Apply CFG dropout to garment latent\n",
1032
  " condition_latent = self._apply_cfg_dropout(condition_latent)\n",
1033
  " \n",
1034
  " # Concatenate latents\n",
1035
  " masked_latent_concat = torch.cat([masked_latent, condition_latent], dim=concat_dim)\n",
1036
  " mask_latent_concat = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim)\n",
1037
- " target_latents=torch.cat([person_latent, condition_latent], dim=concat_dim)\n",
1038
  "\n",
1039
- " noise=randn_tensor(\n",
1040
  " masked_latent_concat.shape,\n",
1041
  " generator=self.generator,\n",
1042
  " device=masked_latent_concat.device,\n",
@@ -1061,11 +677,9 @@
1061
  " # Get initial noise prediction\n",
1062
  " with torch.no_grad():\n",
1063
  " epsilon_theta = self.diffusion(\n",
1064
- " inpainting_latent_model_input,\n",
1065
- " timesteps_embedding\n",
1066
- " )\n",
1067
- "\n",
1068
- " # print(f\"Predicted noise shape: {epsilon_theta.shape}\")\n",
1069
  " \n",
1070
  " # Apply DREAM: zˆt = √αt*z0 + √(1-αt)*(ε + λ*εθ)\n",
1071
  " alphas_cumprod = self.scheduler.alphas_cumprod.to(device=self.device, dtype=self.weight_dtype)\n",
@@ -1087,7 +701,7 @@
1087
  " masked_latent_concat\n",
1088
  " ], dim=1)\n",
1089
  "\n",
1090
- " predicted_noise= self.diffusion(\n",
1091
  " dream_model_input,\n",
1092
  " timesteps_embedding\n",
1093
  " )\n",
@@ -1106,64 +720,61 @@
1106
  " return loss\n",
1107
  " \n",
1108
  " def train_epoch(self) -> float:\n",
1109
- " \"\"\"Train for one epoch\"\"\"\n",
1110
- " self.models['diffusion'].train()\n",
1111
  " total_loss = 0.0\n",
1112
  " num_batches = len(self.train_dataloader)\n",
1113
  " \n",
1114
  " progress_bar = tqdm(self.train_dataloader, desc=f\"Epoch {self.current_epoch+1}\")\n",
1115
  " \n",
1116
  " for step, batch in enumerate(progress_bar):\n",
1117
- " # Compute loss with mixed precision\n",
 
 
 
1118
  " if self.use_mixed_precision:\n",
1119
  " with torch.cuda.amp.autocast():\n",
1120
  " loss = self.compute_loss(batch)\n",
1121
  " \n",
1122
- " # Scale loss for gradient accumulation\n",
1123
- " loss = loss / self.gradient_accumulation_steps\n",
1124
- " \n",
1125
  " # Backward pass with scaling\n",
1126
  " self.scaler.scale(loss).backward()\n",
 
 
 
 
 
 
 
 
 
 
1127
  " else:\n",
1128
  " loss = self.compute_loss(batch)\n",
1129
- " loss = loss / self.gradient_accumulation_steps\n",
1130
  " loss.backward()\n",
1131
- " \n",
1132
- " # Gradient accumulation\n",
1133
- " if (step + 1) % self.gradient_accumulation_steps == 0:\n",
1134
- " if self.use_mixed_precision:\n",
1135
- " # Unscale gradients and clip\n",
1136
- " self.scaler.unscale_(self.optimizer)\n",
1137
- " torch.nn.utils.clip_grad_norm_(\n",
1138
- " [p for p in self.diffusion.parameters() if p.requires_grad],\n",
1139
- " self.max_grad_norm\n",
1140
- " )\n",
1141
- " \n",
1142
- " # Optimizer step with scaling\n",
1143
- " self.scaler.step(self.optimizer)\n",
1144
- " self.scaler.update()\n",
1145
- " else:\n",
1146
- " # Clip gradients\n",
1147
- " torch.nn.utils.clip_grad_norm_(\n",
1148
- " [p for p in self.diffusion.parameters() if p.requires_grad],\n",
1149
- " self.max_grad_norm\n",
1150
- " )\n",
1151
- " self.optimizer.step()\n",
1152
  " \n",
1153
- " self.lr_scheduler.step()\n",
1154
- " self.optimizer.zero_grad()\n",
1155
- " self.global_step += 1\n",
 
 
 
 
 
 
 
 
 
1156
  " \n",
1157
- " total_loss += loss.item() * self.gradient_accumulation_steps\n",
1158
  " \n",
1159
  " # Update progress bar\n",
1160
  " progress_bar.set_postfix({\n",
1161
- " 'loss': loss.item() * self.gradient_accumulation_steps,\n",
1162
  " 'lr': self.optimizer.param_groups[0]['lr'],\n",
1163
  " 'step': self.global_step\n",
1164
  " })\n",
1165
  " \n",
1166
- " # Save checkpoint\n",
1167
  " if self.global_step % self.save_steps == 0:\n",
1168
  " self._save_checkpoint()\n",
1169
  " \n",
@@ -1174,29 +785,32 @@
1174
  " return total_loss / num_batches\n",
1175
  " \n",
1176
  " def train(self):\n",
1177
- " \"\"\"Main training loop\"\"\"\n",
1178
  " print(f\"Starting training for {self.num_epochs} epochs\")\n",
1179
- " print(f\"Total training steps: {self.num_epochs * len(self.train_dataloader)}\")\n",
1180
  " print(f\"Using DREAM with lambda = {self.dream_lambda}\")\n",
1181
  " print(f\"Mixed precision: {self.use_mixed_precision}\")\n",
1182
  " \n",
1183
  " for epoch in range(self.current_epoch, self.num_epochs):\n",
1184
  " self.current_epoch = epoch\n",
1185
  " \n",
1186
- " # Train\n",
1187
  " train_loss = self.train_epoch()\n",
1188
  " \n",
1189
- " print(f\"Epoch {epoch+1}/{self.num_epochs}\")\n",
1190
- " print(f\"Train Loss: {train_loss:.6f}\")\n",
1191
  " \n",
1192
  " # Save epoch checkpoint\n",
1193
- " if (epoch + 1) % 10 == 0:\n",
1194
  " self._save_checkpoint(epoch_checkpoint=True)\n",
1195
  " \n",
1196
  " # Clear cache at end of epoch\n",
1197
  " torch.cuda.empty_cache()\n",
 
 
 
 
1198
  " \n",
1199
- " def _save_checkpoint(self, is_best: bool = False, epoch_checkpoint: bool = False):\n",
1200
  " \"\"\"Save model checkpoint\"\"\"\n",
1201
  " checkpoint = {\n",
1202
  " 'global_step': self.global_step,\n",
@@ -1211,7 +825,9 @@
1211
  " if self.use_mixed_precision:\n",
1212
  " checkpoint['scaler_state_dict'] = self.scaler.state_dict()\n",
1213
  " \n",
1214
- " if is_best:\n",
 
 
1215
  " checkpoint_path = self.output_dir / \"best_model.pth\"\n",
1216
  " elif epoch_checkpoint:\n",
1217
  " checkpoint_path = self.output_dir / f\"checkpoint_epoch_{self.current_epoch+1}.pth\"\n",
@@ -1239,65 +855,57 @@
1239
  " print(f\"Checkpoint loaded: {checkpoint_path}\")\n",
1240
  " print(f\"Resuming from epoch {self.current_epoch}, step {self.global_step}\")\n",
1241
  "\n",
1242
- "def create_dataloaders(args) -> Tuple[DataLoader, Optional[DataLoader]]:\n",
1243
- " \"\"\"Create training and validation dataloaders\"\"\"\n",
1244
- " # Dataset\n",
1245
  " if args.dataset_name == \"vitonhd\":\n",
1246
  " dataset = VITONHDTestDataset(args)\n",
1247
  " else:\n",
1248
- " raise ValueError(f\"Invalid dataset name {args.dataset}.\")\n",
 
1249
  " print(f\"Dataset {args.dataset_name} loaded, total {len(dataset)} pairs.\")\n",
1250
- "\n",
1251
  " dataloader = DataLoader(\n",
1252
  " dataset,\n",
1253
  " batch_size=args.batch_size,\n",
1254
- " shuffle=False,\n",
1255
- " num_workers=args.dataloader_num_workers\n",
 
 
 
1256
  " )\n",
1257
  " \n",
1258
  " return dataloader\n",
1259
  "\n",
1260
  "\n",
1261
  "def main():\n",
1262
- " args=argparse.Namespace()\n",
1263
- " args.__dict__= {\n",
1264
  " \"base_model_path\": \"sd-v1-5-inpainting.ckpt\",\n",
1265
- " \"resume_path\": \"zhengchong/CatVTON\",\n",
1266
  " \"dataset_name\": \"vitonhd\",\n",
1267
  " \"data_root_path\": \"/kaggle/input/viton-hd-dataset\",\n",
1268
- " \"output_dir\": \"./output\",\n",
 
1269
  " \"seed\": 42,\n",
1270
  " \"batch_size\": 1,\n",
1271
- " \"num_inference_steps\": 50,\n",
1272
- " \"guidance_scale\": 2.5,\n",
1273
  " \"width\": 384,\n",
1274
  " \"height\": 512,\n",
1275
  " \"repaint\": True,\n",
1276
  " \"eval_pair\": True,\n",
1277
  " \"concat_eval_results\": True,\n",
1278
- " \"allow_tf32\": True,\n",
1279
- " \"dataloader_num_workers\": 4,\n",
1280
- " \"mixed_precision\": 'no',\n",
1281
  " \"concat_axis\": 'y',\n",
1282
- " \"enable_condition_noise\": True,\n",
1283
- " \"device\":\"cuda\",\n",
1284
- " \"num_training_steps\": 16000,\n",
1285
  " \"learning_rate\": 1e-5,\n",
1286
- " \"gradient_accumulation_steps\": 128, # Simulate batch size 128\n",
1287
  " \"max_grad_norm\": 1.0,\n",
1288
- " \"use_peft\": True,\n",
1289
  " \"cfg_dropout_prob\": 0.1,\n",
1290
  " \"dream_lambda\": 0,\n",
 
1291
  " \"use_mixed_precision\": True,\n",
1292
- " \"output_dir\": \"./checkpoints\",\n",
1293
  " \"save_steps\": 1000,\n",
1294
- " \"resume_from_checkpoint\": None,\n",
1295
  " \"is_train\": True\n",
1296
  " }\n",
1297
  " \n",
1298
- " # Calculate epochs from training steps\n",
1299
- " # This will be calculated after dataloader creation\n",
1300
- " \n",
1301
  " # Set random seeds\n",
1302
  " torch.manual_seed(args.seed)\n",
1303
  " np.random.seed(args.seed)\n",
@@ -1305,27 +913,23 @@
1305
  " if torch.cuda.is_available():\n",
1306
  " torch.cuda.manual_seed_all(args.seed)\n",
1307
  " \n",
1308
- " # Optimize CUDA settings for memory\n",
1309
  " torch.backends.cudnn.benchmark = True\n",
1310
- " torch.backends.cuda.matmul.allow_tf32 = True \n",
 
1311
  " torch.set_float32_matmul_precision(\"high\")\n",
1312
  "\n",
1313
  " # Load pretrained models\n",
1314
  " print(\"Loading pretrained models...\")\n",
1315
  " models = preload_models_from_standard_weights(args.base_model_path, args.device)\n",
1316
  " \n",
1317
- " # Create dataloaders\n",
1318
- " print(\"Creating dataloaders...\")\n",
1319
  " train_dataloader = create_dataloaders(args)\n",
1320
  " \n",
1321
- " # Calculate epochs from training steps\n",
1322
- " steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps\n",
1323
- " num_epochs = (args.num_training_steps + steps_per_epoch - 1) // steps_per_epoch\n",
1324
- " print(f\"Training for {num_epochs} epochs ({args.num_training_steps} steps)\")\n",
1325
- " args.num_epochs = num_epochs\n",
1326
- " print(f\"Steps per epoch: {steps_per_epoch}\")\n",
1327
- " print(f\"Total training steps: {args.num_training_steps}\")\n",
1328
- " print(f\"Total epochs: {num_epochs}\")\n",
1329
  " # Initialize trainer\n",
1330
  " print(\"Initializing trainer...\") \n",
1331
  " trainer = CatVTONTrainer(\n",
@@ -1337,27 +941,28 @@
1337
  " save_steps=args.save_steps,\n",
1338
  " output_dir=args.output_dir,\n",
1339
  " cfg_dropout_prob=args.cfg_dropout_prob,\n",
1340
- " guidance_scale=args.guidance_scale,\n",
1341
- " num_inference_steps=50, # Fixed as per paper\n",
1342
- " gradient_accumulation_steps=args.gradient_accumulation_steps,\n",
1343
  " max_grad_norm=args.max_grad_norm,\n",
1344
  " use_peft=args.use_peft,\n",
1345
  " dream_lambda=args.dream_lambda,\n",
1346
  " resume_from_checkpoint=args.resume_from_checkpoint,\n",
1347
- " use_mixed_precision=args.use_mixed_precision\n",
 
 
1348
  " )\n",
 
1349
  " # Start training\n",
1350
  " print(\"Starting training...\")\n",
1351
  " trainer.train() \n",
1352
  "\n",
 
1353
  "if __name__ == \"__main__\":\n",
1354
- " main()\n"
1355
  ]
1356
  },
1357
  {
1358
  "cell_type": "code",
1359
  "execution_count": null,
1360
- "id": "2eff454d",
1361
  "metadata": {},
1362
  "outputs": [],
1363
  "source": []
@@ -1365,7 +970,7 @@
1365
  {
1366
  "cell_type": "code",
1367
  "execution_count": null,
1368
- "id": "2eefd6bc",
1369
  "metadata": {},
1370
  "outputs": [],
1371
  "source": []
 
11
  "output_type": "stream",
12
  "text": [
13
  "Cloning into 'stable-diffusion'...\n",
14
+ "remote: Enumerating objects: 184, done.\u001b[K\n",
15
+ "remote: Counting objects: 100% (184/184), done.\u001b[K\n",
16
+ "remote: Compressing objects: 100% (156/156), done.\u001b[K\n",
17
+ "remote: Total 184 (delta 44), reused 165 (delta 26), pack-reused 0 (from 0)\u001b[K\n",
18
+ "Receiving objects: 100% (184/184), 9.94 MiB | 37.02 MiB/s, done.\n",
19
+ "Resolving deltas: 100% (44/44), done.\n"
20
  ]
21
  }
22
  ],
 
70
  "name": "stdout",
71
  "output_type": "stream",
72
  "text": [
73
+ "--2025-06-17 08:50:15-- https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n",
74
+ "Resolving huggingface.co (huggingface.co)... 3.171.171.104, 3.171.171.128, 3.171.171.6, ...\n",
75
+ "Connecting to huggingface.co (huggingface.co)|3.171.171.104|:443... connected.\n",
76
  "HTTP request sent, awaiting response... 307 Temporary Redirect\n",
77
  "Location: /stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt [following]\n",
78
+ "--2025-06-17 08:50:15-- https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n",
79
  "Reusing existing connection to huggingface.co:443.\n",
80
  "HTTP request sent, awaiting response... 302 Found\n",
81
+ "Location: https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1750153142&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc1MDE1MzE0Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=kAea10Cu%7EhNLABWiXI0i%7E5gAtwsQUUM6CIZczAEWsswZur-XllSQvXEoKksmPdojVE654r7s-CxII8r%7EQ52to%7EQMLbjsjw-JmXq4duiq91qz6U5aenByAXSpOO1ihAoCmCkP02e7L5Wcbs%7EhaV26W9Q%7EAfbwyQ1mn9ta%7EHIDiE7AuNuHgkEEA2IP45ao25b9zsaFw6fIUlBy93Meuf82zwzsw8CJPWV9QEwj-oPVeSDyv3ZhfxS3iCgGSYS320Vs7NcK%7EqJxPfttpTHG9m6zAnfxOpWjYVQfre6HnHUt3VHOy4QdDvpyfljgEQoH4LxRBWI%7Ev72YjOJZDEgSPoTi1Q__&Key-Pair-Id=K3RPWS32NSSJCE [following]\n",
82
+ "--2025-06-17 08:50:15-- https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1750153142&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc1MDE1MzE0Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=kAea10Cu%7EhNLABWiXI0i%7E5gAtwsQUUM6CIZczAEWsswZur-XllSQvXEoKksmPdojVE654r7s-CxII8r%7EQ52to%7EQMLbjsjw-JmXq4duiq91qz6U5aenByAXSpOO1ihAoCmCkP02e7L5Wcbs%7EhaV26W9Q%7EAfbwyQ1mn9ta%7EHIDiE7AuNuHgkEEA2IP45ao25b9zsaFw6fIUlBy93Meuf82zwzsw8CJPWV9QEwj-oPVeSDyv3ZhfxS3iCgGSYS320Vs7NcK%7EqJxPfttpTHG9m6zAnfxOpWjYVQfre6HnHUt3VHOy4QdDvpyfljgEQoH4LxRBWI%7Ev72YjOJZDEgSPoTi1Q__&Key-Pair-Id=K3RPWS32NSSJCE\n",
83
+ "Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 18.160.78.83, 18.160.78.87, 18.160.78.43, ...\n",
84
+ "Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|18.160.78.83|:443... connected.\n",
85
  "HTTP request sent, awaiting response... 200 OK\n",
86
  "Length: 4265437280 (4.0G) [binary/octet-stream]\n",
87
  "Saving to: ‘sd-v1-5-inpainting.ckpt’\n",
88
  "\n",
89
+ "sd-v1-5-inpainting. 100%[===================>] 3.97G 324MB/s in 12s \n",
90
  "\n",
91
+ "2025-06-17 08:50:27 (341 MB/s) - ‘sd-v1-5-inpainting.ckpt’ saved [4265437280/4265437280]\n",
92
  "\n"
93
  ]
94
  }
 
99
  },
100
  {
101
  "cell_type": "code",
102
+ "execution_count": 5,
103
  "id": "4c5198ca",
104
  "metadata": {},
105
  "outputs": [
 
107
  "name": "stdout",
108
  "output_type": "stream",
109
  "text": [
110
+ "attention.py interface.py\t README.md\t\t utils.py\n",
111
+ "clip.py load_model.py\t requirements.txt\t VITON_Dataset.py\n",
112
+ "ddpm.py merges.txt\t sample_dataset\t vocab.json\n",
113
+ "decoder.py model_converter.py sd-v1-5-inpainting.ckpt\n",
114
+ "diffusion.py output\t\t test.ipynb\n",
115
+ "encoder.py pipeline.py\t training.ipynb\n"
116
  ]
117
  }
118
  ],
 
173
  "# !pip install -U --no-cache-dir gdown --pre"
174
  ]
175
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  {
177
  "cell_type": "code",
178
  "execution_count": null,
 
240
  },
241
  {
242
  "cell_type": "code",
243
+ "execution_count": 6,
244
  "id": "53095103",
245
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
246
  "outputs": [
247
  {
248
  "name": "stdout",
249
  "output_type": "stream",
250
  "text": [
251
+ "mkdir: cannot create directory ‘output’: File exists\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  ]
253
  }
254
  ],
255
  "source": [
256
+ "!mkdir output\n",
257
+ "!mkdir checkpoints"
258
  ]
259
  },
260
  {
261
  "cell_type": "code",
262
+ "execution_count": 34,
263
  "id": "7efe325c",
264
  "metadata": {},
265
  "outputs": [],
 
281
  },
282
  {
283
  "cell_type": "code",
284
+ "execution_count": 35,
285
  "id": "a48f2753",
286
  "metadata": {},
287
  "outputs": [
 
292
  "traceback": [
293
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
294
  "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
295
+ "\u001b[0;32m/tmp/ipykernel_69/1017109895.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mempty_cache\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Release unused GPU memory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mgc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Run Python garbage collector\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
296
  "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/IPython/core/displayhook.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, result)\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite_output_prompt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0mformat_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_format_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_user_ns\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 264\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfill_exec_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mformat_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
297
  "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/IPython/core/displayhook.py\u001b[0m in \u001b[0;36mupdate_user_ns\u001b[0;34m(self, result)\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;31m# Avoid recursive reference when displaying _oh/Out\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 201\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_ns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'_oh'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 202\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_ns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'_oh'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdo_full_cache\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcull_cache\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
298
  "\u001b[0;31mKeyError\u001b[0m: '_oh'"
 
309
  },
310
  {
311
  "cell_type": "code",
312
+ "execution_count": 36,
313
  "id": "5a57d765",
314
  "metadata": {},
315
  "outputs": [],
 
332
  },
333
  {
334
  "cell_type": "code",
335
+ "execution_count": 37,
336
  "id": "5957ec57",
337
  "metadata": {},
338
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
339
  "source": [
340
  "import tensorflow as tf\n",
341
  "tf.keras.backend.clear_session()"
 
343
  },
344
  {
345
  "cell_type": "code",
346
+ "execution_count": 38,
347
  "id": "796e8ef7",
348
  "metadata": {},
349
  "outputs": [
 
368
  },
369
  {
370
  "cell_type": "code",
371
+ "execution_count": null,
372
  "id": "32ed173e",
373
  "metadata": {},
374
  "outputs": [
 
377
  "output_type": "stream",
378
  "text": [
379
  "Total RAM: 31.35 GB\n",
380
+ "Available RAM: 24.16 GB\n"
381
  ]
382
  }
383
  ],
 
401
  },
402
  {
403
  "cell_type": "code",
404
+ "execution_count": 39,
405
  "id": "3ce888b6",
406
  "metadata": {},
407
  "outputs": [],
 
413
  " (image_tensor.shape[0], 4, image_tensor.shape[2] // 8, image_tensor.shape[3] // 8),\n",
414
  " device=device,\n",
415
  " )\n",
416
+ " with torch.no_grad(): # VAE encoding doesn't need gradients\n",
417
+ " return encoder(image_tensor, encoder_noise)"
 
 
418
  ]
419
  },
420
  {
421
  "cell_type": "code",
422
+ "execution_count": 41,
423
+ "id": "3aea80d9",
424
  "metadata": {},
425
  "outputs": [
426
  {
 
431
  ]
432
  },
433
  {
434
+ "ename": "OutOfMemoryError",
435
+ "evalue": "CUDA out of memory. Tried to allocate 58.00 MiB. GPU 0 has a total capacity of 15.89 GiB of which 29.12 MiB is free. Process 3907 has 15.85 GiB memory in use. Of the allocated memory 15.49 GiB is allocated by PyTorch, and 62.82 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  "output_type": "error",
437
  "traceback": [
438
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
439
+ "\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)",
440
+ "\u001b[0;32m/tmp/ipykernel_69/1468414648.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 502\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 503\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"__main__\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 504\u001b[0;31m \u001b[0mmain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
441
+ "\u001b[0;32m/tmp/ipykernel_69/1468414648.py\u001b[0m in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 467\u001b[0m \u001b[0;31m# Load pretrained models\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 468\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Loading pretrained models...\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 469\u001b[0;31m \u001b[0mmodels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpreload_models_from_standard_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbase_model_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 470\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 471\u001b[0m \u001b[0;31m# Create dataloader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
442
+ "\u001b[0;32m/kaggle/working/stable-diffusion/load_model.py\u001b[0m in \u001b[0;36mpreload_models_from_standard_weights\u001b[0;34m(ckpt_path, device, finetune_weights_path)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mstate_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel_converter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_from_standard_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mckpt_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mdiffusion\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mDiffusion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0min_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0min_channels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mout_channels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfinetune_weights_path\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
443
+ "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1338\u001b[0m \u001b[0;32mraise\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1339\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1340\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1341\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1342\u001b[0m def register_full_backward_pre_hook(\n",
444
+ "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 898\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 899\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 900\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 901\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
445
+ "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 898\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 899\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 900\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 901\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
446
+ "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 898\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 899\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 900\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 901\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
447
+ "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 898\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 899\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 900\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 901\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
448
+ "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 898\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 899\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 900\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 901\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
449
+ "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 925\u001b[0m \u001b[0;31m# `with torch.no_grad():`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 926\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 927\u001b[0;31m \u001b[0mparam_applied\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 928\u001b[0m \u001b[0mp_should_use_set_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 929\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
450
+ "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 1324\u001b[0m \u001b[0mmemory_format\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconvert_to_format\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1325\u001b[0m )\n\u001b[0;32m-> 1326\u001b[0;31m return t.to(\n\u001b[0m\u001b[1;32m 1327\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1328\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_complex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
451
+ "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 58.00 MiB. GPU 0 has a total capacity of 15.89 GiB of which 29.12 MiB is free. Process 3907 has 15.85 GiB memory in use. Of the allocated memory 15.49 GiB is allocated by PyTorch, and 62.82 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)"
 
 
 
 
 
 
 
 
 
452
  ]
453
  }
454
  ],
 
474
  "\n",
475
  "# Import your custom modules\n",
476
  "from load_model import preload_models_from_standard_weights\n",
477
+ "from ddpm import DDPMSampler\n",
478
  "from utils import check_inputs, get_time_embedding, prepare_image, prepare_mask_image\n",
479
  "from diffusers.utils.torch_utils import randn_tensor\n",
480
  "\n",
481
  "class CatVTONTrainer:\n",
482
+ " \"\"\"Simplified CatVTON Training Class with PEFT, CFG and DREAM support\"\"\"\n",
483
  " \n",
484
  " def __init__(\n",
485
  " self,\n",
 
487
  " train_dataloader: DataLoader,\n",
488
  " val_dataloader: Optional[DataLoader] = None,\n",
489
  " device: str = \"cuda\",\n",
490
+ " learning_rate: float = 1e-5,\n",
491
+ " num_epochs: int = 50,\n",
492
  " save_steps: int = 1000,\n",
493
  " output_dir: str = \"./checkpoints\",\n",
494
  " cfg_dropout_prob: float = 0.1,\n",
 
 
 
495
  " max_grad_norm: float = 1.0,\n",
496
  " use_peft: bool = True,\n",
497
+ " dream_lambda: float = 10.0,\n",
498
  " resume_from_checkpoint: Optional[str] = None,\n",
499
+ " use_mixed_precision: bool = True,\n",
500
  " height=512,\n",
501
  " width=384,\n",
502
  " ):\n",
 
510
  " self.save_steps = save_steps\n",
511
  " self.output_dir = Path(output_dir)\n",
512
  " self.cfg_dropout_prob = cfg_dropout_prob\n",
 
 
 
513
  " self.max_grad_norm = max_grad_norm\n",
514
  " self.use_peft = use_peft\n",
515
  " self.dream_lambda = dream_lambda\n",
516
  " self.use_mixed_precision = use_mixed_precision\n",
517
+ " self.height = height\n",
518
+ " self.width = width\n",
519
  " self.generator = torch.Generator(device=device)\n",
520
  " \n",
521
  " # Create output directory\n",
 
536
  " if resume_from_checkpoint:\n",
537
  " self._load_checkpoint(resume_from_checkpoint)\n",
538
  " \n",
539
+ " self.encoder = self.models.get('encoder', None)\n",
540
+ " self.decoder = self.models.get('decoder', None)\n",
541
+ " self.diffusion = self.models.get('diffusion', None)\n",
542
  "\n",
543
  " # Setup models and optimizers\n",
544
  " self._setup_training()\n",
545
  " \n",
546
  " def _setup_training(self):\n",
547
  " \"\"\"Setup models for training with PEFT\"\"\"\n",
548
+ " # Move models to device\n",
549
  " for name, model in self.models.items():\n",
550
  " model.to(self.device)\n",
 
 
551
  " \n",
552
  " # Freeze all parameters first\n",
553
  " for model in self.models.values():\n",
 
559
  " self._enable_peft_training()\n",
560
  " else:\n",
561
  " # Enable full training for diffusion model\n",
562
+ " for param in self.diffusion.parameters():\n",
563
  " param.requires_grad = True\n",
564
  " \n",
565
  " # Collect trainable parameters\n",
 
577
  " print(f\"Total parameters: {total_params:,}\")\n",
578
  " print(f\"Trainable parameters: {trainable_count:,} ({trainable_count/total_params*100:.2f}%)\")\n",
579
  " \n",
 
 
 
 
 
 
580
  " # Setup optimizer - AdamW as per paper\n",
581
  " self.optimizer = AdamW(\n",
582
  " trainable_params,\n",
 
586
  " eps=1e-8\n",
587
  " )\n",
588
  " \n",
589
+ " # Setup learning rate scheduler (constant)\n",
 
590
  " self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(\n",
591
  " self.optimizer, lr_lambda=lambda epoch: 1.0\n",
592
  " )\n",
593
  " \n",
594
  " def _enable_peft_training(self):\n",
595
+ " \"\"\"Enable PEFT training - only self-attention layers\"\"\"\n",
596
  " print(\"Enabling PEFT training (self-attention layers only)\")\n",
597
  " \n",
598
  " unet = self.diffusion.unet\n",
599
  " \n",
600
+ " # Enable attention layers in encoders and decoders\n",
601
  " for layers in [unet.encoders, unet.decoders]:\n",
602
  " for layer in layers:\n",
603
+ " for module_idx, module in enumerate(layer):\n",
604
+ " for name, param in module.named_parameters():\n",
605
+ " if 'attention_1' in name:\n",
606
+ " param.requires_grad = True\n",
607
+ " \n",
608
  " # Enable attention layers in bottleneck\n",
609
  " for layer in unet.bottleneck:\n",
610
+ " for name, param in layer.named_parameters():\n",
611
+ " if 'attention_1' in name:\n",
612
  " param.requires_grad = True\n",
613
+ " \n",
614
  " def _apply_cfg_dropout(self, garment_latent: torch.Tensor) -> torch.Tensor:\n",
615
  " \"\"\"Apply classifier-free guidance dropout (10% chance)\"\"\"\n",
616
  " if self.training and random.random() < self.cfg_dropout_prob:\n",
 
624
  " cloth_images = batch['cloth'].to(self.device)\n",
625
  " masks = batch['mask'].to(self.device)\n",
626
  "\n",
627
+ " concat_dim = -2 # y axis concat\n",
628
+ " \n",
629
+ " # Prepare inputs\n",
630
  " image, condition_image, mask = check_inputs(person_images, cloth_images, masks, self.width, self.height)\n",
631
  " image = prepare_image(person_images).to(self.device, dtype=self.weight_dtype)\n",
632
  " condition_image = prepare_image(cloth_images).to(self.device, dtype=self.weight_dtype)\n",
633
  " mask = prepare_mask_image(masks).to(self.device, dtype=self.weight_dtype)\n",
634
+ " \n",
635
  " # Mask image\n",
636
  " masked_image = image * (mask < 0.5)\n",
637
  "\n",
638
  " with torch.cuda.amp.autocast(enabled=self.use_mixed_precision):\n",
639
+ " # VAE encoding\n",
640
  " masked_latent = compute_vae_encodings(masked_image, self.encoder)\n",
641
  " person_latent = compute_vae_encodings(person_images, self.encoder)\n",
642
  " condition_latent = compute_vae_encodings(condition_image, self.encoder)\n",
643
  " mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode=\"nearest\")\n",
644
+ " \n",
645
  " del image, mask, condition_image\n",
646
  "\n",
 
647
  " # Apply CFG dropout to garment latent\n",
648
  " condition_latent = self._apply_cfg_dropout(condition_latent)\n",
649
  " \n",
650
  " # Concatenate latents\n",
651
  " masked_latent_concat = torch.cat([masked_latent, condition_latent], dim=concat_dim)\n",
652
  " mask_latent_concat = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim)\n",
653
+ " target_latents = torch.cat([person_latent, condition_latent], dim=concat_dim)\n",
654
  "\n",
655
+ " noise = randn_tensor(\n",
656
  " masked_latent_concat.shape,\n",
657
  " generator=self.generator,\n",
658
  " device=masked_latent_concat.device,\n",
 
677
  " # Get initial noise prediction\n",
678
  " with torch.no_grad():\n",
679
  " epsilon_theta = self.diffusion(\n",
680
+ " inpainting_latent_model_input,\n",
681
+ " timesteps_embedding\n",
682
+ " )\n",
 
 
683
  " \n",
684
  " # Apply DREAM: zˆt = √αt*z0 + √(1-αt)*(ε + λ*εθ)\n",
685
  " alphas_cumprod = self.scheduler.alphas_cumprod.to(device=self.device, dtype=self.weight_dtype)\n",
 
701
  " masked_latent_concat\n",
702
  " ], dim=1)\n",
703
  "\n",
704
+ " predicted_noise = self.diffusion(\n",
705
  " dream_model_input,\n",
706
  " timesteps_embedding\n",
707
  " )\n",
 
720
  " return loss\n",
721
  " \n",
722
  " def train_epoch(self) -> float:\n",
723
+ " \"\"\"Train for one epoch - simplified version\"\"\"\n",
724
+ " self.diffusion.train()\n",
725
  " total_loss = 0.0\n",
726
  " num_batches = len(self.train_dataloader)\n",
727
  " \n",
728
  " progress_bar = tqdm(self.train_dataloader, desc=f\"Epoch {self.current_epoch+1}\")\n",
729
  " \n",
730
  " for step, batch in enumerate(progress_bar):\n",
731
+ " # Zero gradients\n",
732
+ " self.optimizer.zero_grad()\n",
733
+ " \n",
734
+ " # Forward pass with mixed precision\n",
735
  " if self.use_mixed_precision:\n",
736
  " with torch.cuda.amp.autocast():\n",
737
  " loss = self.compute_loss(batch)\n",
738
  " \n",
 
 
 
739
  " # Backward pass with scaling\n",
740
  " self.scaler.scale(loss).backward()\n",
741
+ " \n",
742
+ " # Gradient clipping and optimizer step\n",
743
+ " self.scaler.unscale_(self.optimizer)\n",
744
+ " torch.nn.utils.clip_grad_norm_(\n",
745
+ " [p for p in self.diffusion.parameters() if p.requires_grad],\n",
746
+ " self.max_grad_norm\n",
747
+ " )\n",
748
+ " \n",
749
+ " self.scaler.step(self.optimizer)\n",
750
+ " self.scaler.update()\n",
751
  " else:\n",
752
  " loss = self.compute_loss(batch)\n",
 
753
  " loss.backward()\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
  " \n",
755
+ " # Gradient clipping\n",
756
+ " torch.nn.utils.clip_grad_norm_(\n",
757
+ " [p for p in self.diffusion.parameters() if p.requires_grad],\n",
758
+ " self.max_grad_norm\n",
759
+ " )\n",
760
+ " \n",
761
+ " # Optimizer step\n",
762
+ " self.optimizer.step()\n",
763
+ " \n",
764
+ " # Update learning rate\n",
765
+ " self.lr_scheduler.step()\n",
766
+ " self.global_step += 1\n",
767
  " \n",
768
+ " total_loss += loss.item()\n",
769
  " \n",
770
  " # Update progress bar\n",
771
  " progress_bar.set_postfix({\n",
772
+ " 'loss': loss.item(),\n",
773
  " 'lr': self.optimizer.param_groups[0]['lr'],\n",
774
  " 'step': self.global_step\n",
775
  " })\n",
776
  " \n",
777
+ " # Save checkpoint based on steps\n",
778
  " if self.global_step % self.save_steps == 0:\n",
779
  " self._save_checkpoint()\n",
780
  " \n",
 
785
  " return total_loss / num_batches\n",
786
  " \n",
787
  " def train(self):\n",
788
+ " \"\"\"Main training loop - simplified version\"\"\"\n",
789
  " print(f\"Starting training for {self.num_epochs} epochs\")\n",
790
+ " print(f\"Total training batches per epoch: {len(self.train_dataloader)}\")\n",
791
  " print(f\"Using DREAM with lambda = {self.dream_lambda}\")\n",
792
  " print(f\"Mixed precision: {self.use_mixed_precision}\")\n",
793
  " \n",
794
  " for epoch in range(self.current_epoch, self.num_epochs):\n",
795
  " self.current_epoch = epoch\n",
796
  " \n",
797
+ " # Train one epoch\n",
798
  " train_loss = self.train_epoch()\n",
799
  " \n",
800
+ " print(f\"Epoch {epoch+1}/{self.num_epochs} - Train Loss: {train_loss:.6f}\")\n",
 
801
  " \n",
802
  " # Save epoch checkpoint\n",
803
+ " if (epoch + 1) % 5 == 0: # Save every 5 epochs\n",
804
  " self._save_checkpoint(epoch_checkpoint=True)\n",
805
  " \n",
806
  " # Clear cache at end of epoch\n",
807
  " torch.cuda.empty_cache()\n",
808
+ " \n",
809
+ " # Save final checkpoint\n",
810
+ " self._save_checkpoint(is_final=True)\n",
811
+ " print(\"Training completed!\")\n",
812
  " \n",
813
+ " def _save_checkpoint(self, is_best: bool = False, epoch_checkpoint: bool = False, is_final: bool = False):\n",
814
  " \"\"\"Save model checkpoint\"\"\"\n",
815
  " checkpoint = {\n",
816
  " 'global_step': self.global_step,\n",
 
825
  " if self.use_mixed_precision:\n",
826
  " checkpoint['scaler_state_dict'] = self.scaler.state_dict()\n",
827
  " \n",
828
+ " if is_final:\n",
829
+ " checkpoint_path = self.output_dir / \"final_model.pth\"\n",
830
+ " elif is_best:\n",
831
  " checkpoint_path = self.output_dir / \"best_model.pth\"\n",
832
  " elif epoch_checkpoint:\n",
833
  " checkpoint_path = self.output_dir / f\"checkpoint_epoch_{self.current_epoch+1}.pth\"\n",
 
855
  " print(f\"Checkpoint loaded: {checkpoint_path}\")\n",
856
  " print(f\"Resuming from epoch {self.current_epoch}, step {self.global_step}\")\n",
857
  "\n",
858
+ "\n",
859
+ "def create_dataloaders(args) -> DataLoader:\n",
860
+ " \"\"\"Create training dataloader\"\"\"\n",
861
  " if args.dataset_name == \"vitonhd\":\n",
862
  " dataset = VITONHDTestDataset(args)\n",
863
  " else:\n",
864
+ " raise ValueError(f\"Invalid dataset name {args.dataset_name}.\")\n",
865
+ " \n",
866
  " print(f\"Dataset {args.dataset_name} loaded, total {len(dataset)} pairs.\")\n",
867
+ " \n",
868
  " dataloader = DataLoader(\n",
869
  " dataset,\n",
870
  " batch_size=args.batch_size,\n",
871
+ " shuffle=True,\n",
872
+ " num_workers=8,\n",
873
+ " pin_memory=True,\n",
874
+ " persistent_workers=True,\n",
875
+ " prefetch_factor=2\n",
876
  " )\n",
877
  " \n",
878
  " return dataloader\n",
879
  "\n",
880
  "\n",
881
  "def main():\n",
882
+ " args = argparse.Namespace()\n",
883
+ " args.__dict__ = {\n",
884
  " \"base_model_path\": \"sd-v1-5-inpainting.ckpt\",\n",
 
885
  " \"dataset_name\": \"vitonhd\",\n",
886
  " \"data_root_path\": \"/kaggle/input/viton-hd-dataset\",\n",
887
+ " \"output_dir\": \"./checkpoints\",\n",
888
+ " \"resume_from_checkpoint\": None,\n",
889
  " \"seed\": 42,\n",
890
  " \"batch_size\": 1,\n",
 
 
891
  " \"width\": 384,\n",
892
  " \"height\": 512,\n",
893
  " \"repaint\": True,\n",
894
  " \"eval_pair\": True,\n",
895
  " \"concat_eval_results\": True,\n",
 
 
 
896
  " \"concat_axis\": 'y',\n",
897
+ " \"device\": \"cuda\",\n",
898
+ " \"num_epochs\": 50, \n",
 
899
  " \"learning_rate\": 1e-5,\n",
 
900
  " \"max_grad_norm\": 1.0,\n",
 
901
  " \"cfg_dropout_prob\": 0.1,\n",
902
  " \"dream_lambda\": 0,\n",
903
+ " \"use_peft\": True,\n",
904
  " \"use_mixed_precision\": True,\n",
 
905
  " \"save_steps\": 1000,\n",
 
906
  " \"is_train\": True\n",
907
  " }\n",
908
  " \n",
 
 
 
909
  " # Set random seeds\n",
910
  " torch.manual_seed(args.seed)\n",
911
  " np.random.seed(args.seed)\n",
 
913
  " if torch.cuda.is_available():\n",
914
  " torch.cuda.manual_seed_all(args.seed)\n",
915
  " \n",
916
+ " # Optimize CUDA settings\n",
917
  " torch.backends.cudnn.benchmark = True\n",
918
+ " torch.backends.cuda.matmul.allow_tf32 = True \n",
919
+ " torch.backends.cudnn.allow_tf32 = True \n",
920
  " torch.set_float32_matmul_precision(\"high\")\n",
921
  "\n",
922
  " # Load pretrained models\n",
923
  " print(\"Loading pretrained models...\")\n",
924
  " models = preload_models_from_standard_weights(args.base_model_path, args.device)\n",
925
  " \n",
926
+ " # Create dataloader\n",
927
+ " print(\"Creating dataloader...\")\n",
928
  " train_dataloader = create_dataloaders(args)\n",
929
  " \n",
930
+ " print(f\"Training for {args.num_epochs} epochs\")\n",
931
+ " print(f\"Batches per epoch: {len(train_dataloader)}\")\n",
932
+ " \n",
 
 
 
 
 
933
  " # Initialize trainer\n",
934
  " print(\"Initializing trainer...\") \n",
935
  " trainer = CatVTONTrainer(\n",
 
941
  " save_steps=args.save_steps,\n",
942
  " output_dir=args.output_dir,\n",
943
  " cfg_dropout_prob=args.cfg_dropout_prob,\n",
 
 
 
944
  " max_grad_norm=args.max_grad_norm,\n",
945
  " use_peft=args.use_peft,\n",
946
  " dream_lambda=args.dream_lambda,\n",
947
  " resume_from_checkpoint=args.resume_from_checkpoint,\n",
948
+ " use_mixed_precision=args.use_mixed_precision,\n",
949
+ " height=args.height,\n",
950
+ " width=args.width\n",
951
  " )\n",
952
+ " \n",
953
  " # Start training\n",
954
  " print(\"Starting training...\")\n",
955
  " trainer.train() \n",
956
  "\n",
957
+ "\n",
958
  "if __name__ == \"__main__\":\n",
959
+ " main()"
960
  ]
961
  },
962
  {
963
  "cell_type": "code",
964
  "execution_count": null,
965
+ "id": "77892d6a",
966
  "metadata": {},
967
  "outputs": [],
968
  "source": []
 
970
  {
971
  "cell_type": "code",
972
  "execution_count": null,
973
+ "id": "b3917d76",
974
  "metadata": {},
975
  "outputs": [],
976
  "source": []