cerulianx commited on
Commit
eb44892
1 Parent(s): 783469b

Upload usage.ipynb

Browse files
Files changed (1) hide show
  1. usage.ipynb +116 -0
usage.ipynb ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import io\n",
10
+ "import os, sys\n",
11
+ "import requests\n",
12
+ "import PIL\n",
13
+ "\n",
14
+ "import torch\n",
15
+ "import torchvision.transforms as T\n",
16
+ "import torchvision.transforms.functional as TF\n",
17
+ "\n",
18
+ "from dall_e import map_pixels, unmap_pixels, load_model\n",
19
+ "from IPython.display import display, display_markdown\n",
20
+ "\n",
21
+ "target_image_size = 256\n",
22
+ "\n",
23
+ "def download_image(url):\n",
24
+ " resp = requests.get(url)\n",
25
+ " resp.raise_for_status()\n",
26
+ " return PIL.Image.open(io.BytesIO(resp.content))\n",
27
+ "\n",
28
+ "def preprocess(img):\n",
29
+ " s = min(img.size)\n",
30
+ " \n",
31
+ " if s < target_image_size:\n",
32
+ " raise ValueError(f'min dim for image {s} < {target_image_size}')\n",
33
+ " \n",
34
+ " r = target_image_size / s\n",
35
+ " s = (round(r * img.size[1]), round(r * img.size[0]))\n",
36
+ " img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)\n",
37
+ " img = TF.center_crop(img, output_size=2 * [target_image_size])\n",
38
+ " img = torch.unsqueeze(T.ToTensor()(img), 0)\n",
39
+ " return map_pixels(img)"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "# This can be changed to a GPU, e.g. 'cuda:0'.\n",
49
+ "dev = torch.device('cpu')\n",
50
+ "\n",
51
+ "# For faster load times, download these files locally and use the local paths instead.\n",
52
+ "enc = load_model(\"https://cdn.openai.com/dall-e/encoder.pkl\", dev)\n",
53
+ "dec = load_model(\"https://cdn.openai.com/dall-e/decoder.pkl\", dev)"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "x = preprocess(download_image('https://assets.bwbx.io/images/users/iqjWHBFdfxIU/iKIWgaiJUtss/v2/1000x-1.jpg'))\n",
63
+ "display_markdown('Original image:')\n",
64
+ "display(T.ToPILImage(mode='RGB')(x[0]))"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "import torch.nn.functional as F\n",
74
+ "\n",
75
+ "z_logits = enc(x)\n",
76
+ "z = torch.argmax(z_logits, axis=1)\n",
77
+ "z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float()\n",
78
+ "\n",
79
+ "x_stats = dec(z).float()\n",
80
+ "x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))\n",
81
+ "x_rec = T.ToPILImage(mode='RGB')(x_rec[0])\n",
82
+ "\n",
83
+ "display_markdown('Reconstructed image:')\n",
84
+ "display(x_rec)"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": []
93
+ }
94
+ ],
95
+ "metadata": {
96
+ "kernelspec": {
97
+ "display_name": "Python 3",
98
+ "language": "python",
99
+ "name": "python3"
100
+ },
101
+ "language_info": {
102
+ "codemirror_mode": {
103
+ "name": "ipython",
104
+ "version": 3
105
+ },
106
+ "file_extension": ".py",
107
+ "mimetype": "text/x-python",
108
+ "name": "python",
109
+ "nbconvert_exporter": "python",
110
+ "pygments_lexer": "ipython3",
111
+ "version": "3.9.1"
112
+ }
113
+ },
114
+ "nbformat": 4,
115
+ "nbformat_minor": 2
116
+ }