Upload usage.ipynb
Browse files- usage.ipynb +116 -0
usage.ipynb
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import io\n",
|
10 |
+
"import os, sys\n",
|
11 |
+
"import requests\n",
|
12 |
+
"import PIL\n",
|
13 |
+
"\n",
|
14 |
+
"import torch\n",
|
15 |
+
"import torchvision.transforms as T\n",
|
16 |
+
"import torchvision.transforms.functional as TF\n",
|
17 |
+
"\n",
|
18 |
+
"from dall_e import map_pixels, unmap_pixels, load_model\n",
|
19 |
+
"from IPython.display import display, display_markdown\n",
|
20 |
+
"\n",
|
21 |
+
"target_image_size = 256\n",
|
22 |
+
"\n",
|
23 |
+
"def download_image(url):\n",
|
24 |
+
" resp = requests.get(url)\n",
|
25 |
+
" resp.raise_for_status()\n",
|
26 |
+
" return PIL.Image.open(io.BytesIO(resp.content))\n",
|
27 |
+
"\n",
|
28 |
+
"def preprocess(img):\n",
|
29 |
+
" s = min(img.size)\n",
|
30 |
+
" \n",
|
31 |
+
" if s < target_image_size:\n",
|
32 |
+
" raise ValueError(f'min dim for image {s} < {target_image_size}')\n",
|
33 |
+
" \n",
|
34 |
+
" r = target_image_size / s\n",
|
35 |
+
" s = (round(r * img.size[1]), round(r * img.size[0]))\n",
|
36 |
+
" img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)\n",
|
37 |
+
" img = TF.center_crop(img, output_size=2 * [target_image_size])\n",
|
38 |
+
" img = torch.unsqueeze(T.ToTensor()(img), 0)\n",
|
39 |
+
" return map_pixels(img)"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "code",
|
44 |
+
"execution_count": null,
|
45 |
+
"metadata": {},
|
46 |
+
"outputs": [],
|
47 |
+
"source": [
|
48 |
+
"# This can be changed to a GPU, e.g. 'cuda:0'.\n",
|
49 |
+
"dev = torch.device('cpu')\n",
|
50 |
+
"\n",
|
51 |
+
"# For faster load times, download these files locally and use the local paths instead.\n",
|
52 |
+
"enc = load_model(\"https://cdn.openai.com/dall-e/encoder.pkl\", dev)\n",
|
53 |
+
"dec = load_model(\"https://cdn.openai.com/dall-e/decoder.pkl\", dev)"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "code",
|
58 |
+
"execution_count": null,
|
59 |
+
"metadata": {},
|
60 |
+
"outputs": [],
|
61 |
+
"source": [
|
62 |
+
"x = preprocess(download_image('https://assets.bwbx.io/images/users/iqjWHBFdfxIU/iKIWgaiJUtss/v2/1000x-1.jpg'))\n",
|
63 |
+
"display_markdown('Original image:')\n",
|
64 |
+
"display(T.ToPILImage(mode='RGB')(x[0]))"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "code",
|
69 |
+
"execution_count": null,
|
70 |
+
"metadata": {},
|
71 |
+
"outputs": [],
|
72 |
+
"source": [
|
73 |
+
"import torch.nn.functional as F\n",
|
74 |
+
"\n",
|
75 |
+
"z_logits = enc(x)\n",
|
76 |
+
"z = torch.argmax(z_logits, axis=1)\n",
|
77 |
+
"z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float()\n",
|
78 |
+
"\n",
|
79 |
+
"x_stats = dec(z).float()\n",
|
80 |
+
"x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))\n",
|
81 |
+
"x_rec = T.ToPILImage(mode='RGB')(x_rec[0])\n",
|
82 |
+
"\n",
|
83 |
+
"display_markdown('Reconstructed image:')\n",
|
84 |
+
"display(x_rec)"
|
85 |
+
]
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"cell_type": "code",
|
89 |
+
"execution_count": null,
|
90 |
+
"metadata": {},
|
91 |
+
"outputs": [],
|
92 |
+
"source": []
|
93 |
+
}
|
94 |
+
],
|
95 |
+
"metadata": {
|
96 |
+
"kernelspec": {
|
97 |
+
"display_name": "Python 3",
|
98 |
+
"language": "python",
|
99 |
+
"name": "python3"
|
100 |
+
},
|
101 |
+
"language_info": {
|
102 |
+
"codemirror_mode": {
|
103 |
+
"name": "ipython",
|
104 |
+
"version": 3
|
105 |
+
},
|
106 |
+
"file_extension": ".py",
|
107 |
+
"mimetype": "text/x-python",
|
108 |
+
"name": "python",
|
109 |
+
"nbconvert_exporter": "python",
|
110 |
+
"pygments_lexer": "ipython3",
|
111 |
+
"version": "3.9.1"
|
112 |
+
}
|
113 |
+
},
|
114 |
+
"nbformat": 4,
|
115 |
+
"nbformat_minor": 2
|
116 |
+
}
|