AlbeRota commited on
Commit
af15687
·
verified ·
1 Parent(s): afc472b

Upload weights, notebooks, sample images

Browse files
notebooks/UnReflectAnything.ipynb ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# UnReflectAnything API & CLI Examples\n",
8
+ "---\n",
9
+ "\n",
10
+ "### 1. Installation and assets download\n",
11
+ "Ensure you have installed UnReflectAnything with \n",
12
+ "```bash\n",
13
+ "pip install unreflectanything\n",
14
+ "```\n",
15
+ "this will also install the CLI, which is also callable with aliases `unreflect` and `ura`. Verify installation and check the version with:\n",
16
+ "```bash\n",
17
+ "unreflectanything --help\n",
18
+ "```\n",
19
+ "```bash\n",
20
+ "unreflect --version\n",
21
+ "```\n",
22
+ "```bash\n",
23
+ "ura --version\n",
24
+ "```"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 31,
30
+ "metadata": {},
31
+ "outputs": [
32
+ {
33
+ "name": "stdout",
34
+ "output_type": "stream",
35
+ "text": [
36
+ "Using device: cuda\n"
37
+ ]
38
+ }
39
+ ],
40
+ "source": [
41
+ "import torch\n",
42
+ "from pathlib import Path\n",
43
+ "\n",
44
+ "# Import UnreflectAnything!\n",
45
+ "import unreflectanything\n",
46
+ "\n",
47
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
48
+ "print(f\"Using device: {device}\")"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "markdown",
53
+ "metadata": {},
54
+ "source": [
55
+ "`pip install`ing UnReflectAnything does not download the pretrained model weights. Download them with the cli command\n",
56
+ "```bash\n",
57
+ "unrefleactanything download --weights\n",
58
+ "```\n",
59
+ "or"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": 32,
65
+ "metadata": {},
66
+ "outputs": [
67
+ {
68
+ "data": {
69
+ "application/vnd.jupyter.widget-view+json": {
70
+ "model_id": "e2614e279f5f40609b18ac82d696b01d",
71
+ "version_major": 2,
72
+ "version_minor": 0
73
+ },
74
+ "text/plain": [
75
+ "Fetching 4 files: 0%| | 0/4 [00:00<?, ?it/s]"
76
+ ]
77
+ },
78
+ "metadata": {},
79
+ "output_type": "display_data"
80
+ },
81
+ {
82
+ "data": {
83
+ "application/vnd.jupyter.widget-view+json": {
84
+ "model_id": "7293fbadfcc74c438a3cb5a0e4df1ac2",
85
+ "version_major": 2,
86
+ "version_minor": 0
87
+ },
88
+ "text/plain": [
89
+ "weights/diffuse_decoder.pt: 0%| | 0.00/418M [00:00<?, ?B/s]"
90
+ ]
91
+ },
92
+ "metadata": {},
93
+ "output_type": "display_data"
94
+ },
95
+ {
96
+ "data": {
97
+ "application/vnd.jupyter.widget-view+json": {
98
+ "model_id": "f1253cb4b21d48c2b0c3780f4c4060ef",
99
+ "version_major": 2,
100
+ "version_minor": 0
101
+ },
102
+ "text/plain": [
103
+ "weights/token_inpainter.pt: 0%| | 0.00/307M [00:00<?, ?B/s]"
104
+ ]
105
+ },
106
+ "metadata": {},
107
+ "output_type": "display_data"
108
+ },
109
+ {
110
+ "data": {
111
+ "application/vnd.jupyter.widget-view+json": {
112
+ "model_id": "74ed792f61264beab4a890ea5ede8a58",
113
+ "version_major": 2,
114
+ "version_minor": 0
115
+ },
116
+ "text/plain": [
117
+ "weights/highlight_decoder.pt: 0%| | 0.00/54.8M [00:00<?, ?B/s]"
118
+ ]
119
+ },
120
+ "metadata": {},
121
+ "output_type": "display_data"
122
+ },
123
+ {
124
+ "data": {
125
+ "application/vnd.jupyter.widget-view+json": {
126
+ "model_id": "1ea92cd82ad34af68f06be823dd12060",
127
+ "version_major": 2,
128
+ "version_minor": 0
129
+ },
130
+ "text/plain": [
131
+ "weights/full_model_weights.pt: 0%| | 0.00/3.55G [00:00<?, ?B/s]"
132
+ ]
133
+ },
134
+ "metadata": {},
135
+ "output_type": "display_data"
136
+ },
137
+ {
138
+ "name": "stdout",
139
+ "output_type": "stream",
140
+ "text": [
141
+ "Weights saved to /home/arota/.cache/unreflectanything/weights\n"
142
+ ]
143
+ }
144
+ ],
145
+ "source": [
146
+ "weights_dir = unreflectanything.download(\"weights\")"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "markdown",
151
+ "id": "511c670a",
152
+ "metadata": {},
153
+ "source": [
154
+ "Download some sample images which will be used in this notebook with \n",
155
+ "```bash\n",
156
+ "unrefleactanything download --images\n",
157
+ "```\n",
158
+ "or"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "id": "e0c60b28",
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": [
168
+ "images_dir = unreflectanything.download(\"images\")"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "markdown",
173
+ "id": "36417577",
174
+ "metadata": {},
175
+ "source": [
176
+ "### 2. Running UnReflectAnything with pretrained weights"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "id": "c318961c",
183
+ "metadata": {},
184
+ "outputs": [],
185
+ "source": [
186
+ "# Instantating the pretrained default UnreflectAnything model. \n",
187
+ "unreflect = unreflectanything.model(device=device)"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": null,
193
+ "id": "448d5456",
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "from PIL import Image\n",
198
+ "import numpy as np\n",
199
+ "\n",
200
+ "# Building a simple dataloader on a simple dataset that loads from a dir of images\n",
201
+ "sample_dataset = unreflectanything.ImageDirDataset(images_dir)\n",
202
+ "sample_dataloader = torch.utils.data.DataLoader(\n",
203
+ " sample_dataset, batch_size=1, shuffle=False\n",
204
+ ")\n",
205
+ "\n",
206
+ "# Threshold and Dilation in inpaint mask can be overridden; defaults 0.2 and 40\n",
207
+ "THRESHOLD = 0.2\n",
208
+ "DILATION = 40\n",
209
+ "\n",
210
+ "# Process and display only N images out of the full sample dataset\n",
211
+ "DISPLAY_N_IMAGES = 2\n",
212
+ "\n",
213
+ "outputs = []\n",
214
+ "for batch in sample_dataloader:\n",
215
+ " # Forward pass\n",
216
+ " batch_output = unreflect(\n",
217
+ " batch.to(device), return_dict=True, threshold=THRESHOLD, dilation=DILATION\n",
218
+ " )\n",
219
+ " outputs.append(batch_output)\n",
220
+ " if len(outputs) >= DISPLAY_N_IMAGES:\n",
221
+ " break\n",
222
+ "\n"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": null,
228
+ "id": "98c60bf4",
229
+ "metadata": {},
230
+ "outputs": [],
231
+ "source": [
232
+ "# Helper: Convert tensor [H, W, C] in [0,1] float32 to uint8 to display them \n",
233
+ "def tensor_to_uint8_img(t):\n",
234
+ " arr = t.permute(1, 2, 0).detach().numpy()\n",
235
+ " arr = np.clip(arr, 0, 1)\n",
236
+ " arr = (arr * 255).round().astype(np.uint8)\n",
237
+ " return arr\n",
238
+ "\n",
239
+ "# Plotting a collage of the input, the diffuse output, and the highlight mask\n",
240
+ "for input_batch, output_batch in zip(sample_dataloader, outputs):\n",
241
+ " concat_images = torch.cat(\n",
242
+ " [\n",
243
+ " input_batch.cpu(),\n",
244
+ " output_batch[\"diffuse\"].cpu(),\n",
245
+ " output_batch[\"highlight\"].repeat(1, 3, 1, 1).cpu(), # \n",
246
+ " ],\n",
247
+ " dim=3,\n",
248
+ " )\n",
249
+ " for sample in concat_images:\n",
250
+ " img_uint8 = tensor_to_uint8_img(sample)\n",
251
+ " display(Image.fromarray(img_uint8))\n",
252
+ " # break\n"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "markdown",
257
+ "id": "cf0e6ac6",
258
+ "metadata": {},
259
+ "source": [
260
+ "### 3. Inference API and CLI endpoint\n",
261
+ "The `inference` wrapper instantiates the UnReflectAnything model and calls its forward function is a single API call. It either:\n",
262
+ "- Inputs a batched image tensor and outputs a batched image tensor\n",
263
+ "- Inputs the path to an image (or directory of images) and saves the output results at a given path (of file or directory)\n",
264
+ "- Inputs the path to an image and outputs a batched image tensor \n",
265
+ "\n",
266
+ "Some example CLI calls:\n",
267
+ "```bash\n",
268
+ "unreflect inference path/to/image/dir/ -o output/dir/ --threshold 0.3 --dilation 40\n",
269
+ "```\n",
270
+ "```bash\n",
271
+ "unreflect inference path/to/image.png -o path/to/output.png --threshold 0.3 --dilation 40\n",
272
+ "```"
273
+ ]
274
+ },
275
+ {
276
+ "cell_type": "code",
277
+ "execution_count": null,
278
+ "metadata": {},
279
+ "outputs": [],
280
+ "source": [
281
+ "# Pick a sample image from the downloaded assets. `input` can also be the path to a dir\n",
282
+ "input_path = list(images_dir.glob(\"*.png\"))[0]\n",
283
+ "print(\"Input file: \", input_path)\n",
284
+ "# Specify the outptut name. If `input` is a path to a dir, `output` should be too.\n",
285
+ "output_path = Path(\"output_example.png\").resolve()\n",
286
+ "print(\"Output file: \", output_path)\n",
287
+ "\n",
288
+ "unreflectanything.inference(\n",
289
+ " input=input_path,\n",
290
+ " output=output_path,\n",
291
+ " device=device,\n",
292
+ " threshold=THRESHOLD, \n",
293
+ " dilation=DILATION, \n",
294
+ ")\n",
295
+ "\n",
296
+ "# Loading the saved output and original input from files, then displaying them\n",
297
+ "input_img = Image.open(input_path).convert(\"RGB\")\n",
298
+ "output_img = Image.open(output_path).convert(\"RGB\")\n",
299
+ "\n",
300
+ "def to_tensor(img):\n",
301
+ " return torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.\n",
302
+ "\n",
303
+ "input_tensor = to_tensor(input_img)\n",
304
+ "output_tensor = to_tensor(output_img)\n",
305
+ "concat = torch.cat([input_tensor, output_tensor], dim=2)\n",
306
+ "concat_uint8 = (concat.permute(1,2,0).numpy() * 255).clip(0,255).astype(np.uint8)\n",
307
+ "display(Image.fromarray(concat_uint8))"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": null,
313
+ "id": "5118ea92",
314
+ "metadata": {},
315
+ "outputs": [],
316
+ "source": [
317
+ "print(\"Equivalent CLI command:\\n\")\n",
318
+ "print(f\"unreflect inference {input_path} -o {output_path} --threshold {THRESHOLD} --dilation {DILATION}\")"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "markdown",
323
+ "id": "562c17f1",
324
+ "metadata": {},
325
+ "source": [
326
+ "`inference` initializes the model every time by default. To run it without this step, pass them model to the API call"
327
+ ]
328
+ },
329
+ {
330
+ "cell_type": "code",
331
+ "execution_count": null,
332
+ "id": "89fbbf62",
333
+ "metadata": {},
334
+ "outputs": [],
335
+ "source": [
336
+ "# Pick a sample image from the downloaded assets. `input` can also be the path to a dir\n",
337
+ "input_path = list(images_dir.glob(\"*.png\"))[6]\n",
338
+ "# Specify the outptu name\n",
339
+ "output_path = Path(\"output_example.png\")\n",
340
+ " \n",
341
+ "unreflectanything.inference(\n",
342
+ " model=unreflect, # <<<<<<<<< Pass the model instance and it won't be loaded at every `inference` call\n",
343
+ " input=input_path,\n",
344
+ " output=output_path,\n",
345
+ " device=device,\n",
346
+ " threshold=THRESHOLD, \n",
347
+ " dilation=DILATION, \n",
348
+ ")\n",
349
+ "\n",
350
+ "# Loading the saved output and original input from files, then displaying them\n",
351
+ "input_img = Image.open(input_path).convert(\"RGB\")\n",
352
+ "output_img = Image.open(output_path).convert(\"RGB\")\n",
353
+ "\n",
354
+ "def to_tensor(img):\n",
355
+ " return torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.\n",
356
+ "\n",
357
+ "input_tensor = to_tensor(input_img)\n",
358
+ "output_tensor = to_tensor(output_img)\n",
359
+ "concat = torch.cat([input_tensor, output_tensor], dim=2)\n",
360
+ "concat_uint8 = (concat.permute(1,2,0).numpy() * 255).clip(0,255).astype(np.uint8)\n",
361
+ "display(Image.fromarray(concat_uint8))"
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "markdown",
366
+ "id": "57441af0",
367
+ "metadata": {},
368
+ "source": [
369
+ "### 4. The Cache Directory\n",
370
+ "\n",
371
+ "`unreflectanything download` saves the downloaded asset in your system cache. Print this path with\n",
372
+ "```bash\n",
373
+ "unreflectanything cache --dir\n",
374
+ "```\n",
375
+ "or clear the cache with \n",
376
+ "```bash\n",
377
+ "unreflectanything cache --clear\n",
378
+ "```\n",
379
+ "The same endopoints are also on the API"
380
+ ]
381
+ },
382
+ {
383
+ "cell_type": "code",
384
+ "execution_count": null,
385
+ "id": "b331050d",
386
+ "metadata": {},
387
+ "outputs": [],
388
+ "source": [
389
+ "unreflectanything.cache(\"dir\") # Also unreflectanything.cache()\n",
390
+ "unreflectanything.cache(\"clear\")\n",
391
+ "# unreflectanything.cache.clear()"
392
+ ]
393
+ },
394
+ {
395
+ "cell_type": "markdown",
396
+ "metadata": {},
397
+ "source": [
398
+ "## 4. Verify Assets\n",
399
+ "\n",
400
+ "You can verify that the weights are correctly downloaded and loadable."
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": null,
406
+ "metadata": {},
407
+ "outputs": [],
408
+ "source": [
409
+ "is_valid = unreflectanything.verify(\"weights\")"
410
+ ]
411
+ },
412
+ {
413
+ "cell_type": "markdown",
414
+ "metadata": {},
415
+ "source": [
416
+ "### CLI Equivalent\n",
417
+ "\n",
418
+ "```bash\n",
419
+ "unreflect verify --weights\n",
420
+ "```\n",
421
+ "```bash\n",
422
+ "unreflect verify --weights\n",
423
+ "```"
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "code",
428
+ "execution_count": null,
429
+ "id": "85947e38",
430
+ "metadata": {},
431
+ "outputs": [],
432
+ "source": []
433
+ },
434
+ {
435
+ "cell_type": "markdown",
436
+ "metadata": {},
437
+ "source": [
438
+ "## 5. Cite\n",
439
+ "\n",
440
+ "If you use UnReflectAnything in your research, please cite it:"
441
+ ]
442
+ },
443
+ {
444
+ "cell_type": "code",
445
+ "execution_count": null,
446
+ "metadata": {},
447
+ "outputs": [],
448
+ "source": [
449
+ "print(ura.cite(format=\"bibtex\"))"
450
+ ]
451
+ },
452
+ {
453
+ "cell_type": "markdown",
454
+ "metadata": {},
455
+ "source": [
456
+ "### CLI Equivalent\n",
457
+ "\n",
458
+ "```bash\n",
459
+ "unreflect cite --bibtex\n",
460
+ "```"
461
+ ]
462
+ }
463
+ ],
464
+ "metadata": {
465
+ "kernelspec": {
466
+ "display_name": "Python 3 (ipykernel)",
467
+ "language": "python",
468
+ "name": "python3"
469
+ },
470
+ "language_info": {
471
+ "codemirror_mode": {
472
+ "name": "ipython",
473
+ "version": 3
474
+ },
475
+ "file_extension": ".py",
476
+ "mimetype": "text/x-python",
477
+ "name": "python",
478
+ "nbconvert_exporter": "python",
479
+ "pygments_lexer": "ipython3",
480
+ "version": "3.12.3"
481
+ }
482
+ },
483
+ "nbformat": 4,
484
+ "nbformat_minor": 5
485
+ }
notebooks/api_examples.ipynb DELETED
@@ -1,253 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "d5e78019",
6
- "metadata": {},
7
- "source": [
8
- "# UnReflectAnything API Examples\n",
9
- "---"
10
- ]
11
- },
12
- {
13
- "cell_type": "markdown",
14
- "id": "d423248d",
15
- "metadata": {},
16
- "source": [
17
- "### Package Import"
18
- ]
19
- },
20
- {
21
- "cell_type": "code",
22
- "execution_count": 1,
23
- "id": "db2eda79",
24
- "metadata": {},
25
- "outputs": [
26
- {
27
- "name": "stdout",
28
- "output_type": "stream",
29
- "text": [
30
- "Using device: cuda\n"
31
- ]
32
- }
33
- ],
34
- "source": [
35
- "import unreflectanything\n",
36
- "import torch\n",
37
- "\n",
38
- "%load_ext autoreload\n",
39
- "%autoreload 2\n",
40
- "\n",
41
- "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
42
- "print(f\"Using device: {device}\")"
43
- ]
44
- },
45
- {
46
- "cell_type": "markdown",
47
- "id": "c3828c5e",
48
- "metadata": {},
49
- "source": [
50
- "### Model Loading"
51
- ]
52
- },
53
- {
54
- "cell_type": "markdown",
55
- "id": "cabb1b8a",
56
- "metadata": {},
57
- "source": [
58
- "If you haven't downloaded the pre-trained weights yet, do so with \n",
59
- "\n",
60
- "`unreflectanything download --weights` from the terminal\n",
61
- "\n",
62
- "\n",
63
- "or with `unreflectanything.download(\"weights\")` from Python."
64
- ]
65
- },
66
- {
67
- "cell_type": "code",
68
- "execution_count": 6,
69
- "id": "d58ad7f1",
70
- "metadata": {},
71
- "outputs": [
72
- {
73
- "data": {
74
- "text/html": [
75
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:45:03</span><span style=\"font-weight: bold\">]</span> ✓ Decoder <span style=\"color: #008000; text-decoration-color: #008000\">'diffuse'</span>: Successfully loaded all <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">54</span> state dict keys from weights/rgb_decoder.pth\n",
76
- "</pre>\n"
77
- ],
78
- "text/plain": [
79
- "MODEL \u001b[1m[\u001b[0m\u001b[1;92m18:45:03\u001b[0m\u001b[1m]\u001b[0m ✓ Decoder \u001b[32m'diffuse'\u001b[0m: Successfully loaded all \u001b[1;36m54\u001b[0m state dict keys from weights/rgb_decoder.pth\n"
80
- ]
81
- },
82
- "metadata": {},
83
- "output_type": "display_data"
84
- },
85
- {
86
- "data": {
87
- "text/html": [
88
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:45:03</span><span style=\"font-weight: bold\">]</span> Loaded pre-trained decoder weights from weights/rgb_decoder.pth\n",
89
- "</pre>\n"
90
- ],
91
- "text/plain": [
92
- "MODEL \u001b[1m[\u001b[0m\u001b[1;92m18:45:03\u001b[0m\u001b[1m]\u001b[0m Loaded pre-trained decoder weights from weights/rgb_decoder.pth\n"
93
- ]
94
- },
95
- "metadata": {},
96
- "output_type": "display_data"
97
- },
98
- {
99
- "data": {
100
- "text/html": [
101
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:45:03</span><span style=\"font-weight: bold\">]</span> ✓ Token Inpainter: Successfully loaded all <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">78</span> state dict keys from weights/token_inpainter.pth\n",
102
- "</pre>\n"
103
- ],
104
- "text/plain": [
105
- "MODEL \u001b[1m[\u001b[0m\u001b[1;92m18:45:03\u001b[0m\u001b[1m]\u001b[0m ✓ Token Inpainter: Successfully loaded all \u001b[1;36m78\u001b[0m state dict keys from weights/token_inpainter.pth\n"
106
- ]
107
- },
108
- "metadata": {},
109
- "output_type": "display_data"
110
- },
111
- {
112
- "data": {
113
- "text/html": [
114
- "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">MODEL <span style=\"font-weight: bold\">[</span><span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">18:45:03</span><span style=\"font-weight: bold\">]</span> Loaded pretrained token inpainter weights from weights/token_inpainter.pth\n",
115
- "</pre>\n"
116
- ],
117
- "text/plain": [
118
- "MODEL \u001b[1m[\u001b[0m\u001b[1;92m18:45:03\u001b[0m\u001b[1m]\u001b[0m Loaded pretrained token inpainter weights from weights/token_inpainter.pth\n"
119
- ]
120
- },
121
- "metadata": {},
122
- "output_type": "display_data"
123
- },
124
- {
125
- "name": "stdout",
126
- "output_type": "stream",
127
- "text": [
128
- "Warning: missing keys when loading checkpoint: ['decoders.highlight.reassemble_layers.0.proj.weight', 'decoders.highlight.reassemble_layers.0.proj.bias', 'decoders.highlight.reassemble_layers.0.resample.weight', 'decoders.highlight.reassemble_layers.0.resample.bias', 'decoders.highlight.reassemble_layers.1.proj.weight', 'decoders.highlight.reassemble_layers.1.proj.bias', 'decoders.highlight.reassemble_layers.1.resample.weight', 'decoders.highlight.reassemble_layers.1.resample.bias', 'decoders.highlight.reassemble_layers.2.proj.weight', 'decoders.highlight.reassemble_layers.2.proj.bias', 'decoders.highlight.reassemble_layers.3.proj.weight', 'decoders.highlight.reassemble_layers.3.proj.bias', 'decoders.highlight.reassemble_layers.3.resample.weight', 'decoders.highlight.reassemble_layers.3.resample.bias', 'decoders.highlight.fusion_blocks.0.residual_conv1.weight', 'decoders.highlight.fusion_blocks.0.residual_conv1.bias', 'decoders.highlight.fusion_blocks.0.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.0.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.0.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.0.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.0.out_conv.weight', 'decoders.highlight.fusion_blocks.0.out_conv.bias', 'decoders.highlight.fusion_blocks.1.residual_conv1.weight', 'decoders.highlight.fusion_blocks.1.residual_conv1.bias', 'decoders.highlight.fusion_blocks.1.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.1.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.1.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.1.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.1.out_conv.weight', 'decoders.highlight.fusion_blocks.1.out_conv.bias', 'decoders.highlight.fusion_blocks.2.residual_conv1.weight', 'decoders.highlight.fusion_blocks.2.residual_conv1.bias', 'decoders.highlight.fusion_blocks.2.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.2.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.2.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.2.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.2.out_conv.weight', 'decoders.highlight.fusion_blocks.2.out_conv.bias', 'decoders.highlight.fusion_blocks.3.residual_conv1.weight', 'decoders.highlight.fusion_blocks.3.residual_conv1.bias', 'decoders.highlight.fusion_blocks.3.residual_conv2.0.weight', 'decoders.highlight.fusion_blocks.3.residual_conv2.0.bias', 'decoders.highlight.fusion_blocks.3.residual_conv2.3.weight', 'decoders.highlight.fusion_blocks.3.residual_conv2.3.bias', 'decoders.highlight.fusion_blocks.3.out_conv.weight', 'decoders.highlight.fusion_blocks.3.out_conv.bias', 'decoders.highlight.rgb_head.0.weight', 'decoders.highlight.rgb_head.0.bias', 'decoders.highlight.rgb_head.5.weight', 'decoders.highlight.rgb_head.5.bias', 'decoders.highlight.rgb_head.9.weight', 'decoders.highlight.rgb_head.9.bias', 'decoders.highlight.rgb_head.13.weight', 'decoders.highlight.rgb_head.13.bias', 'token_inpaint.mask_token', 'token_inpaint.mask_indicator', 'token_inpaint.blocks.0.attn.norm.weight', 'token_inpaint.blocks.0.attn.norm.bias', 'token_inpaint.blocks.0.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.0.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.0.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.0.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.0.mlp.norm.weight', 'token_inpaint.blocks.0.mlp.norm.bias', 'token_inpaint.blocks.0.mlp.fn.fc1.weight', 'token_inpaint.blocks.0.mlp.fn.fc1.bias', 'token_inpaint.blocks.0.mlp.fn.fc2.weight', 'token_inpaint.blocks.0.mlp.fn.fc2.bias', 'token_inpaint.blocks.1.attn.norm.weight', 'token_inpaint.blocks.1.attn.norm.bias', 'token_inpaint.blocks.1.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.1.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.1.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.1.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.1.mlp.norm.weight', 'token_inpaint.blocks.1.mlp.norm.bias', 'token_inpaint.blocks.1.mlp.fn.fc1.weight', 'token_inpaint.blocks.1.mlp.fn.fc1.bias', 'token_inpaint.blocks.1.mlp.fn.fc2.weight', 'token_inpaint.blocks.1.mlp.fn.fc2.bias', 'token_inpaint.blocks.2.attn.norm.weight', 'token_inpaint.blocks.2.attn.norm.bias', 'token_inpaint.blocks.2.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.2.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.2.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.2.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.2.mlp.norm.weight', 'token_inpaint.blocks.2.mlp.norm.bias', 'token_inpaint.blocks.2.mlp.fn.fc1.weight', 'token_inpaint.blocks.2.mlp.fn.fc1.bias', 'token_inpaint.blocks.2.mlp.fn.fc2.weight', 'token_inpaint.blocks.2.mlp.fn.fc2.bias', 'token_inpaint.blocks.3.attn.norm.weight', 'token_inpaint.blocks.3.attn.norm.bias', 'token_inpaint.blocks.3.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.3.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.3.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.3.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.3.mlp.norm.weight', 'token_inpaint.blocks.3.mlp.norm.bias', 'token_inpaint.blocks.3.mlp.fn.fc1.weight', 'token_inpaint.blocks.3.mlp.fn.fc1.bias', 'token_inpaint.blocks.3.mlp.fn.fc2.weight', 'token_inpaint.blocks.3.mlp.fn.fc2.bias', 'token_inpaint.blocks.4.attn.norm.weight', 'token_inpaint.blocks.4.attn.norm.bias', 'token_inpaint.blocks.4.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.4.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.4.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.4.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.4.mlp.norm.weight', 'token_inpaint.blocks.4.mlp.norm.bias', 'token_inpaint.blocks.4.mlp.fn.fc1.weight', 'token_inpaint.blocks.4.mlp.fn.fc1.bias', 'token_inpaint.blocks.4.mlp.fn.fc2.weight', 'token_inpaint.blocks.4.mlp.fn.fc2.bias', 'token_inpaint.blocks.5.attn.norm.weight', 'token_inpaint.blocks.5.attn.norm.bias', 'token_inpaint.blocks.5.attn.fn.attn.in_proj_weight', 'token_inpaint.blocks.5.attn.fn.attn.in_proj_bias', 'token_inpaint.blocks.5.attn.fn.attn.out_proj.weight', 'token_inpaint.blocks.5.attn.fn.attn.out_proj.bias', 'token_inpaint.blocks.5.mlp.norm.weight', 'token_inpaint.blocks.5.mlp.norm.bias', 'token_inpaint.blocks.5.mlp.fn.fc1.weight', 'token_inpaint.blocks.5.mlp.fn.fc1.bias', 'token_inpaint.blocks.5.mlp.fn.fc2.weight', 'token_inpaint.blocks.5.mlp.fn.fc2.bias', 'token_inpaint.out_proj.weight', 'token_inpaint.out_proj.bias', 'token_inpaint._final_norm.weight', 'token_inpaint._final_norm.bias']\n"
129
- ]
130
- }
131
- ],
132
- "source": [
133
- "# unreflectanything.download(\"weights\")\n",
134
- "# unreflectanything.download(\"images\") # --> Loads 20 sample images\n",
135
- "unreflectanythingmodel = unreflectanything.model(pretrained=True)"
136
- ]
137
- },
138
- {
139
- "cell_type": "markdown",
140
- "id": "f3dfa889",
141
- "metadata": {},
142
- "source": [
143
- "Load a dataset of images. Change `PATH_TO_IMAGE_DIR` to point to your own image directory"
144
- ]
145
- },
146
- {
147
- "cell_type": "code",
148
- "execution_count": null,
149
- "id": "da39fa39",
150
- "metadata": {},
151
- "outputs": [],
152
- "source": [
153
- "from unreflectanything import ImageDirDataset, get_cache_dir\n",
154
- "from torch.utils.data import DataLoader\n",
155
- "\n",
156
- "PATH_TO_IMAGE_DIR = get_cache_dir(\n",
157
- " \"images\"\n",
158
- ") # Modify this path to point to your image directory\n",
159
- "\n",
160
- "ds = ImageDirDataset(PATH_TO_IMAGE_DIR, target_size=(448, 448), return_path=False)\n",
161
- "loader = DataLoader(ds, batch_size=1, shuffle=False)"
162
- ]
163
- },
164
- {
165
- "cell_type": "markdown",
166
- "id": "4c8312f0",
167
- "metadata": {},
168
- "source": [
169
- "### Forward Pass / Inference"
170
- ]
171
- },
172
- {
173
- "cell_type": "code",
174
- "execution_count": 8,
175
- "id": "34e01754",
176
- "metadata": {},
177
- "outputs": [],
178
- "source": [
179
- "output_images = [unreflectanythingmodel(batch_images) for batch_images in loader]"
180
- ]
181
- },
182
- {
183
- "cell_type": "markdown",
184
- "id": "94690751",
185
- "metadata": {},
186
- "source": [
187
- "### Displaying results"
188
- ]
189
- },
190
- {
191
- "cell_type": "code",
192
- "execution_count": 9,
193
- "id": "a130c042",
194
- "metadata": {},
195
- "outputs": [
196
- {
197
- "ename": "RuntimeError",
198
- "evalue": "Sizes of tensors must match except in dimension 3. Expected size 896 but got size 448 for tensor number 1 in the list.",
199
- "output_type": "error",
200
- "traceback": [
201
- "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
202
- "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)",
203
- "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[9]\u001b[39m\u001b[32m, line 14\u001b[39m\n\u001b[32m 10\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m arr\n\u001b[32m 13\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m input_batch, output_batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(loader, output_images):\n\u001b[32m---> \u001b[39m\u001b[32m14\u001b[39m concat_images = \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcat\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 15\u001b[39m \u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43minput_batch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcpu\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_batch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcpu\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m3\u001b[39;49m\n\u001b[32m 16\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# (B, 3, H, 2W)\u001b[39;00m\n\u001b[32m 17\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m sample \u001b[38;5;129;01min\u001b[39;00m concat_images:\n\u001b[32m 18\u001b[39m img_uint8 = tensor_to_uint8_img(sample)\n",
204
- "\u001b[31mRuntimeError\u001b[39m: Sizes of tensors must match except in dimension 3. Expected size 896 but got size 448 for tensor number 1 in the list."
205
- ]
206
- }
207
- ],
208
- "source": [
209
- "from PIL import Image\n",
210
- "import numpy as np\n",
211
- "\n",
212
- "\n",
213
- "# Helper: Convert tensor [H, W, C] in [0,1] float32 to uint8\n",
214
- "def tensor_to_uint8_img(t):\n",
215
- " arr = t.permute(1, 2, 0).cpu().detach().numpy()\n",
216
- " arr = np.clip(arr, 0, 1)\n",
217
- " arr = (arr * 255).round().astype(np.uint8)\n",
218
- " return arr\n",
219
- "\n",
220
- "\n",
221
- "for input_batch, output_batch in zip(loader, output_images):\n",
222
- " concat_images = torch.cat(\n",
223
- " [input_batch.cpu(), output_batch.cpu()], dim=3\n",
224
- " ) # (B, 3, H, 2W)\n",
225
- " for sample in concat_images:\n",
226
- " img_uint8 = tensor_to_uint8_img(sample)\n",
227
- " display(Image.fromarray(img_uint8))\n",
228
- " break\n"
229
- ]
230
- }
231
- ],
232
- "metadata": {
233
- "kernelspec": {
234
- "display_name": "Python 3 (ipykernel)",
235
- "language": "python",
236
- "name": "python3"
237
- },
238
- "language_info": {
239
- "codemirror_mode": {
240
- "name": "ipython",
241
- "version": 3
242
- },
243
- "file_extension": ".py",
244
- "mimetype": "text/x-python",
245
- "name": "python",
246
- "nbconvert_exporter": "python",
247
- "pygments_lexer": "ipython3",
248
- "version": "3.12.11"
249
- }
250
- },
251
- "nbformat": 4,
252
- "nbformat_minor": 5
253
- }